Forward codespace ports over Dev Tunnels

This commit is contained in:
David Gardiner 2023-09-19 13:36:06 -07:00
parent 48b0d53d0e
commit e059f32aa5
13 changed files with 1271 additions and 303 deletions

2
go.mod
View file

@ -27,6 +27,7 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/mattn/go-isatty v0.0.19
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/microsoft/dev-tunnels v0.0.21
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
github.com/opentracing/opentracing-go v1.1.0
github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d
@ -75,6 +76,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rodaine/table v1.0.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect
github.com/stretchr/objx v0.5.0 // indirect

4
go.sum
View file

@ -117,6 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg=
github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM=
github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0=
github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
@ -139,6 +141,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ=
github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46DPtb8k793jiecUEhaX9ixoIBt41HEGU=

View file

@ -201,6 +201,7 @@ type Codespace struct {
GitStatus CodespaceGitStatus `json:"git_status"`
Connection CodespaceConnection `json:"connection"`
Machine CodespaceMachine `json:"machine"`
RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"`
VSCSTarget string `json:"vscs_target"`
PendingOperation bool `json:"pending_operation"`
PendingOperationDisabledReason string `json:"pending_operation_disabled_reason"`
@ -246,11 +247,25 @@ const (
)
type CodespaceConnection struct {
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
TunnelProperties TunnelProperties `json:"tunnelProperties"`
}
type TunnelProperties struct {
ConnectAccessToken string `json:"connectAccessToken"`
ManagePortsAccessToken string `json:"managePortsAccessToken"`
ServiceUri string `json:"serviceUri"`
TunnelId string `json:"tunnelId"`
ClusterId string `json:"clusterId"`
Domain string `json:"domain"`
}
type RuntimeConstraints struct {
AllowedPortPrivacySettings []string `json:"allowed_port_privacy_settings"`
}
// ListCodespaceFields is the list of exportable fields for a codespace when using the `gh cs list` command.
@ -1162,3 +1177,13 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error
return nil, fmt.Errorf("received response with status code %d", resp.StatusCode)
}, backoff.WithMaxRetries(bo, 3))
}
// HTTPClient returns the HTTP client used to make requests to the API.
func (a *API) HTTPClient() (*http.Client, error) {
httpClient, err := a.client()
if err != nil {
return nil, err
}
return httpClient, nil
}

View file

@ -5,24 +5,42 @@ import (
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/cli/cli/v2/pkg/liveshare"
)
func connectionReady(codespace *api.Codespace) bool {
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
// If the codespace is not available, it is not ready
if codespace.State != api.CodespaceStateAvailable {
return false
}
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
if usingDevTunnels {
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
codespace.Connection.TunnelProperties.ServiceUri != "" &&
codespace.Connection.TunnelProperties.TunnelId != "" &&
codespace.Connection.TunnelProperties.ClusterId != "" &&
codespace.Connection.TunnelProperties.Domain != ""
}
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
return codespace.Connection.SessionID != "" &&
codespace.Connection.SessionToken != "" &&
codespace.Connection.RelayEndpoint != "" &&
codespace.Connection.RelaySAS != "" &&
codespace.State == api.CodespaceStateAvailable
codespace.Connection.RelaySAS != ""
}
type apiClient interface {
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
StartCodespace(ctx context.Context, name string) error
HTTPClient() (*http.Client, error)
}
type progressIndicator interface {
@ -43,9 +61,48 @@ func (e *TimeoutError) Error() string {
return e.message
}
// ConnectToLiveshare waits for a Codespace to become running,
// and connects to it using a Live Share session.
// GetCodespaceConnection waits until a codespace is able
// to be connected to and initializes a connection to it.
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
if err != nil {
return nil, err
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
httpClient, err := apiClient.HTTPClient()
if err != nil {
return nil, fmt.Errorf("error getting http client: %w", err)
}
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
}
// ConnectToLiveshare waits until a codespace is able to be
// connected to and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
if err != nil {
return nil, err
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
}
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
if codespace.State != api.CodespaceStateAvailable {
progress.StartProgressIndicatorWithLabel("Starting codespace")
defer progress.StopProgressIndicator()
@ -54,7 +111,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
}
}
if !connectionReady(codespace) {
if !connectionReady(codespace, usingDevTunnels) {
expBackoff := backoff.NewExponentialBackOff()
expBackoff.Multiplier = 1.1
expBackoff.MaxInterval = 10 * time.Second
@ -67,7 +124,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
}
if connectionReady(codespace) {
if connectionReady(codespace, usingDevTunnels) {
return nil
}
@ -83,17 +140,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
}
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
return codespace, nil
}
// ListenTCP starts a localhost tcp listener on 127.0.0.1 (unless allInterfaces is true) and returns the listener and bound port

View file

