diff --git a/main.go b/main.go index 26cf8c1..f78943c 100644 --- a/main.go +++ b/main.go @@ -71,6 +71,11 @@ func main() { EnvVar: "PLUGIN_PORT,SSH_PORT", Value: 22, }, + cli.DurationFlag{ + Name: "retry-timeout", + Usage: "connection retry timeout", + EnvVar: "PLUGIN_RETRY_TIMEOUT", + }, cli.BoolFlag{ Name: "sync", Usage: "sync mode", @@ -194,6 +199,7 @@ func run(c *cli.Context) error { Password: c.String("password"), Host: c.StringSlice("host"), Port: c.Int("port"), + RetryTimeout: c.Duration("retry-timeout"), Timeout: c.Duration("timeout"), CommandTimeout: c.Int("command.timeout"), Script: c.StringSlice("script"), diff --git a/plugin.go b/plugin.go index d47fcab..7793c13 100644 --- a/plugin.go +++ b/plugin.go @@ -10,6 +10,7 @@ import ( "github.com/appleboy/easyssh-proxy" "io" + "net" ) const ( @@ -29,6 +30,7 @@ type ( Host []string Port int Timeout time.Duration + RetryTimeout time.Duration CommandTimeout int Script []string Secrets []string @@ -90,7 +92,7 @@ func (p Plugin) exec(host string, wg *sync.WaitGroup, errChannel chan error) { p.log(host, "======END======") } - stdoutChan, stderrChan, doneChan, errChan, err := ssh.Stream(strings.Join(p.Config.Script, "\n"), p.Config.CommandTimeout) + stdoutChan, stderrChan, doneChan, errChan, err := retryStream(ssh, p) if err != nil { errChannel <- err } else { @@ -179,3 +181,34 @@ func (p Plugin) Exec() error { return nil } + +func retryStream(ssh *easyssh.MakeConfig, p Plugin) (<-chan string, <-chan string, <-chan bool, <-chan error, error) { + var ( + timeout = time.After(p.Config.RetryTimeout) + wait = time.Second + ) + + for { + stdoutChan, stderrChan, doneChan, errChan, err := ssh.Stream(strings.Join(p.Config.Script, "\n"), p.Config.CommandTimeout) + + // If there was no error, return all channels + if err == nil { + return stdoutChan, stderrChan, doneChan, errChan, nil + } + + // If the error was not a net.OpError, return that error + if _, ok := err.(*net.OpError); !ok { + return nil, nil, nil, nil, err + } + + select { + case <-timeout: + return nil, nil, nil, nil, err + case <-time.After(wait): + break + } + + // Double our back-off time + wait *= 2 + } +}