diff --git a/main.go b/main.go index 26cf8c1..f78943c 100644 --- a/main.go +++ b/main.go @@ -71,6 +71,11 @@ func main() { EnvVar: "PLUGIN_PORT,SSH_PORT", Value: 22, }, + cli.DurationFlag{ + Name: "retry-timeout", + Usage: "connection retry timeout", + EnvVar: "PLUGIN_RETRY_TIMEOUT", + }, cli.BoolFlag{ Name: "sync", Usage: "sync mode", @@ -194,6 +199,7 @@ func run(c *cli.Context) error { Password: c.String("password"), Host: c.StringSlice("host"), Port: c.Int("port"), + RetryTimeout: c.Duration("retry-timeout"), Timeout: c.Duration("timeout"), CommandTimeout: c.Int("command.timeout"), Script: c.StringSlice("script"), diff --git a/plugin.go b/plugin.go index d47fcab..7793c13 100644 --- a/plugin.go +++ b/plugin.go @@ -10,6 +10,7 @@ import ( "github.com/appleboy/easyssh-proxy" "io" + "net" ) const ( @@ -29,6 +30,7 @@ type ( Host []string Port int Timeout time.Duration + RetryTimeout time.Duration CommandTimeout int Script []string Secrets []string @@ -90,7 +92,7 @@ func (p Plugin) exec(host string, wg *sync.WaitGroup, errChannel chan error) { p.log(host, "======END======") } - stdoutChan, stderrChan, doneChan, errChan, err := ssh.Stream(strings.Join(p.Config.Script, "\n"), p.Config.CommandTimeout) + stdoutChan, stderrChan, doneChan, errChan, err := retryStream(ssh, p) if err != nil { errChannel <- err } else { @@ -179,3 +181,34 @@ func (p Plugin) Exec() error { return nil } + +func retryStream(ssh *easyssh.MakeConfig, p Plugin) (<-chan string, <-chan string, <-chan bool, <-chan error, error) { + var ( + timeout = time.After(p.Config.RetryTimeout) + wait = time.Second + ) + + for { + stdoutChan, stderrChan, doneChan, errChan, err := ssh.Stream(strings.Join(p.Config.Script, "\n"), p.Config.CommandTimeout) + + // If there was no error, return all channels + if err == nil { + return stdoutChan, stderrChan, doneChan, errChan, nil + } + + // If the error was not a net.OpError, return that error + if _, ok := err.(*net.OpError); !ok { + return nil, nil, nil, nil, err + } + + select { + case <-timeout: + return nil, nil, nil, nil, err + case <-time.After(wait): + break + } + + // Double our back-off time + wait *= 2 + } +} diff --git a/plugin_test.go b/plugin_test.go index 563c06b..26160bb 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -8,6 +8,8 @@ import ( "github.com/appleboy/easyssh-proxy" "github.com/stretchr/testify/assert" "strings" + + "time" ) func TestMissingHostOrUser(t *testing.T) { @@ -457,6 +459,56 @@ func TestEnvOutput(t *testing.T) { assert.Equal(t, unindent(expected), unindent(buffer.String())) } +// TestConnectTimeoutRetry tests that when a network error occurs, the connect +// is retried until it either succeeds or the configured timeout is hit. +func TestConnectTimeoutRetry(t *testing.T) { + start := time.Now() + + plugin := Plugin{ + Config: Config{ + Host: []string{"localhost"}, + UserName: "drone-scp", + Port: 2200, + KeyPath: "./tests/.ssh/id_rsa", + Script: []string{"exit"}, + RetryTimeout: 15 * time.Second, + Sync: true, + }, + } + + err := plugin.Exec() + assert.NotNil(t, err) + + end := time.Now() + assert.WithinDuration(t, start.Local().Add(15*time.Second), end, 2*time.Second) + +} + +// TestConnectTimeoutRetry tests that when a non-network error occurs, the connect +// is not retried and instead returns the error immediately. +func TestConnectTimeoutImmediate(t *testing.T) { + start := time.Now() + + plugin := Plugin{ + Config: Config{ + Host: []string{"localhost"}, + UserName: "drone-scp", + Port: 22, + Key: "", + Script: []string{"exit"}, + RetryTimeout: 60 * time.Second, + Sync: true, + }, + } + + err := plugin.Exec() + assert.NotNil(t, err) + + end := time.Now() + assert.WithinDuration(t, start.Local().Add(time.Second), end, 2*time.Second) + +} + func unindent(text string) string { return strings.TrimSpace(strings.Replace(text, "\t", "", -1)) }