@ -0,0 +1,116 @@
package connection
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
clientName = "gh"
)
type CodespaceConnection struct {
tunnelProperties api.TunnelProperties
TunnelManager *tunnels.Manager
TunnelClient *tunnels.Client
Options *tunnels.TunnelRequestOptions
Tunnel *tunnels.Tunnel
AllowedPortPrivacySettings []string
}
// NewCodespaceConnection initializes a connection to a codespace.
// This connections allows for port forwarding which enables the
// use of most features of the codespace command.
func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) {
// Get the tunnel properties
tunnelProperties := codespace.Connection.TunnelProperties
// Create the tunnel manager
tunnelManager, err := getTunnelManager(tunnelProperties, httpClient)
if err != nil {
return nil, fmt.Errorf("error getting tunnel management client: %w", err)
}
// Calculate allowed port privacy settings
allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings
// Get the access tokens
connectToken := tunnelProperties.ConnectAccessToken
managementToken := tunnelProperties.ManagePortsAccessToken
// Create the tunnel definition
tunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken},
TunnelID: tunnelProperties.TunnelId,
ClusterID: tunnelProperties.ClusterId,
Domain: tunnelProperties.Domain,
}
// Create options
options := &tunnels.TunnelRequestOptions{
IncludePorts: true,
}
// Create the tunnel client (not connected yet)
tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel client: %w", err)
}
return &CodespaceConnection{
tunnelProperties: tunnelProperties,
TunnelManager: tunnelManager,
TunnelClient: tunnelClient,
Options: options,
Tunnel: tunnel,
AllowedPortPrivacySettings: allowedPortPrivacySettings,
}, nil
}
// getTunnelManager creates a tunnel manager for the given codespace.
// The tunnel manager is used to get the tunnel hosted in the codespace that we
// want to connect to and perform operations on ports (add, remove, list, etc.).
func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) {
userAgent := []tunnels.UserAgent{{Name: clientName}}
url, err := url.Parse(tunnelProperties.ServiceUri)
if err != nil {
return nil, fmt.Errorf("error parsing tunnel service uri: %w", err)
}
// Create the tunnel manager
tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient)
if err != nil {
return nil, fmt.Errorf("error creating tunnel manager: %w", err)
}
return tunnelManager, nil
}
// getTunnelClient creates a tunnel client for the given tunnel.
// The tunnel client is used to connect to the the tunnel and allows
// for ports to be forwarded locally.
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
// Get the tunnel that we want to connect to
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel: %w", err)
}
// Copy the access tokens from the tunnel definition
codespaceTunnel.AccessTokens = tunnel.AccessTokens
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
if err != nil {
return nil, fmt.Errorf("error creating tunnel client: %w", err)
}
return tunnelClient, nil
}

View file

@ -0,0 +1,75 @@
package connection
import (
"context"
"reflect"
"testing"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
func TestNewCodespaceConnection(t *testing.T) {
ctx := context.Background()
// Create a mock codespace
connection := api.CodespaceConnection{
TunnelProperties: api.TunnelProperties{
ConnectAccessToken: "connect-token",
ManagePortsAccessToken: "manage-ports-token",
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
TunnelId: "tunnel-id",
ClusterId: "usw2",
Domain: "domain.com",
},
}
allowedPortPrivacySettings := []string{"public", "private"}
codespace := &api.Codespace{
Connection: connection,
RuntimeConstraints: api.RuntimeConstraints{AllowedPortPrivacySettings: allowedPortPrivacySettings},
}
// Create the mock HTTP client
httpClient, err := NewMockHttpClient()
if err != nil {
t.Fatalf("NewHttpClient returned an error: %v", err)
}
// Create the connection
conn, err := NewCodespaceConnection(ctx, codespace, httpClient)
if err != nil {
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
}
// Check that the connection was created successfully
if conn == nil {
t.Fatal("NewCodespaceConnection returned nil")
}
// Verify that the connection contains the expected tunnel properties
if conn.tunnelProperties != connection.TunnelProperties {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel properties: %+v", conn.tunnelProperties)
}
// Verify that the connection contains the expected tunnel
expectedTunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connection.TunnelProperties.ConnectAccessToken, tunnels.TunnelAccessScopeManagePorts: connection.TunnelProperties.ManagePortsAccessToken},
TunnelID: connection.TunnelProperties.TunnelId,
ClusterID: connection.TunnelProperties.ClusterId,
Domain: connection.TunnelProperties.Domain,
}
if !reflect.DeepEqual(conn.Tunnel, expectedTunnel) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel: %+v", conn.Tunnel)
}
// Verify that the connection contains the expected tunnel options
expectedOptions := &tunnels.TunnelRequestOptions{IncludePorts: true}
if !reflect.DeepEqual(conn.Options, expectedOptions) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected options: %+v", conn.Options)
}
// Verify that the connection contains the expected allowed port privacy settings
if !reflect.DeepEqual(conn.AllowedPortPrivacySettings, allowedPortPrivacySettings) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected allowed port privacy settings: %+v", conn.AllowedPortPrivacySettings)
}
}

View file

