use WaitGroup for a more idiomatic concurrency pattern

This commit is contained in:
Jason Lunz 2021-12-20 12:12:22 -07:00
parent 81b34d272c
commit a864985f0a
No known key found for this signature in database
GPG key ID: C3EB59E26C4EE4F3

View file

@ -13,6 +13,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"text/template"
"github.com/MakeNowJust/heredoc"
@ -199,7 +200,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts configOptions, execut
}
sshUsers := make(chan sshResult)
fetches := 0
var wg sync.WaitGroup
var status error
for _, cs := range csList {
if cs.State != "Available" && opts.codespace == "" {
@ -209,35 +210,34 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts configOptions, execut
}
cs := cs
fetches += 1
wg.Add(1)
go func() {
result := sshResult{}
defer func() {
select {
case sshUsers <- result:
case <-ctx.Done():
}
}()
defer wg.Done()
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, cs)
if err != nil {
result.err = fmt.Errorf("error connecting to codespace: %w", err)
return
}
defer session.Close()
} else {
defer session.Close()
//a.StartProgressIndicatorWithLabel(fmt.Sprintf("Fetching SSH Details for %s", cs.Name))
_, result.user, err = session.StartSSHServer(ctx)
//a.StopProgressIndicator()
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
return
_, result.user, err = session.StartSSHServer(ctx)
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
} else {
result.codespace = cs
}
}
result.codespace = cs
sshUsers <- result
}()
}
go func() {
wg.Wait()
close(sshUsers)
}()
// While the above fetches are running, ensure that the user has keys installed.
// That lets us report a more useful error message if they don't.
if err = checkAuthorizedKeys(ctx, a.apiClient); err != nil {
@ -258,8 +258,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts configOptions, execut
return fmt.Errorf("error formatting template: %w", err)
}
for i := 0; i < fetches; i++ {
result := <-sshUsers
for result := range sshUsers {
if result.err != nil {
fmt.Fprintf(os.Stderr, "%v\n", result.err)
status = cmdutil.SilentError