diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 6bbf05091..0f2460e0a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -16,6 +16,7 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) type PortsOptions struct { @@ -241,19 +242,24 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ Use: "forward ", - Short: "Forward port", - Args: cobra.ExactArgs(3), + Short: "Forward ports", + Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { log := output.NewLogger(os.Stdout, os.Stderr, false) - return forwardPort(log, args[0], args[1], args[2]) + return forwardPorts(log, args[0], args[1:]) }, } } -func forwardPort(log *output.Logger, codespaceName, sourcePort, destPort string) error { +func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + portPairs, err := getPortPairs(ports) + if err != nil { + return fmt.Errorf("get port pairs: %v", err) + } + user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) @@ -279,29 +285,62 @@ func forwardPort(log *output.Logger, codespaceName, sourcePort, destPort string) return fmt.Errorf("error creating server: %v", err) } - sourcePortInt, err := strconv.Atoi(sourcePort) - if err != nil { - return fmt.Errorf("error reading source port: %v", err) + g, gctx := errgroup.WithContext(ctx) + for _, portPair := range portPairs { + pp := portPair + + srcstr := strconv.Itoa(portPair.Src) + if err := server.StartSharing(gctx, "share-"+srcstr, pp.Src); err != nil { + return fmt.Errorf("start sharing port: %v", err) + } + + g.Go(func() error { + log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.Dst)) + portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.Dst) + if err := portForwarder.Start(gctx); err != nil { + return fmt.Errorf("error forwarding port: %v", err) + } + + return nil + }) } - dstPortInt, err := strconv.Atoi(destPort) - if err != nil { - return fmt.Errorf("error reading destination port: %v", err) - } - - if err := server.StartSharing(ctx, "share-"+sourcePort, sourcePortInt); err != nil { - return fmt.Errorf("error sharing source port: %v", err) - } - - log.Println("Forwarding port: " + sourcePort + " -> " + destPort) - portForwarder := liveshare.NewPortForwarder(lsclient, server, dstPortInt) - if err := portForwarder.Start(ctx); err != nil { - return fmt.Errorf("error forwarding port: %v", err) + if err := g.Wait(); err != nil { + return err } return nil } +type portPair struct { + Src, Dst int +} + +func getPortPairs(ports []string) ([]portPair, error) { + pp := make([]portPair, 0, len(ports)) + + for _, portString := range ports { + parts := strings.Split(portString, ":") + if len(parts) < 2 { + return pp, fmt.Errorf("port pair: '%v' is not valid", portString) + } + + srcp, err := strconv.Atoi(parts[0]) + if err != nil { + return pp, fmt.Errorf("convert source port to int: %v", err) + } + + dstp, err := strconv.Atoi(parts[1]) + if err != nil { + return pp, fmt.Errorf("convert dest port to int: %v", err) + } + + pp = append(pp, portPair{srcp, dstp}) + } + + return pp, nil +} + func normalizeJSON(j []byte) []byte { // remove trailing commas return bytes.ReplaceAll(j, []byte("},}"), []byte("}}"))