@ -0,0 +1,396 @@
package connection
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/microsoft/dev-tunnels/go/tunnels"
tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh"
"github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages"
"golang.org/x/crypto/ssh"
)
func NewMockHttpClient() (*http.Client, error) {
accessToken := "tunnel access-token"
relayServer, err := newMockrelayServer(withAccessToken(accessToken))
if err != nil {
return nil, fmt.Errorf("NewrelayServer returned an error: %w", err)
}
hostURL := strings.Replace(relayServer.URL(), "http://", "ws://", 1)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var response []byte
if r.URL.Path == "/api/v1/tunnels/tunnel-id" {
tunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{
tunnels.TunnelAccessScopeConnect: accessToken,
},
Endpoints: []tunnels.TunnelEndpoint{
{
HostID: "host1",
TunnelRelayTunnelEndpoint: tunnels.TunnelRelayTunnelEndpoint{
ClientRelayURI: hostURL,
},
},
},
}
response, err = json.Marshal(*tunnel)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
} else if strings.HasPrefix(r.URL.Path, "/api/v1/tunnels/tunnel-id/ports") {
// Use regex to check if the path ends with a number
match, err := regexp.MatchString(`\/\d+$`, r.URL.Path)
if err != nil {
log.Fatalf("regexp.MatchString returned an error: %v", err)
}
// If the path ends with a number, it's a request for a specific port
if match || r.Method == http.MethodPost {
if r.Method == http.MethodDelete {
w.WriteHeader(http.StatusOK)
return
}
tunnelPort := &tunnels.TunnelPort{
AccessControl: &tunnels.TunnelAccessControl{
Entries: []tunnels.TunnelAccessControlEntry{},
},
}
// Convert the tunnel to JSON and write it to the response
response, err = json.Marshal(*tunnelPort)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
} else {
// If the path doesn't end with a number and we aren't making a POST request, return an array of ports
tunnelPorts := []tunnels.TunnelPort{
{
AccessControl: &tunnels.TunnelAccessControl{
Entries: []tunnels.TunnelAccessControlEntry{},
},
},
}
response, err = json.Marshal(tunnelPorts)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
}
} else {
w.WriteHeader(http.StatusNotFound)
return
}
// Write the response
_, _ = w.Write(response)
}))
url, err := url.Parse(mockServer.URL)
if err != nil {
return nil, fmt.Errorf("url.Parse returned an error: %w", err)
}
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(url),
},
}, nil
}
type relayServer struct {
httpServer *httptest.Server
errc chan error
sshConfig *ssh.ServerConfig
channels map[string]channelHandler
accessToken string
serverConn *ssh.ServerConn
}
type relayServerOption func(*relayServer)
type channelHandler func(context.Context, ssh.NewChannel) error
func newMockrelayServer(opts ...relayServerOption) (*relayServer, error) {
server := &relayServer{
errc: make(chan error),
sshConfig: &ssh.ServerConfig{
NoClientAuth: true,
},
}
// Create a private key with the crypto package
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
privateKeyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)
// Parse the private key
sshPrivateKey, err := ssh.ParsePrivateKey(privateKeyPEM)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
server.sshConfig.AddHostKey(ssh.Signer(sshPrivateKey))
server.httpServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
for _, opt := range opts {
opt(server)
}
return server, nil
}
func withAccessToken(accessToken string) func(*relayServer) {
return func(server *relayServer) {
server.accessToken = accessToken
}
}
func (rs *relayServer) URL() string {
return rs.httpServer.URL
}
func (rs *relayServer) Err() <-chan error {
return rs.errc
}
func (rs *relayServer) sendError(err error) {
select {
case rs.errc <- err:
default:
// channel is blocked with a previous error, so we ignore this one
}
}
func (rs *relayServer) ForwardPort(ctx context.Context, port uint16) error {
pfr := messages.NewPortForwardRequest("127.0.0.1", uint32(port))
b, err := pfr.Marshal()
if err != nil {
return fmt.Errorf("error marshaling port forward request: %w", err)
}
replied, data, err := rs.serverConn.SendRequest(messages.PortForwardRequestType, true, b)
if err != nil {
return fmt.Errorf("error sending port forward request: %w", err)
}
if !replied {
return fmt.Errorf("port forward request not replied")
}
if data == nil {
return fmt.Errorf("no data returned")
}
return nil
}
func makeConnection(server *relayServer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.accessToken != "" {
if r.Header.Get("Authorization") != server.accessToken {
server.sendError(fmt.Errorf("invalid access token"))
return
}
}
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
server.sendError(fmt.Errorf("error upgrading to websocket: %w", err))
return
}
defer func() {
if err := c.Close(); err != nil {
server.sendError(fmt.Errorf("error closing websocket: %w", err))
}
}()
socketConn := newSocketConn(c)
serverConn, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
server.sendError(fmt.Errorf("error creating ssh server conn: %w", err))
return
}
go handleRequests(ctx, convertRequests(reqs))
server.serverConn = serverConn
if err := handleChannels(ctx, server, chans); err != nil {
server.sendError(fmt.Errorf("error handling channels: %w", err))
return
}
}
}
func (sr *sshRequest) Type() string {
return sr.request.Type
}
type sshRequest struct {
request *ssh.Request
}
// Reply method for sshRequest to satisfy the tunnelssh.SSHRequest interface
func (sr *sshRequest) Reply(success bool, message []byte) error {
return sr.request.Reply(success, message)
}
// convertRequests function
func convertRequests(reqs <-chan *ssh.Request) <-chan tunnelssh.SSHRequest {
out := make(chan tunnelssh.SSHRequest)
go func() {
for req := range reqs {
out <- &sshRequest{req}
}
close(out)
}()
return out
}
func handleChannels(ctx context.Context, server *relayServer, chans <-chan ssh.NewChannel) error {
errc := make(chan error, 1)
go func() {
for ch := range chans {
if handler, ok := server.channels[ch.ChannelType()]; ok {
if err := handler(ctx, ch); err != nil {
errc <- err
return
}
} else {
// generic accept of the channel to not block
_, _, err := ch.Accept()
if err != nil {
errc <- fmt.Errorf("error accepting channel: %w", err)
return
}
}
}
}()
return awaitError(ctx, errc)
}
func handleRequests(ctx context.Context, reqs <-chan tunnelssh.SSHRequest) {
for {
select {
case <-ctx.Done():
return
case req, ok := <-reqs:
if !ok {
return
}
if req.Type() == "RefreshPorts" {
_ = req.Reply(true, nil)
continue
} else {
_ = req.Reply(false, nil)
}
}
}
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errc:
return err
}
}
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %w", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}

