diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 77d1b00f7..b51490ddc 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -15,6 +15,7 @@ import ( "github.com/muhammadmuzzammil1998/jsonc" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func NewPortsCmd() *cobra.Command { @@ -249,18 +250,23 @@ func NewPortsForwardCmd() *cobra.Command { Short: "forward", Long: "forward", RunE: func(cmd *cobra.Command, args []string) error { - if len(args) < 3 { - return errors.New("[codespace_name] [source] [dst] port number are required.") + if len(args) < 2 { + return errors.New("[codespace_name] [source]:[dst] port number are required.") } - return forwardPort(args[0], args[1], args[2]) + return forwardPort(args[0], args[1:]) }, } } -func forwardPort(codespaceName, sourcePort, destPort string) error { +func forwardPort(codespaceName string, ports []string) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + portPairs, err := getPortPairs(ports) + if err != nil { + return fmt.Errorf("get port pairs: %v", err) + } + user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) @@ -286,25 +292,58 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error creating server: %v", err) } - sourcePortInt, err := strconv.Atoi(sourcePort) - if err != nil { - return fmt.Errorf("error reading source port: %v", err) + g, gctx := errgroup.WithContext(ctx) + for _, portPair := range portPairs { + portPair := portPair + + srcstr := strconv.Itoa(portPair.Src) + if err := server.StartSharing(gctx, "share-"+srcstr, portPair.Src); err != nil { + return fmt.Errorf("start sharing port: %v", err) + } + + g.Go(func() error { + fmt.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(portPair.Dst)) + portForwarder := liveshare.NewPortForwarder(lsclient, server, portPair.Dst) + if err := portForwarder.Start(gctx); err != nil { + return fmt.Errorf("error forwarding port: %v", err) + } + + return nil + }) } - dstPortInt, err := strconv.Atoi(destPort) - if err != nil { - return fmt.Errorf("error reading destination port: %v", err) - } - - if err := server.StartSharing(ctx, "share-"+sourcePort, sourcePortInt); err != nil { - return fmt.Errorf("error sharing source port: %v", err) - } - - fmt.Println("Forwarding port: " + sourcePort + " -> " + destPort) - portForwarder := liveshare.NewPortForwarder(lsclient, server, dstPortInt) - if err := portForwarder.Start(ctx); err != nil { - return fmt.Errorf("error forwarding port: %v", err) + if err := g.Wait(); err != nil { + return err } return nil } + +type portPair struct { + Src, Dst int +} + +func getPortPairs(ports []string) ([]portPair, error) { + pp := make([]portPair, 0, len(ports)) + + for _, portString := range ports { + parts := strings.Split(portString, ":") + if len(parts) < 2 { + return pp, fmt.Errorf("port pair: '%v' is not valid", portString) + } + + srcp, err := strconv.Atoi(parts[0]) + if err != nil { + return pp, fmt.Errorf("convert source port to int: %v", err) + } + + dstp, err := strconv.Atoi(parts[1]) + if err != nil { + return pp, fmt.Errorf("convert dest port to int: %v", err) + } + + pp = append(pp, portPair{srcp, dstp}) + } + + return pp, nil +}