From 368e8c61105f7beafcdedd3d6c3376293db345ca Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 17:34:06 -0400 Subject: [PATCH] simplify contract for state polling --- cmd/ghcs/create.go | 83 ++++++++++++++++------------------- internal/codespaces/states.go | 51 +++++++++------------ 2 files changed, 58 insertions(+), 76 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 42e8f11be..6b6d10511 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -106,64 +106,55 @@ func create(opts *createOptions) error { } func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { - states, err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace) - if err != nil { - return fmt.Errorf("failed to subscribe to state changes from codespace: %v", err) - } - var lastState codespaces.PostCreateState - finishedStates := make(map[string]bool) var breakNextState bool -PollStates: - for { - select { - case <-ctx.Done(): - return nil + finishedStates := make(map[string]bool) + ctx, stopPolling := context.WithCancel(ctx) - case stateUpdate := <-states: - if stateUpdate.Err != nil { - return fmt.Errorf("receive state update: %v", err) + poller := func(states []codespaces.PostCreateState) { + var inProgress bool + for _, state := range states { + if _, found := finishedStates[state.Name]; found { + continue // skip this state as we've processed it already } - var inProgress bool - for _, state := range stateUpdate.PostCreateStates { - if _, found := finishedStates[state.Name]; found { - continue // skip this state as we've processed it already + if state.Name != lastState.Name { + log.Print(state.Name) + + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + lastState = state + log.Print("...") + break } - if state.Name != lastState.Name { - log.Print(state.Name) - - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - lastState = state - log.Print("...") - break - } - - finishedStates[state.Name] = true - log.Println("..." + state.Status) - } else { - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - log.Print(".") - break - } - - finishedStates[state.Name] = true - log.Println(state.Status) - lastState = codespaces.PostCreateState{} // reset the value + finishedStates[state.Name] = true + log.Println("..." + state.Status) + } else { + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + log.Print(".") + break } - } - if !inProgress { - if breakNextState { - break PollStates - } - breakNextState = true + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value } } + + if !inProgress { + if breakNextState { + stopPolling() + return + } + breakNextState = true + } + } + + if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { + return fmt.Errorf("failed to poll state changes from codespace: %v", err) } return nil diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 9ace150d6..427726a46 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -33,50 +33,40 @@ type PostCreateState struct { Status PostCreateStateStatus `json:"status"` } -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace) (<-chan PostCreateStatesResult, error) { - pollch := make(chan PostCreateStatesResult) - +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return nil, fmt.Errorf("getting codespace token: %v", err) + return fmt.Errorf("getting codespace token: %v", err) } lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return nil, fmt.Errorf("connect to liveshare: %v", err) + return fmt.Errorf("connect to liveshare: %v", err) } tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) if err != nil { - return nil, fmt.Errorf("make ssh tunnel: %v", err) + return fmt.Errorf("make ssh tunnel: %v", err) } - go func() { - t := time.NewTicker(1 * time.Second) - for { - select { - case <-ctx.Done(): - return - case err := <-connClosed: - if err != nil { - pollch <- PostCreateStatesResult{Err: fmt.Errorf("connection closed: %v", err)} - return - } - case <-t.C: - states, err := getPostCreateOutput(ctx, tunnelPort, codespace) - if err != nil { - pollch <- PostCreateStatesResult{Err: fmt.Errorf("get post create output: %v", err)} - return - } - - pollch <- PostCreateStatesResult{ - PostCreateStates: states, - } + t := time.NewTicker(1 * time.Second) + for { + select { + case <-ctx.Done(): + return nil + case err := <-connClosed: + return fmt.Errorf("connection closed: %v", err) + case <-t.C: + states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + if err != nil { + return fmt.Errorf("get post create output: %v", err) } - } - }() - return pollch, nil + poller(states) + } + } + + return nil } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { @@ -87,6 +77,7 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod if err != nil { return nil, fmt.Errorf("run command: %v", err) } + defer stdout.Close() b, err := ioutil.ReadAll(stdout) if err != nil {