View file

@ -0,0 +1,253 @@
package portforwarder
import (
"context"
"fmt"
"net"
"strings"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
githubSubjectId = "1"
InternalPortTag = "InternalPort"
UserForwardedPortTag = "UserForwardedPort"
)
const (
PrivatePortVisibility = "private"
OrgPortVisibility = "org"
PublicPortVisibility = "public"
)
type PortForwarder struct {
connection connection.CodespaceConnection
}
// NewPortForwarder returns a new PortForwarder for the specified codespace.
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
return &PortForwarder{
connection: *codespaceConnection,
}, nil
}
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
}
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
// If no visibility is provided, Dev Tunnels will use the default (private)
if visibility != "" {
// Check if the requested visibility is allowed
allowed := false
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
if allowedVisibility == visibility {
allowed = true
break
}
}
// If the requested visibility is not allowed, return an error
if !allowed {
return fmt.Errorf("visibility %s is not allowed", visibility)
}
accessControlEntries := visibilityToAccessControlEntries(visibility)
if len(accessControlEntries) > 0 {
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
Entries: accessControlEntries,
}
}
}
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
if internal {
tunnelPort.Tags = []string{InternalPortTag}
} else {
tunnelPort.Tags = []string{UserForwardedPortTag}
}
// Create the tunnel port
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
if err != nil && !strings.Contains(err.Error(), "409") {
return fmt.Errorf("create tunnel port failed: %v", err)
}
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
// If we don't want to connect to the port, exit early
if !connect {
done <- nil
return
}
// Ensure the port is forwarded before connecting
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
if err != nil {
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
return
}
// Connect to the forwarded port via a local TCP port
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
return
}
done <- nil
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// ListPorts fetches the list of ports that are currently forwarded.
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
if err != nil {
return nil, fmt.Errorf("error listing ports: %w", err)
}
return ports, nil
}
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
if err != nil {
return fmt.Errorf("error getting tunnel port: %w", err)
}
// If the port visibility isn't changing, don't do anything
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
return nil
}
// Delete the existing tunnel port to update
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options)
if err != nil {
return fmt.Errorf("error deleting tunnel port: %w", err)
}
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've deleted the port
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
done <- nil
}()
// Wait for the done channel to be closed
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
// Re-forward the port with the updated visibility
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
if err != nil {
return fmt.Errorf("error forwarding port: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
for _, entry := range accessControlEntries {
// If we have the anonymous type (and we're not denying it), it's public
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
return PublicPortVisibility
}
// If we have the organizations type (and we're not denying it), it's org
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
return OrgPortVisibility
}
}
// Else, it's private
return PrivatePortVisibility
}
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
switch visibility {
case PublicPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
Subjects: []string{},
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
}}
case OrgPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{
string(tunnels.TunnelAccessScopeConnect),
},
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
}}
default:
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{},
IsDeny: true,
}}
}
}
// IsInternalPort returns true if the port is internal.
func IsInternalPort(port *tunnels.TunnelPort) bool {
for _, tag := range port.Tags {
if strings.EqualFold(tag, InternalPortTag) {
return true
}
}
return false
}

View file

