From a8d30dc1337fcfd463fe3c63cad69d8b71a99198 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Tue, 4 Apr 2023 16:56:50 +0800 Subject: [PATCH] feat: refactor code for parallel execution on multiple hosts (#249) - Add `trimValues` function for cleaning up slice values - Remove unused `wg.Done()` call - Modify `Exec` function to launch goroutines for each host in `Config.Host` - Add test for `ScriptStop` with multiple hosts and sync mode refer to: https://github.com/appleboy/ssh-action/issues/233 --- plugin.go | 84 +++++++++++++++++++++++++++++++------------------- plugin_test.go | 35 +++++++++++++++++++++ 2 files changed, 87 insertions(+), 32 deletions(-) diff --git a/plugin.go b/plugin.go index 67049b3..b2cd752 100644 --- a/plugin.go +++ b/plugin.go @@ -65,6 +65,7 @@ func (p Plugin) hostPort(host string) (string, string) { } func (p Plugin) exec(host string, wg *sync.WaitGroup, errChannel chan error) { + defer wg.Done() host, port := p.hostPort(host) // Create MakeConfig instance with remote username, server address and path to private key. ssh := &easyssh.MakeConfig{ @@ -117,38 +118,36 @@ func (p Plugin) exec(host string, wg *sync.WaitGroup, errChannel chan error) { stdoutChan, stderrChan, doneChan, errChan, err := ssh.Stream(strings.Join(p.Config.Script, "\n"), p.Config.CommandTimeout) if err != nil { errChannel <- err - } else { - // read from the output channel until the done signal is passed - isTimeout := true - loop: - for { - select { - case isTimeout = <-doneChan: - break loop - case outline := <-stdoutChan: - if outline != "" { - p.log(host, "out:", outline) - } - case errline := <-stderrChan: - if errline != "" { - p.log(host, "err:", errline) - } - case err = <-errChan: + return + } + // read from the output channel until the done signal is passed + isTimeout := true +loop: + for { + select { + case isTimeout = <-doneChan: + break loop + case outline := <-stdoutChan: + if outline != "" { + p.log(host, "out:", outline) } - } - - // get exit code or command error. - if err != nil { - errChannel <- err - } - - // command time out - if !isTimeout { - errChannel <- errCommandTimeOut + case errline := <-stderrChan: + if errline != "" { + p.log(host, "err:", errline) + } + case err = <-errChan: } } - wg.Done() + // get exit code or command error. + if err != nil { + errChannel <- err + } + + // command time out + if !isTimeout { + errChannel <- errCommandTimeOut + } } func (p Plugin) log(host string, message ...interface{}) { @@ -164,6 +163,8 @@ func (p Plugin) log(host string, message ...interface{}) { // Exec executes the plugin. func (p Plugin) Exec() error { + p.Config.Host = trimValues(p.Config.Host) + if len(p.Config.Host) == 0 { return errMissingHost } @@ -176,10 +177,14 @@ func (p Plugin) Exec() error { wg.Add(len(p.Config.Host)) errChannel := make(chan error) finished := make(chan struct{}) - for _, host := range p.Config.Host { - if p.Config.Sync { - p.exec(host, &wg, errChannel) - } else { + if p.Config.Sync { + go func() { + for _, host := range p.Config.Host { + p.exec(host, &wg, errChannel) + } + }() + } else { + for _, host := range p.Config.Host { go p.exec(host, &wg, errChannel) } } @@ -230,3 +235,18 @@ func (p Plugin) scriptCommands() []string { return commands } + +func trimValues(keys []string) []string { + var newKeys []string + + for _, value := range keys { + value = strings.TrimSpace(value) + if len(value) == 0 { + continue + } + + newKeys = append(newKeys, value) + } + + return newKeys +} diff --git a/plugin_test.go b/plugin_test.go index dc33248..91e252e 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -440,6 +440,41 @@ func TestFingerprint(t *testing.T) { assert.Equal(t, unindent(expected), unindent(buffer.String())) } +func TestScriptStopWithMultipleHostAndSyncMode(t *testing.T) { + var ( + buffer bytes.Buffer + expected = ` + ======CMD====== + mkdir a/b/c + mkdir d/e/f + ======END====== + err: mkdir: can't create directory 'a/b/c': No such file or directory + ` + ) + + plugin := Plugin{ + Config: Config{ + Host: []string{"", "localhost"}, + Username: "drone-scp", + Port: 22, + KeyPath: "./tests/.ssh/id_rsa", + Script: []string{ + "mkdir a/b/c", + "mkdir d/e/f", + }, + CommandTimeout: 10 * time.Second, + ScriptStop: true, + Sync: true, + }, + Writer: &buffer, + } + + err := plugin.Exec() + assert.NotNil(t, err) + + assert.Equal(t, unindent(expected), unindent(buffer.String())) +} + func TestScriptStop(t *testing.T) { var ( buffer bytes.Buffer