diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 2fab4254d..2bb3e6917 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -37,7 +37,7 @@ func newPortsCmd() *cobra.Command { }, } - portsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") portsCmd.AddCommand(newPortsPublicCmd()) @@ -157,18 +157,24 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. func newPortsPublicCmd() *cobra.Command { - var codespace string - newPortsPublicCmd := &cobra.Command{ Use: "public ", Short: "Mark port as public", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) port := args[0] if len(args) > 1 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace, port = args[0], args[1] } @@ -176,25 +182,29 @@ func newPortsPublicCmd() *cobra.Command { }, } - newPortsPublicCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsPublicCmd } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. func newPortsPrivateCmd() *cobra.Command { - var codespace string - newPortsPrivateCmd := &cobra.Command{ Use: "private ", Short: "Mark port as private", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) port := args[0] if len(args) > 1 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace, port = args[0], args[1] } @@ -202,8 +212,6 @@ func newPortsPrivateCmd() *cobra.Command { }, } - newPortsPrivateCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsPrivateCmd } @@ -216,13 +224,11 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting user: %v", err) } - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - - codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { + if err == codespaces.ErrNoCodespaces { + return err + } return fmt.Errorf("error getting codespace: %v", err) } @@ -252,19 +258,25 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. func newPortsForwardCmd() *cobra.Command { - var codespace string - newPortsForwardCmd := &cobra.Command{ Use: "forward :...", Short: "Forward ports", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) ports := args[0:] if len(args) > 1 && !strings.Contains(args[0], ":") { // assume this is a codespace name - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace = args[0] ports = args[1:] } @@ -273,8 +285,6 @@ func newPortsForwardCmd() *cobra.Command { }, } - newPortsForwardCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsForwardCmd } @@ -292,13 +302,11 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro return fmt.Errorf("error getting user: %v", err) } - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - - codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { + if err == codespaces.ErrNoCodespaces { + return err + } return fmt.Errorf("error getting codespace: %v", err) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index b5fa4a583..fd04b303e 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -124,6 +124,8 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } +// GetOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. +// It then fetches the codespace token and the codespace record. func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = ChooseCodespace(ctx, apiClient, user)