@ -0,0 +1,139 @@
package portforwarder
import (
"context"
"testing"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
func TestNewPortForwarder(t *testing.T) {
ctx := context.Background()
// Create a mock codespace
codespace := &api.Codespace{
Connection: api.CodespaceConnection{
TunnelProperties: api.TunnelProperties{
ConnectAccessToken: "connect-token",
ManagePortsAccessToken: "manage-ports-token",
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
TunnelId: "tunnel-id",
ClusterId: "usw2",
Domain: "domain.com",
},
},
RuntimeConstraints: api.RuntimeConstraints{
AllowedPortPrivacySettings: []string{"public", "private"},
},
}
// Create the mock HTTP client
httpClient, err := connection.NewMockHttpClient()
if err != nil {
t.Fatalf("NewHttpClient returned an error: %v", err)
}
// Call the function being tested
conn, err := connection.NewCodespaceConnection(ctx, codespace, httpClient)
if err != nil {
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
}
// Create the new port forwarder
portForwarder, err := NewPortForwarder(ctx, conn)
if err != nil {
t.Fatalf("NewPortForwarder returned an error: %v", err)
}
// Check that the port forwarder was created successfully
if portForwarder == nil {
t.Fatal("NewPortForwarder returned nil")
}
}
func TestAccessControlEntriesToVisibility(t *testing.T) {
publicAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
}}
orgAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
}}
privateAccessControlEntry := []tunnels.TunnelAccessControlEntry{}
orgIsDenyAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
IsDeny: true,
}}
tests := []struct {
name string
accessControlEntries []tunnels.TunnelAccessControlEntry
expected string
}{
{
name: "public",
accessControlEntries: publicAccessControlEntry,
expected: PublicPortVisibility,
},
{
name: "org",
accessControlEntries: orgAccessControlEntry,
expected: OrgPortVisibility,
},
{
name: "private",
accessControlEntries: privateAccessControlEntry,
expected: PrivatePortVisibility,
},
{
name: "orgIsDeny",
accessControlEntries: orgIsDenyAccessControlEntry,
expected: PrivatePortVisibility,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
visibility := AccessControlEntriesToVisibility(test.accessControlEntries)
if visibility != test.expected {
t.Errorf("expected %q, got %q", test.expected, visibility)
}
})
}
}
func TestIsInternalPort(t *testing.T) {
internalPort := &tunnels.TunnelPort{
Tags: []string{"InternalPort"},
}
userForwardedPort := &tunnels.TunnelPort{
Tags: []string{"UserForwardedPort"},
}
tests := []struct {
name string
port *tunnels.TunnelPort
expected bool
}{
{
name: "internal",
port: internalPort,
expected: true,
},
{
name: "user-forwarded",
port: userForwardedPort,
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
isInternal := IsInternalPort(test.port)
if isInternal != test.expected {
t.Errorf("expected %v, got %v", test.expected, isInternal)
}
})
}
}

View file

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"net/http"
"os"
"sort"
"strings"
@ -104,6 +105,7 @@ type apiClient interface {
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)
GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error)
HTTPClient() (*http.Client, error)
}
var errNoCodespaces = errors.New("you have no codespaces")

View file

@ -5,6 +5,7 @@ package codespace
import (
"context"
"net/http"
"sync"
codespacesAPI "github.com/cli/cli/v2/internal/codespaces/api"
@ -40,15 +41,15 @@ import (
// GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) {
// panic("mock out the GetCodespacesMachines method")
// },
// HTTPClientFunc: func() (*http.Client, error) {
// panic("mock out the HTTPClient method")
// },
// GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
// panic("mock out the GetOrgMemberCodespace method")
// },
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error) {
// panic("mock out the GetRepository method")
// },
// ServerURLFunc: func() string {
// panic("mock out the ServerURL method")
// },
// GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) {
// panic("mock out the GetUser method")
// },
@ -58,6 +59,9 @@ import (
// ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error) {
// panic("mock out the ListDevContainers method")
// },
// ServerURLFunc: func() string {
// panic("mock out the ServerURL method")
// },
// StartCodespaceFunc: func(ctx context.Context, name string) error {
// panic("mock out the StartCodespace method")
// },
@ -95,15 +99,15 @@ type apiClientMock struct {
// GetCodespacesMachinesFunc mocks the GetCodespacesMachines method.
GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error)
// HTTPClientFunc mocks the HTTPClient method.
HTTPClientFunc func() (*http.Client, error)
// GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method.
GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error)
// GetRepositoryFunc mocks the GetRepository method.
GetRepositoryFunc func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error)
// ServerURLFunc mocks the ServerURL method.
ServerURLFunc func() string
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error)
@ -113,6 +117,9 @@ type apiClientMock struct {
// ListDevContainersFunc mocks the ListDevContainers method.
ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error)
// ServerURLFunc mocks the ServerURL method.
ServerURLFunc func() string
// StartCodespaceFunc mocks the StartCodespace method.
StartCodespaceFunc func(ctx context.Context, name string) error
@ -195,6 +202,9 @@ type apiClientMock struct {
// DevcontainerPath is the devcontainerPath argument value.
DevcontainerPath string
}
// HTTPClient holds details about calls to the HTTPClient method.
HTTPClient []struct {
}
// GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method.
GetOrgMemberCodespace []struct {
// Ctx is the ctx argument value.
@ -213,9 +223,6 @@ type apiClientMock struct {
// Nwo is the nwo argument value.
Nwo string
}
// ServerURL holds details about calls to the ServerURL method.
ServerURL []struct {
}
// GetUser holds details about calls to the GetUser method.
GetUser []struct {
// Ctx is the ctx argument value.
@ -239,6 +246,9 @@ type apiClientMock struct {
// Limit is the limit argument value.
Limit int
}
// ServerURL holds details about calls to the ServerURL method.
ServerURL []struct {
}
// StartCodespace holds details about calls to the StartCodespace method.
StartCodespace []struct {
// Ctx is the ctx argument value.
@ -266,12 +276,13 @@ type apiClientMock struct {
lockGetCodespaceRepoSuggestions sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespacesMachines sync.RWMutex
lockHTTPClient sync.RWMutex
lockGetOrgMemberCodespace sync.RWMutex
lockGetRepository sync.RWMutex
lockServerURL sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockListDevContainers sync.RWMutex
lockServerURL sync.RWMutex
lockStartCodespace sync.RWMutex
lockStopCodespace sync.RWMutex
}
@ -600,6 +611,33 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct {
return calls
}
// HTTPClient calls HTTPClientFunc.
func (mock *apiClientMock) HTTPClient() (*http.Client, error) {
if mock.HTTPClientFunc == nil {
panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called")
}
callInfo := struct {
}{}
mock.lockHTTPClient.Lock()
mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo)
mock.lockHTTPClient.Unlock()
return mock.HTTPClientFunc()
}
// HTTPClientCalls gets all the calls that were made to HTTPClient.
// Check the length with:
//
// len(mockedapiClient.HTTPClientCalls())
func (mock *apiClientMock) HTTPClientCalls() []struct {
} {
var calls []struct {
}
mock.lockHTTPClient.RLock()
calls = mock.calls.HTTPClient
mock.lockHTTPClient.RUnlock()
return calls
}
// GetOrgMemberCodespace calls GetOrgMemberCodespaceFunc.
func (mock *apiClientMock) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
if mock.GetOrgMemberCodespaceFunc == nil {
@ -680,33 +718,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct {
return calls
}
// ServerURL calls ServerURLFunc.
func (mock *apiClientMock) ServerURL() string {
if mock.ServerURLFunc == nil {
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
}
callInfo := struct {
}{}
mock.lockServerURL.Lock()
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
mock.lockServerURL.Unlock()
return mock.ServerURLFunc()
}
// ServerURLCalls gets all the calls that were made to ServerURL.
// Check the length with:
//
// len(mockedapiClient.ServerURLCalls())
func (mock *apiClientMock) ServerURLCalls() []struct {
} {
var calls []struct {
}
mock.lockServerURL.RLock()
calls = mock.calls.ServerURL
mock.lockServerURL.RUnlock()
return calls
}
// GetUser calls GetUserFunc.
func (mock *apiClientMock) GetUser(ctx context.Context) (*codespacesAPI.User, error) {
if mock.GetUserFunc == nil {
@ -819,6 +830,33 @@ func (mock *apiClientMock) ListDevContainersCalls() []struct {
return calls
}
// ServerURL calls ServerURLFunc.
func (mock *apiClientMock) ServerURL() string {
if mock.ServerURLFunc == nil {
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
}
callInfo := struct {
}{}
mock.lockServerURL.Lock()
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
mock.lockServerURL.Unlock()
return mock.ServerURLFunc()
}
// ServerURLCalls gets all the calls that were made to ServerURL.
// Check the length with:
//
// len(mockedapiClient.ServerURLCalls())
func (mock *apiClientMock) ServerURLCalls() []struct {
} {
var calls []struct {
}
mock.lockServerURL.RLock()
calls = mock.calls.ServerURL
mock.lockServerURL.RUnlock()
return calls
}
// StartCodespace calls StartCodespaceFunc.
func (mock *apiClientMock) StartCodespace(ctx context.Context, name string) error {
if mock.StartCodespaceFunc == nil {

View file

@ -6,26 +6,21 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/tableprinter"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/cli/cli/v2/utils"
"github.com/microsoft/dev-tunnels/go/tunnels"
"github.com/muhammadmuzzammil1998/jsonc"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
const (
vscodeServerPortName = "VSCodeServerInternal"
codespacesInternalPortName = "CodespacesInternal"
)
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
// according to the specified flags.
func newPortsCmd(app *App) *cobra.Command {
@ -62,15 +57,19 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
session, err := startLiveShareSession(ctx, codespace, a, false, "")
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return err
return fmt.Errorf("error connecting to codespace: %w", err)
}
defer safeClose(session, &err)
var ports []*liveshare.Port
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
var ports []*tunnels.TunnelPort
err = a.RunWithProgress("Fetching ports", func() (err error) {
ports, err = session.GetSharedServers(ctx)
ports, err = fwd.ListPorts(ctx)
return
})
if err != nil {
@ -87,9 +86,10 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
for _, p := range ports {
// filter out internal ports from list
if strings.HasPrefix(p.SessionName, vscodeServerPortName) || strings.HasPrefix(p.SessionName, codespacesInternalPortName) {
if portforwarder.IsInternalPort(p) {
continue
}
portInfos = append(portInfos, &portInfo{
Port: p,
codespace: codespace,
@ -107,40 +107,42 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
}
cs := a.io.ColorScheme()
//nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter
tp := utils.NewTablePrinter(a.io)
tp := tableprinter.New(a.io)
if tp.IsTTY() {
tp.AddField("LABEL", nil, nil)
tp.AddField("PORT", nil, nil)
tp.AddField("VISIBILITY", nil, nil)
tp.AddField("BROWSE URL", nil, nil)
if a.io.IsStdoutTTY() {
tp.AddField("LABEL")
tp.AddField("PORT")
tp.AddField("VISIBILITY")
tp.AddField("BROWSE URL")
tp.EndRow()
}
for _, port := range portInfos {
tp.AddField(port.Label(), nil, nil)
tp.AddField(strconv.Itoa(port.SourcePort), nil, cs.Yellow)
tp.AddField(port.Privacy, nil, nil)
tp.AddField(port.BrowseURL(), nil, nil)
// Convert the ACE to a friendly visibility string (private, org, public)
visibility := portforwarder.AccessControlEntriesToVisibility(port.Port.AccessControl.Entries)
tp.AddField(port.Label())
tp.AddField(cs.Yellow(fmt.Sprintf("%d", port.Port.PortNumber)))
tp.AddField(visibility)
tp.AddField(port.BrowseURL())
tp.EndRow()
}
return tp.Render()
}
type portInfo struct {
*liveshare.Port
Port *tunnels.TunnelPort
codespace *api.Codespace
devContainer *devContainer
}
func (pi *portInfo) BrowseURL() string {
return fmt.Sprintf("https://%s-%d.preview.app.github.dev", pi.codespace.Name, pi.Port.SourcePort)
return fmt.Sprintf("https://%s-%d.app.github.dev", pi.codespace.Name, pi.Port.PortNumber)
}
func (pi *portInfo) Label() string {
if pi.devContainer != nil {
portStr := strconv.Itoa(pi.Port.SourcePort)
portStr := strconv.Itoa(int(pi.Port.PortNumber))
if attributes, ok := pi.devContainer.PortAttributes[portStr]; ok {
return attributes.Label
}
@ -150,7 +152,6 @@ func (pi *portInfo) Label() string {
var portFields = []string{
"sourcePort",
// "destinationPort", // TODO(mislav): this appears to always be blank?
"visibility",
"label",
"browseUrl",
@ -162,11 +163,9 @@ func (pi *portInfo) ExportData(fields []string) map[string]interface{} {
for _, f := range fields {
switch f {
case "sourcePort":
data[f] = pi.Port.SourcePort
case "destinationPort":
data[f] = pi.Port.DestinationPort
data[f] = pi.Port.PortNumber
case "visibility":
data[f] = pi.Port.Privacy
data[f] = portforwarder.AccessControlEntriesToVisibility(pi.Port.AccessControl.Entries)
case "label":
data[f] = pi.Label()
case "browseUrl":
@ -235,30 +234,6 @@ func newPortsVisibilityCmd(app *App, selector *CodespaceSelector) *cobra.Command
}
}
type ErrUpdatingPortVisibility struct {
port int
visibility string
err error
}
func newErrUpdatingPortVisibility(port int, visibility string, err error) *ErrUpdatingPortVisibility {
return &ErrUpdatingPortVisibility{
port: port,
visibility: visibility,
err: err,
}
}
func (e *ErrUpdatingPortVisibility) Error() string {
return fmt.Sprintf("error waiting for port %d to update to %s: %s", e.port, e.visibility, e.err)
}
func (e *ErrUpdatingPortVisibility) Unwrap() error {
return e.err
}
var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting")
func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelector, args []string) (err error) {
ports, err := a.parsePortVisibilities(args)
if err != nil {
@ -270,47 +245,28 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
return err
}
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to codespace: %w", err)
}
defer safeClose(session, &err)
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
// TODO: check if port visibility can be updated in parallel instead of sequentially
for _, port := range ports {
err := a.RunWithProgress(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility), func() (err error) {
// wait for success or failure
g, ctx := errgroup.WithContext(ctx)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
g.Go(func() error {
updateNotif, err := session.WaitForPortNotification(ctx, port.number, liveshare.PortChangeKindUpdate)
if err != nil {
return fmt.Errorf("error waiting for port %d to update: %w", port.number, err)
}
if !updateNotif.Success {
if updateNotif.StatusCode == http.StatusForbidden {
return newErrUpdatingPortVisibility(port.number, port.visibility, errUpdatePortVisibilityForbidden)
}
return newErrUpdatingPortVisibility(port.number, port.visibility, errors.New(updateNotif.ErrorDetail))
}
return nil // success
})
g.Go(func() error {
err := session.UpdateSharedServerPrivacy(ctx, port.number, port.visibility)
if err != nil {
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
}
return nil
})
// wait for success or failure
err = g.Wait()
return
err = fwd.UpdatePortVisibility(ctx, port.number, port.visibility)
if err != nil {
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
}
return nil
})
if err != nil {
return err
@ -367,11 +323,10 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
return err
}
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to codespace: %w", err)
}
defer safeClose(session, &err)
// Run forwarding of all ports concurrently, aborting all of
// them at the first failure, including cancellation of the context.
@ -386,9 +341,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
defer listen.Close()
a.errLogger.Printf("Forwarding ports: remote %d <=> local %d", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
return fwd.ForwardToListener(ctx, listen) // error always non-nil
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false)
})
}
return group.Wait() // first error

View file

@ -2,18 +2,34 @@ package codespace
import (
"context"
"errors"
"fmt"
"os"
"net/http"
"testing"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/liveshare"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestListPorts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mockApi := GetMockApi(false)
ios, _, _, _ := iostreams.Test()
a := &App{
io: ios,
apiClient: mockApi,
}
selector := &CodespaceSelector{api: a.apiClient, codespaceName: "codespace-name"}
err := a.ListPorts(ctx, selector, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestPortsUpdateVisibilitySuccess(t *testing.T) {
portVisibilities := []portVisibility{
{
@ -26,175 +42,35 @@ func TestPortsUpdateVisibilitySuccess(t *testing.T) {
},
}
eventResponses := []string{
"serverSharing.sharingSucceeded",
"serverSharing.sharingSucceeded",
}
portsData := []liveshare.PortNotification{
{
Success: true,
Port: 80,
ChangeKind: liveshare.PortChangeKindUpdate,
},
{
Success: true,
Port: 9999,
ChangeKind: liveshare.PortChangeKindUpdate,
},
}
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
err := runUpdateVisibilityTest(t, portVisibilities, true)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestPortsUpdateVisibilityFailure403(t *testing.T) {
portVisibilities := []portVisibility{
{
number: 80,
visibility: "org",
},
{
number: 9999,
visibility: "public",
},
}
eventResponses := []string{
"serverSharing.sharingSucceeded",
"serverSharing.sharingFailed",
}
portsData := []liveshare.PortNotification{
{
Success: true,
Port: 80,
ChangeKind: liveshare.PortChangeKindUpdate,
},
{
Success: false,
Port: 9999,
ChangeKind: liveshare.PortChangeKindUpdate,
ErrorDetail: "test error",
StatusCode: 403,
},
}
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
if err == nil {
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
}
if errors.Unwrap(err) != errUpdatePortVisibilityForbidden {
t.Errorf("expected: %v, got: %v", errUpdatePortVisibilityForbidden, errors.Unwrap(err))
}
}
func TestPortsUpdateVisibilityFailure(t *testing.T) {
portVisibilities := []portVisibility{
{
number: 80,
visibility: "org",
},
{
number: 9999,
visibility: "public",
},
}
eventResponses := []string{
"serverSharing.sharingSucceeded",
"serverSharing.sharingFailed",
}
portsData := []liveshare.PortNotification{
{
Success: true,
Port: 80,
ChangeKind: liveshare.PortChangeKindUpdate,
},
{
Success: false,
Port: 9999,
ChangeKind: liveshare.PortChangeKindUpdate,
ErrorDetail: "test error",
number: 80,
visibility: "org",
},
}
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
err := runUpdateVisibilityTest(t, portVisibilities, false)
if err == nil {
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
}
var expectedErr *ErrUpdatingPortVisibility
if !errors.As(err, &expectedErr) {
t.Errorf("expected: %v, got: %v", expectedErr, err)
}
}
type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error {
t.Helper()
if os.Getenv("GITHUB_ACTIONS") == "true" {
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
}
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
ch := make(chan *jsonrpc2.Conn, 1)
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
ch <- conn
return nil, nil
}
testServer, err := livesharetest.NewServer(
livesharetest.WithNonSecure(),
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
)
if err != nil {
return fmt.Errorf("unable to create test server: %w", err)
}
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, allowOrgPorts bool) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for i, pd := range portsData {
select {
case <-ctx.Done():
return
case conn := <-ch:
_, _ = conn.DispatchCall(ctx, eventResponses[i], pd, nil)
}
}
}()
mockApi := &apiClientMock{
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
return &api.Codespace{
Name: "codespace-name",
State: api.CodespaceStateAvailable,
Connection: api.CodespaceConnection{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: testServer.URL(),
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
},
}, nil
},
}
mockApi := GetMockApi(allowOrgPorts)
ios, _, _, _ := iostreams.Test()
a := &App{
@ -251,6 +127,44 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) {
}
}
func GetMockApi(allowOrgPorts bool) *apiClientMock {
return &apiClientMock{
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
allowedPortPrivacySettings := []string{"public", "private"}
if allowOrgPorts {
allowedPortPrivacySettings = append(allowedPortPrivacySettings, "org")
}
return &api.Codespace{
Name: "codespace-name",
State: api.CodespaceStateAvailable,
Connection: api.CodespaceConnection{
TunnelProperties: api.TunnelProperties{
ConnectAccessToken: "tunnel access-token",
ManagePortsAccessToken: "manage-ports-token",
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
TunnelId: "tunnel-id",
ClusterId: "usw2",
Domain: "domain.com",
},
},
RuntimeConstraints: api.RuntimeConstraints{
AllowedPortPrivacySettings: allowedPortPrivacySettings,
},
}, nil
},
StartCodespaceFunc: func(ctx context.Context, codespaceName string) error {
return nil
},
GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
return nil, nil
},
HTTPClientFunc: func() (*http.Client, error) {
return connection.NewMockHttpClient()
},
}
}
func testingPortsApp() *App {
disabledCodespace := &api.Codespace{
Name: "disabledCodespace",