Forward codespace ports over Dev Tunnels
This commit is contained in:
parent
48b0d53d0e
commit
e059f32aa5
13 changed files with 1271 additions and 303 deletions
2
go.mod
2
go.mod
|
|
@ -27,6 +27,7 @@ require (
|
|||
github.com/mattn/go-colorable v0.1.13
|
||||
github.com/mattn/go-isatty v0.0.19
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/microsoft/dev-tunnels v0.0.21
|
||||
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
|
||||
github.com/opentracing/opentracing-go v1.1.0
|
||||
github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d
|
||||
|
|
@ -75,6 +76,7 @@ require (
|
|||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rodaine/table v1.0.1 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -117,6 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ
|
|||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
|
||||
github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg=
|
||||
github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM=
|
||||
github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0=
|
||||
github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
|
||||
github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
|
|
@ -139,6 +141,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
|
|||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ=
|
||||
github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46DPtb8k793jiecUEhaX9ixoIBt41HEGU=
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ type Codespace struct {
|
|||
GitStatus CodespaceGitStatus `json:"git_status"`
|
||||
Connection CodespaceConnection `json:"connection"`
|
||||
Machine CodespaceMachine `json:"machine"`
|
||||
RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"`
|
||||
VSCSTarget string `json:"vscs_target"`
|
||||
PendingOperation bool `json:"pending_operation"`
|
||||
PendingOperationDisabledReason string `json:"pending_operation_disabled_reason"`
|
||||
|
|
@ -246,11 +247,25 @@ const (
|
|||
)
|
||||
|
||||
type CodespaceConnection struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
TunnelProperties TunnelProperties `json:"tunnelProperties"`
|
||||
}
|
||||
|
||||
type TunnelProperties struct {
|
||||
ConnectAccessToken string `json:"connectAccessToken"`
|
||||
ManagePortsAccessToken string `json:"managePortsAccessToken"`
|
||||
ServiceUri string `json:"serviceUri"`
|
||||
TunnelId string `json:"tunnelId"`
|
||||
ClusterId string `json:"clusterId"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type RuntimeConstraints struct {
|
||||
AllowedPortPrivacySettings []string `json:"allowed_port_privacy_settings"`
|
||||
}
|
||||
|
||||
// ListCodespaceFields is the list of exportable fields for a codespace when using the `gh cs list` command.
|
||||
|
|
@ -1162,3 +1177,13 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error
|
|||
return nil, fmt.Errorf("received response with status code %d", resp.StatusCode)
|
||||
}, backoff.WithMaxRetries(bo, 3))
|
||||
}
|
||||
|
||||
// HTTPClient returns the HTTP client used to make requests to the API.
|
||||
func (a *API) HTTPClient() (*http.Client, error) {
|
||||
httpClient, err := a.client()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return httpClient, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,24 +5,42 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
)
|
||||
|
||||
func connectionReady(codespace *api.Codespace) bool {
|
||||
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
|
||||
// If the codespace is not available, it is not ready
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
return false
|
||||
}
|
||||
|
||||
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
|
||||
if usingDevTunnels {
|
||||
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ServiceUri != "" &&
|
||||
codespace.Connection.TunnelProperties.TunnelId != "" &&
|
||||
codespace.Connection.TunnelProperties.ClusterId != "" &&
|
||||
codespace.Connection.TunnelProperties.Domain != ""
|
||||
}
|
||||
|
||||
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
|
||||
return codespace.Connection.SessionID != "" &&
|
||||
codespace.Connection.SessionToken != "" &&
|
||||
codespace.Connection.RelayEndpoint != "" &&
|
||||
codespace.Connection.RelaySAS != "" &&
|
||||
codespace.State == api.CodespaceStateAvailable
|
||||
codespace.Connection.RelaySAS != ""
|
||||
}
|
||||
|
||||
type apiClient interface {
|
||||
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
|
||||
StartCodespace(ctx context.Context, name string) error
|
||||
HTTPClient() (*http.Client, error)
|
||||
}
|
||||
|
||||
type progressIndicator interface {
|
||||
|
|
@ -43,9 +61,48 @@ func (e *TimeoutError) Error() string {
|
|||
return e.message
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits for a Codespace to become running,
|
||||
// and connects to it using a Live Share session.
|
||||
// GetCodespaceConnection waits until a codespace is able
|
||||
// to be connected to and initializes a connection to it.
|
||||
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
httpClient, err := apiClient.HTTPClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting http client: %w", err)
|
||||
}
|
||||
|
||||
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits until a codespace is able to be
|
||||
// connected to and connects to it using a Live Share session.
|
||||
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
}
|
||||
|
||||
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
|
||||
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
progress.StartProgressIndicatorWithLabel("Starting codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
|
@ -54,7 +111,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
}
|
||||
}
|
||||
|
||||
if !connectionReady(codespace) {
|
||||
if !connectionReady(codespace, usingDevTunnels) {
|
||||
expBackoff := backoff.NewExponentialBackOff()
|
||||
expBackoff.Multiplier = 1.1
|
||||
expBackoff.MaxInterval = 10 * time.Second
|
||||
|
|
@ -67,7 +124,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
|
||||
}
|
||||
|
||||
if connectionReady(codespace) {
|
||||
if connectionReady(codespace, usingDevTunnels) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -83,17 +140,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
}
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
return codespace, nil
|
||||
}
|
||||
|
||||
// ListenTCP starts a localhost tcp listener on 127.0.0.1 (unless allInterfaces is true) and returns the listener and bound port
|
||||
|
|
|
|||
116
internal/codespaces/connection/connection.go
Normal file
116
internal/codespaces/connection/connection.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
const (
|
||||
clientName = "gh"
|
||||
)
|
||||
|
||||
type CodespaceConnection struct {
|
||||
tunnelProperties api.TunnelProperties
|
||||
TunnelManager *tunnels.Manager
|
||||
TunnelClient *tunnels.Client
|
||||
Options *tunnels.TunnelRequestOptions
|
||||
Tunnel *tunnels.Tunnel
|
||||
AllowedPortPrivacySettings []string
|
||||
}
|
||||
|
||||
// NewCodespaceConnection initializes a connection to a codespace.
|
||||
// This connections allows for port forwarding which enables the
|
||||
// use of most features of the codespace command.
|
||||
func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) {
|
||||
// Get the tunnel properties
|
||||
tunnelProperties := codespace.Connection.TunnelProperties
|
||||
|
||||
// Create the tunnel manager
|
||||
tunnelManager, err := getTunnelManager(tunnelProperties, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel management client: %w", err)
|
||||
}
|
||||
|
||||
// Calculate allowed port privacy settings
|
||||
allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings
|
||||
|
||||
// Get the access tokens
|
||||
connectToken := tunnelProperties.ConnectAccessToken
|
||||
managementToken := tunnelProperties.ManagePortsAccessToken
|
||||
|
||||
// Create the tunnel definition
|
||||
tunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken},
|
||||
TunnelID: tunnelProperties.TunnelId,
|
||||
ClusterID: tunnelProperties.ClusterId,
|
||||
Domain: tunnelProperties.Domain,
|
||||
}
|
||||
|
||||
// Create options
|
||||
options := &tunnels.TunnelRequestOptions{
|
||||
IncludePorts: true,
|
||||
}
|
||||
|
||||
// Create the tunnel client (not connected yet)
|
||||
tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel client: %w", err)
|
||||
}
|
||||
|
||||
return &CodespaceConnection{
|
||||
tunnelProperties: tunnelProperties,
|
||||
TunnelManager: tunnelManager,
|
||||
TunnelClient: tunnelClient,
|
||||
Options: options,
|
||||
Tunnel: tunnel,
|
||||
AllowedPortPrivacySettings: allowedPortPrivacySettings,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelManager creates a tunnel manager for the given codespace.
|
||||
// The tunnel manager is used to get the tunnel hosted in the codespace that we
|
||||
// want to connect to and perform operations on ports (add, remove, list, etc.).
|
||||
func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) {
|
||||
userAgent := []tunnels.UserAgent{{Name: clientName}}
|
||||
url, err := url.Parse(tunnelProperties.ServiceUri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing tunnel service uri: %w", err)
|
||||
}
|
||||
|
||||
// Create the tunnel manager
|
||||
tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tunnel manager: %w", err)
|
||||
}
|
||||
|
||||
return tunnelManager, nil
|
||||
}
|
||||
|
||||
// getTunnelClient creates a tunnel client for the given tunnel.
|
||||
// The tunnel client is used to connect to the the tunnel and allows
|
||||
// for ports to be forwarded locally.
|
||||
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
|
||||
// Get the tunnel that we want to connect to
|
||||
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Copy the access tokens from the tunnel definition
|
||||
codespaceTunnel.AccessTokens = tunnel.AccessTokens
|
||||
|
||||
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
|
||||
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tunnel client: %w", err)
|
||||
}
|
||||
|
||||
return tunnelClient, nil
|
||||
}
|
||||
75
internal/codespaces/connection/connection_test.go
Normal file
75
internal/codespaces/connection/connection_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
func TestNewCodespaceConnection(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a mock codespace
|
||||
connection := api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "connect-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
}
|
||||
allowedPortPrivacySettings := []string{"public", "private"}
|
||||
codespace := &api.Codespace{
|
||||
Connection: connection,
|
||||
RuntimeConstraints: api.RuntimeConstraints{AllowedPortPrivacySettings: allowedPortPrivacySettings},
|
||||
}
|
||||
|
||||
// Create the mock HTTP client
|
||||
httpClient, err := NewMockHttpClient()
|
||||
if err != nil {
|
||||
t.Fatalf("NewHttpClient returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Create the connection
|
||||
conn, err := NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Check that the connection was created successfully
|
||||
if conn == nil {
|
||||
t.Fatal("NewCodespaceConnection returned nil")
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel properties
|
||||
if conn.tunnelProperties != connection.TunnelProperties {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel properties: %+v", conn.tunnelProperties)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel
|
||||
expectedTunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connection.TunnelProperties.ConnectAccessToken, tunnels.TunnelAccessScopeManagePorts: connection.TunnelProperties.ManagePortsAccessToken},
|
||||
TunnelID: connection.TunnelProperties.TunnelId,
|
||||
ClusterID: connection.TunnelProperties.ClusterId,
|
||||
Domain: connection.TunnelProperties.Domain,
|
||||
}
|
||||
if !reflect.DeepEqual(conn.Tunnel, expectedTunnel) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel: %+v", conn.Tunnel)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel options
|
||||
expectedOptions := &tunnels.TunnelRequestOptions{IncludePorts: true}
|
||||
if !reflect.DeepEqual(conn.Options, expectedOptions) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected options: %+v", conn.Options)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected allowed port privacy settings
|
||||
if !reflect.DeepEqual(conn.AllowedPortPrivacySettings, allowedPortPrivacySettings) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected allowed port privacy settings: %+v", conn.AllowedPortPrivacySettings)
|
||||
}
|
||||
}
|
||||
396
internal/codespaces/connection/tunnels_api_server_mock.go
Normal file
396
internal/codespaces/connection/tunnels_api_server_mock.go
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func NewMockHttpClient() (*http.Client, error) {
|
||||
accessToken := "tunnel access-token"
|
||||
relayServer, err := newMockrelayServer(withAccessToken(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewrelayServer returned an error: %w", err)
|
||||
}
|
||||
|
||||
hostURL := strings.Replace(relayServer.URL(), "http://", "ws://", 1)
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var response []byte
|
||||
if r.URL.Path == "/api/v1/tunnels/tunnel-id" {
|
||||
tunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{
|
||||
tunnels.TunnelAccessScopeConnect: accessToken,
|
||||
},
|
||||
Endpoints: []tunnels.TunnelEndpoint{
|
||||
{
|
||||
HostID: "host1",
|
||||
TunnelRelayTunnelEndpoint: tunnels.TunnelRelayTunnelEndpoint{
|
||||
ClientRelayURI: hostURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err = json.Marshal(*tunnel)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
} else if strings.HasPrefix(r.URL.Path, "/api/v1/tunnels/tunnel-id/ports") {
|
||||
// Use regex to check if the path ends with a number
|
||||
match, err := regexp.MatchString(`\/\d+$`, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Fatalf("regexp.MatchString returned an error: %v", err)
|
||||
}
|
||||
|
||||
// If the path ends with a number, it's a request for a specific port
|
||||
if match || r.Method == http.MethodPost {
|
||||
if r.Method == http.MethodDelete {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
tunnelPort := &tunnels.TunnelPort{
|
||||
AccessControl: &tunnels.TunnelAccessControl{
|
||||
Entries: []tunnels.TunnelAccessControlEntry{},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert the tunnel to JSON and write it to the response
|
||||
response, err = json.Marshal(*tunnelPort)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
} else {
|
||||
// If the path doesn't end with a number and we aren't making a POST request, return an array of ports
|
||||
tunnelPorts := []tunnels.TunnelPort{
|
||||
{
|
||||
AccessControl: &tunnels.TunnelAccessControl{
|
||||
Entries: []tunnels.TunnelAccessControlEntry{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err = json.Marshal(tunnelPorts)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Write the response
|
||||
_, _ = w.Write(response)
|
||||
}))
|
||||
|
||||
url, err := url.Parse(mockServer.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("url.Parse returned an error: %w", err)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(url),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type relayServer struct {
|
||||
httpServer *httptest.Server
|
||||
errc chan error
|
||||
sshConfig *ssh.ServerConfig
|
||||
channels map[string]channelHandler
|
||||
accessToken string
|
||||
|
||||
serverConn *ssh.ServerConn
|
||||
}
|
||||
|
||||
type relayServerOption func(*relayServer)
|
||||
type channelHandler func(context.Context, ssh.NewChannel) error
|
||||
|
||||
func newMockrelayServer(opts ...relayServerOption) (*relayServer, error) {
|
||||
server := &relayServer{
|
||||
errc: make(chan error),
|
||||
sshConfig: &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a private key with the crypto package
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
},
|
||||
)
|
||||
|
||||
// Parse the private key
|
||||
sshPrivateKey, err := ssh.ParsePrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
server.sshConfig.AddHostKey(ssh.Signer(sshPrivateKey))
|
||||
|
||||
server.httpServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func withAccessToken(accessToken string) func(*relayServer) {
|
||||
return func(server *relayServer) {
|
||||
server.accessToken = accessToken
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *relayServer) URL() string {
|
||||
return rs.httpServer.URL
|
||||
}
|
||||
|
||||
func (rs *relayServer) Err() <-chan error {
|
||||
return rs.errc
|
||||
}
|
||||
|
||||
func (rs *relayServer) sendError(err error) {
|
||||
select {
|
||||
case rs.errc <- err:
|
||||
default:
|
||||
// channel is blocked with a previous error, so we ignore this one
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *relayServer) ForwardPort(ctx context.Context, port uint16) error {
|
||||
pfr := messages.NewPortForwardRequest("127.0.0.1", uint32(port))
|
||||
b, err := pfr.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling port forward request: %w", err)
|
||||
}
|
||||
|
||||
replied, data, err := rs.serverConn.SendRequest(messages.PortForwardRequestType, true, b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending port forward request: %w", err)
|
||||
}
|
||||
|
||||
if !replied {
|
||||
return fmt.Errorf("port forward request not replied")
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return fmt.Errorf("no data returned")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeConnection(server *relayServer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if server.accessToken != "" {
|
||||
if r.Header.Get("Authorization") != server.accessToken {
|
||||
server.sendError(fmt.Errorf("invalid access token"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
server.sendError(fmt.Errorf("error upgrading to websocket: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
server.sendError(fmt.Errorf("error closing websocket: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
serverConn, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
server.sendError(fmt.Errorf("error creating ssh server conn: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
go handleRequests(ctx, convertRequests(reqs))
|
||||
|
||||
server.serverConn = serverConn
|
||||
if err := handleChannels(ctx, server, chans); err != nil {
|
||||
server.sendError(fmt.Errorf("error handling channels: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *sshRequest) Type() string {
|
||||
return sr.request.Type
|
||||
}
|
||||
|
||||
type sshRequest struct {
|
||||
request *ssh.Request
|
||||
}
|
||||
|
||||
// Reply method for sshRequest to satisfy the tunnelssh.SSHRequest interface
|
||||
func (sr *sshRequest) Reply(success bool, message []byte) error {
|
||||
return sr.request.Reply(success, message)
|
||||
}
|
||||
|
||||
// convertRequests function
|
||||
func convertRequests(reqs <-chan *ssh.Request) <-chan tunnelssh.SSHRequest {
|
||||
out := make(chan tunnelssh.SSHRequest)
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
out <- &sshRequest{req}
|
||||
}
|
||||
close(out)
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func handleChannels(ctx context.Context, server *relayServer, chans <-chan ssh.NewChannel) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for ch := range chans {
|
||||
if handler, ok := server.channels[ch.ChannelType()]; ok {
|
||||
if err := handler(ctx, ch); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// generic accept of the channel to not block
|
||||
_, _, err := ch.Accept()
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("error accepting channel: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func handleRequests(ctx context.Context, reqs <-chan tunnelssh.SSHRequest) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case req, ok := <-reqs:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Type() == "RefreshPorts" {
|
||||
_ = req.Reply(true, nil)
|
||||
continue
|
||||
} else {
|
||||
_ = req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %w", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %w", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
253
internal/codespaces/portforwarder/port_forwarder.go
Normal file
253
internal/codespaces/portforwarder/port_forwarder.go
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
package portforwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
const (
|
||||
githubSubjectId = "1"
|
||||
InternalPortTag = "InternalPort"
|
||||
UserForwardedPortTag = "UserForwardedPort"
|
||||
)
|
||||
|
||||
const (
|
||||
PrivatePortVisibility = "private"
|
||||
OrgPortVisibility = "org"
|
||||
PublicPortVisibility = "public"
|
||||
)
|
||||
|
||||
type PortForwarder struct {
|
||||
connection connection.CodespaceConnection
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified codespace.
|
||||
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
|
||||
return &PortForwarder{
|
||||
connection: *codespaceConnection,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
|
||||
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
|
||||
}
|
||||
|
||||
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
|
||||
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
|
||||
|
||||
// If no visibility is provided, Dev Tunnels will use the default (private)
|
||||
if visibility != "" {
|
||||
// Check if the requested visibility is allowed
|
||||
allowed := false
|
||||
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
|
||||
if allowedVisibility == visibility {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the requested visibility is not allowed, return an error
|
||||
if !allowed {
|
||||
return fmt.Errorf("visibility %s is not allowed", visibility)
|
||||
}
|
||||
|
||||
accessControlEntries := visibilityToAccessControlEntries(visibility)
|
||||
if len(accessControlEntries) > 0 {
|
||||
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
|
||||
Entries: accessControlEntries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
|
||||
if internal {
|
||||
tunnelPort.Tags = []string{InternalPortTag}
|
||||
} else {
|
||||
tunnelPort.Tags = []string{UserForwardedPortTag}
|
||||
}
|
||||
|
||||
// Create the tunnel port
|
||||
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
|
||||
if err != nil && !strings.Contains(err.Error(), "409") {
|
||||
return fmt.Errorf("create tunnel port failed: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inform the host that we've forwarded the port locally
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("refresh ports failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't want to connect to the port, exit early
|
||||
if !connect {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the port is forwarded before connecting
|
||||
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to the forwarded port via a local TCP port
|
||||
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListPorts fetches the list of ports that are currently forwarded.
|
||||
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
|
||||
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing ports: %w", err)
|
||||
}
|
||||
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
|
||||
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
||||
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting tunnel port: %w", err)
|
||||
}
|
||||
|
||||
// If the port visibility isn't changing, don't do anything
|
||||
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the existing tunnel port to update
|
||||
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting tunnel port: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inform the host that we've deleted the port
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("refresh ports failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
// Wait for the done channel to be closed
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Re-forward the port with the updated visibility
|
||||
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error forwarding port: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
|
||||
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
|
||||
for _, entry := range accessControlEntries {
|
||||
// If we have the anonymous type (and we're not denying it), it's public
|
||||
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
|
||||
return PublicPortVisibility
|
||||
}
|
||||
|
||||
// If we have the organizations type (and we're not denying it), it's org
|
||||
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
|
||||
return OrgPortVisibility
|
||||
}
|
||||
}
|
||||
|
||||
// Else, it's private
|
||||
return PrivatePortVisibility
|
||||
}
|
||||
|
||||
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
|
||||
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
|
||||
switch visibility {
|
||||
case PublicPortVisibility:
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
|
||||
Subjects: []string{},
|
||||
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
|
||||
}}
|
||||
case OrgPortVisibility:
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
||||
Subjects: []string{githubSubjectId},
|
||||
Scopes: []string{
|
||||
string(tunnels.TunnelAccessScopeConnect),
|
||||
},
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
}}
|
||||
default:
|
||||
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
||||
Subjects: []string{githubSubjectId},
|
||||
Scopes: []string{},
|
||||
IsDeny: true,
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
// IsInternalPort returns true if the port is internal.
|
||||
func IsInternalPort(port *tunnels.TunnelPort) bool {
|
||||
for _, tag := range port.Tags {
|
||||
if strings.EqualFold(tag, InternalPortTag) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
139
internal/codespaces/portforwarder/port_forwarder_test.go
Normal file
139
internal/codespaces/portforwarder/port_forwarder_test.go
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
package portforwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a mock codespace
|
||||
codespace := &api.Codespace{
|
||||
Connection: api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "connect-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
},
|
||||
RuntimeConstraints: api.RuntimeConstraints{
|
||||
AllowedPortPrivacySettings: []string{"public", "private"},
|
||||
},
|
||||
}
|
||||
|
||||
// Create the mock HTTP client
|
||||
httpClient, err := connection.NewMockHttpClient()
|
||||
if err != nil {
|
||||
t.Fatalf("NewHttpClient returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Call the function being tested
|
||||
conn, err := connection.NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Create the new port forwarder
|
||||
portForwarder, err := NewPortForwarder(ctx, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPortForwarder returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Check that the port forwarder was created successfully
|
||||
if portForwarder == nil {
|
||||
t.Fatal("NewPortForwarder returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessControlEntriesToVisibility(t *testing.T) {
|
||||
publicAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
|
||||
}}
|
||||
orgAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
}}
|
||||
privateAccessControlEntry := []tunnels.TunnelAccessControlEntry{}
|
||||
orgIsDenyAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
IsDeny: true,
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessControlEntries []tunnels.TunnelAccessControlEntry
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "public",
|
||||
accessControlEntries: publicAccessControlEntry,
|
||||
expected: PublicPortVisibility,
|
||||
},
|
||||
{
|
||||
name: "org",
|
||||
accessControlEntries: orgAccessControlEntry,
|
||||
expected: OrgPortVisibility,
|
||||
},
|
||||
{
|
||||
name: "private",
|
||||
accessControlEntries: privateAccessControlEntry,
|
||||
expected: PrivatePortVisibility,
|
||||
},
|
||||
{
|
||||
name: "orgIsDeny",
|
||||
accessControlEntries: orgIsDenyAccessControlEntry,
|
||||
expected: PrivatePortVisibility,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
visibility := AccessControlEntriesToVisibility(test.accessControlEntries)
|
||||
if visibility != test.expected {
|
||||
t.Errorf("expected %q, got %q", test.expected, visibility)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInternalPort(t *testing.T) {
|
||||
internalPort := &tunnels.TunnelPort{
|
||||
Tags: []string{"InternalPort"},
|
||||
}
|
||||
userForwardedPort := &tunnels.TunnelPort{
|
||||
Tags: []string{"UserForwardedPort"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port *tunnels.TunnelPort
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "internal",
|
||||
port: internalPort,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "user-forwarded",
|
||||
port: userForwardedPort,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
isInternal := IsInternalPort(test.port)
|
||||
if isInternal != test.expected {
|
||||
t.Errorf("expected %v, got %v", test.expected, isInternal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -104,6 +105,7 @@ type apiClient interface {
|
|||
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
|
||||
GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)
|
||||
GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error)
|
||||
HTTPClient() (*http.Client, error)
|
||||
}
|
||||
|
||||
var errNoCodespaces = errors.New("you have no codespaces")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package codespace
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
codespacesAPI "github.com/cli/cli/v2/internal/codespaces/api"
|
||||
|
|
@ -40,15 +41,15 @@ import (
|
|||
// GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) {
|
||||
// panic("mock out the GetCodespacesMachines method")
|
||||
// },
|
||||
// HTTPClientFunc: func() (*http.Client, error) {
|
||||
// panic("mock out the HTTPClient method")
|
||||
// },
|
||||
// GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
|
||||
// panic("mock out the GetOrgMemberCodespace method")
|
||||
// },
|
||||
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error) {
|
||||
// panic("mock out the GetRepository method")
|
||||
// },
|
||||
// ServerURLFunc: func() string {
|
||||
// panic("mock out the ServerURL method")
|
||||
// },
|
||||
// GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) {
|
||||
// panic("mock out the GetUser method")
|
||||
// },
|
||||
|
|
@ -58,6 +59,9 @@ import (
|
|||
// ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error) {
|
||||
// panic("mock out the ListDevContainers method")
|
||||
// },
|
||||
// ServerURLFunc: func() string {
|
||||
// panic("mock out the ServerURL method")
|
||||
// },
|
||||
// StartCodespaceFunc: func(ctx context.Context, name string) error {
|
||||
// panic("mock out the StartCodespace method")
|
||||
// },
|
||||
|
|
@ -95,15 +99,15 @@ type apiClientMock struct {
|
|||
// GetCodespacesMachinesFunc mocks the GetCodespacesMachines method.
|
||||
GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error)
|
||||
|
||||
// HTTPClientFunc mocks the HTTPClient method.
|
||||
HTTPClientFunc func() (*http.Client, error)
|
||||
|
||||
// GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method.
|
||||
GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error)
|
||||
|
||||
// GetRepositoryFunc mocks the GetRepository method.
|
||||
GetRepositoryFunc func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error)
|
||||
|
||||
// ServerURLFunc mocks the ServerURL method.
|
||||
ServerURLFunc func() string
|
||||
|
||||
// GetUserFunc mocks the GetUser method.
|
||||
GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error)
|
||||
|
||||
|
|
@ -113,6 +117,9 @@ type apiClientMock struct {
|
|||
// ListDevContainersFunc mocks the ListDevContainers method.
|
||||
ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error)
|
||||
|
||||
// ServerURLFunc mocks the ServerURL method.
|
||||
ServerURLFunc func() string
|
||||
|
||||
// StartCodespaceFunc mocks the StartCodespace method.
|
||||
StartCodespaceFunc func(ctx context.Context, name string) error
|
||||
|
||||
|
|
@ -195,6 +202,9 @@ type apiClientMock struct {
|
|||
// DevcontainerPath is the devcontainerPath argument value.
|
||||
DevcontainerPath string
|
||||
}
|
||||
// HTTPClient holds details about calls to the HTTPClient method.
|
||||
HTTPClient []struct {
|
||||
}
|
||||
// GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method.
|
||||
GetOrgMemberCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -213,9 +223,6 @@ type apiClientMock struct {
|
|||
// Nwo is the nwo argument value.
|
||||
Nwo string
|
||||
}
|
||||
// ServerURL holds details about calls to the ServerURL method.
|
||||
ServerURL []struct {
|
||||
}
|
||||
// GetUser holds details about calls to the GetUser method.
|
||||
GetUser []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -239,6 +246,9 @@ type apiClientMock struct {
|
|||
// Limit is the limit argument value.
|
||||
Limit int
|
||||
}
|
||||
// ServerURL holds details about calls to the ServerURL method.
|
||||
ServerURL []struct {
|
||||
}
|
||||
// StartCodespace holds details about calls to the StartCodespace method.
|
||||
StartCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -266,12 +276,13 @@ type apiClientMock struct {
|
|||
lockGetCodespaceRepoSuggestions sync.RWMutex
|
||||
lockGetCodespaceRepositoryContents sync.RWMutex
|
||||
lockGetCodespacesMachines sync.RWMutex
|
||||
lockHTTPClient sync.RWMutex
|
||||
lockGetOrgMemberCodespace sync.RWMutex
|
||||
lockGetRepository sync.RWMutex
|
||||
lockServerURL sync.RWMutex
|
||||
lockGetUser sync.RWMutex
|
||||
lockListCodespaces sync.RWMutex
|
||||
lockListDevContainers sync.RWMutex
|
||||
lockServerURL sync.RWMutex
|
||||
lockStartCodespace sync.RWMutex
|
||||
lockStopCodespace sync.RWMutex
|
||||
}
|
||||
|
|
@ -600,6 +611,33 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// HTTPClient calls HTTPClientFunc.
|
||||
func (mock *apiClientMock) HTTPClient() (*http.Client, error) {
|
||||
if mock.HTTPClientFunc == nil {
|
||||
panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockHTTPClient.Lock()
|
||||
mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo)
|
||||
mock.lockHTTPClient.Unlock()
|
||||
return mock.HTTPClientFunc()
|
||||
}
|
||||
|
||||
// HTTPClientCalls gets all the calls that were made to HTTPClient.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.HTTPClientCalls())
|
||||
func (mock *apiClientMock) HTTPClientCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockHTTPClient.RLock()
|
||||
calls = mock.calls.HTTPClient
|
||||
mock.lockHTTPClient.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetOrgMemberCodespace calls GetOrgMemberCodespaceFunc.
|
||||
func (mock *apiClientMock) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
|
||||
if mock.GetOrgMemberCodespaceFunc == nil {
|
||||
|
|
@ -680,33 +718,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// ServerURL calls ServerURLFunc.
|
||||
func (mock *apiClientMock) ServerURL() string {
|
||||
if mock.ServerURLFunc == nil {
|
||||
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockServerURL.Lock()
|
||||
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
|
||||
mock.lockServerURL.Unlock()
|
||||
return mock.ServerURLFunc()
|
||||
}
|
||||
|
||||
// ServerURLCalls gets all the calls that were made to ServerURL.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.ServerURLCalls())
|
||||
func (mock *apiClientMock) ServerURLCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockServerURL.RLock()
|
||||
calls = mock.calls.ServerURL
|
||||
mock.lockServerURL.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetUser calls GetUserFunc.
|
||||
func (mock *apiClientMock) GetUser(ctx context.Context) (*codespacesAPI.User, error) {
|
||||
if mock.GetUserFunc == nil {
|
||||
|
|
@ -819,6 +830,33 @@ func (mock *apiClientMock) ListDevContainersCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// ServerURL calls ServerURLFunc.
|
||||
func (mock *apiClientMock) ServerURL() string {
|
||||
if mock.ServerURLFunc == nil {
|
||||
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockServerURL.Lock()
|
||||
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
|
||||
mock.lockServerURL.Unlock()
|
||||
return mock.ServerURLFunc()
|
||||
}
|
||||
|
||||
// ServerURLCalls gets all the calls that were made to ServerURL.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.ServerURLCalls())
|
||||
func (mock *apiClientMock) ServerURLCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockServerURL.RLock()
|
||||
calls = mock.calls.ServerURL
|
||||
mock.lockServerURL.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// StartCodespace calls StartCodespaceFunc.
|
||||
func (mock *apiClientMock) StartCodespace(ctx context.Context, name string) error {
|
||||
if mock.StartCodespaceFunc == nil {
|
||||
|
|
|
|||
|
|
@ -6,26 +6,21 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/tableprinter"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/cli/cli/v2/utils"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
"github.com/muhammadmuzzammil1998/jsonc"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
vscodeServerPortName = "VSCodeServerInternal"
|
||||
codespacesInternalPortName = "CodespacesInternal"
|
||||
)
|
||||
|
||||
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
|
||||
// according to the specified flags.
|
||||
func newPortsCmd(app *App) *cobra.Command {
|
||||
|
|
@ -62,15 +57,19 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
|
||||
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, false, "")
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
var ports []*liveshare.Port
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
|
||||
var ports []*tunnels.TunnelPort
|
||||
err = a.RunWithProgress("Fetching ports", func() (err error) {
|
||||
ports, err = session.GetSharedServers(ctx)
|
||||
ports, err = fwd.ListPorts(ctx)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -87,9 +86,10 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
|
||||
for _, p := range ports {
|
||||
// filter out internal ports from list
|
||||
if strings.HasPrefix(p.SessionName, vscodeServerPortName) || strings.HasPrefix(p.SessionName, codespacesInternalPortName) {
|
||||
if portforwarder.IsInternalPort(p) {
|
||||
continue
|
||||
}
|
||||
|
||||
portInfos = append(portInfos, &portInfo{
|
||||
Port: p,
|
||||
codespace: codespace,
|
||||
|
|
@ -107,40 +107,42 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
}
|
||||
|
||||
cs := a.io.ColorScheme()
|
||||
//nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter
|
||||
tp := utils.NewTablePrinter(a.io)
|
||||
tp := tableprinter.New(a.io)
|
||||
|
||||
if tp.IsTTY() {
|
||||
tp.AddField("LABEL", nil, nil)
|
||||
tp.AddField("PORT", nil, nil)
|
||||
tp.AddField("VISIBILITY", nil, nil)
|
||||
tp.AddField("BROWSE URL", nil, nil)
|
||||
if a.io.IsStdoutTTY() {
|
||||
tp.AddField("LABEL")
|
||||
tp.AddField("PORT")
|
||||
tp.AddField("VISIBILITY")
|
||||
tp.AddField("BROWSE URL")
|
||||
tp.EndRow()
|
||||
}
|
||||
|
||||
for _, port := range portInfos {
|
||||
tp.AddField(port.Label(), nil, nil)
|
||||
tp.AddField(strconv.Itoa(port.SourcePort), nil, cs.Yellow)
|
||||
tp.AddField(port.Privacy, nil, nil)
|
||||
tp.AddField(port.BrowseURL(), nil, nil)
|
||||
// Convert the ACE to a friendly visibility string (private, org, public)
|
||||
visibility := portforwarder.AccessControlEntriesToVisibility(port.Port.AccessControl.Entries)
|
||||
|
||||
tp.AddField(port.Label())
|
||||
tp.AddField(cs.Yellow(fmt.Sprintf("%d", port.Port.PortNumber)))
|
||||
tp.AddField(visibility)
|
||||
tp.AddField(port.BrowseURL())
|
||||
tp.EndRow()
|
||||
}
|
||||
return tp.Render()
|
||||
}
|
||||
|
||||
type portInfo struct {
|
||||
*liveshare.Port
|
||||
Port *tunnels.TunnelPort
|
||||
codespace *api.Codespace
|
||||
devContainer *devContainer
|
||||
}
|
||||
|
||||
func (pi *portInfo) BrowseURL() string {
|
||||
return fmt.Sprintf("https://%s-%d.preview.app.github.dev", pi.codespace.Name, pi.Port.SourcePort)
|
||||
return fmt.Sprintf("https://%s-%d.app.github.dev", pi.codespace.Name, pi.Port.PortNumber)
|
||||
}
|
||||
|
||||
func (pi *portInfo) Label() string {
|
||||
if pi.devContainer != nil {
|
||||
portStr := strconv.Itoa(pi.Port.SourcePort)
|
||||
portStr := strconv.Itoa(int(pi.Port.PortNumber))
|
||||
if attributes, ok := pi.devContainer.PortAttributes[portStr]; ok {
|
||||
return attributes.Label
|
||||
}
|
||||
|
|
@ -150,7 +152,6 @@ func (pi *portInfo) Label() string {
|
|||
|
||||
var portFields = []string{
|
||||
"sourcePort",
|
||||
// "destinationPort", // TODO(mislav): this appears to always be blank?
|
||||
"visibility",
|
||||
"label",
|
||||
"browseUrl",
|
||||
|
|
@ -162,11 +163,9 @@ func (pi *portInfo) ExportData(fields []string) map[string]interface{} {
|
|||
for _, f := range fields {
|
||||
switch f {
|
||||
case "sourcePort":
|
||||
data[f] = pi.Port.SourcePort
|
||||
case "destinationPort":
|
||||
data[f] = pi.Port.DestinationPort
|
||||
data[f] = pi.Port.PortNumber
|
||||
case "visibility":
|
||||
data[f] = pi.Port.Privacy
|
||||
data[f] = portforwarder.AccessControlEntriesToVisibility(pi.Port.AccessControl.Entries)
|
||||
case "label":
|
||||
data[f] = pi.Label()
|
||||
case "browseUrl":
|
||||
|
|
@ -235,30 +234,6 @@ func newPortsVisibilityCmd(app *App, selector *CodespaceSelector) *cobra.Command
|
|||
}
|
||||
}
|
||||
|
||||
type ErrUpdatingPortVisibility struct {
|
||||
port int
|
||||
visibility string
|
||||
err error
|
||||
}
|
||||
|
||||
func newErrUpdatingPortVisibility(port int, visibility string, err error) *ErrUpdatingPortVisibility {
|
||||
return &ErrUpdatingPortVisibility{
|
||||
port: port,
|
||||
visibility: visibility,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ErrUpdatingPortVisibility) Error() string {
|
||||
return fmt.Sprintf("error waiting for port %d to update to %s: %s", e.port, e.visibility, e.err)
|
||||
}
|
||||
|
||||
func (e *ErrUpdatingPortVisibility) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting")
|
||||
|
||||
func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelector, args []string) (err error) {
|
||||
ports, err := a.parsePortVisibilities(args)
|
||||
if err != nil {
|
||||
|
|
@ -270,47 +245,28 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
|
||||
// TODO: check if port visibility can be updated in parallel instead of sequentially
|
||||
for _, port := range ports {
|
||||
err := a.RunWithProgress(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility), func() (err error) {
|
||||
// wait for success or failure
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
g.Go(func() error {
|
||||
updateNotif, err := session.WaitForPortNotification(ctx, port.number, liveshare.PortChangeKindUpdate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error waiting for port %d to update: %w", port.number, err)
|
||||
|
||||
}
|
||||
if !updateNotif.Success {
|
||||
if updateNotif.StatusCode == http.StatusForbidden {
|
||||
return newErrUpdatingPortVisibility(port.number, port.visibility, errUpdatePortVisibilityForbidden)
|
||||
}
|
||||
return newErrUpdatingPortVisibility(port.number, port.visibility, errors.New(updateNotif.ErrorDetail))
|
||||
|
||||
}
|
||||
return nil // success
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
err := session.UpdateSharedServerPrivacy(ctx, port.number, port.visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// wait for success or failure
|
||||
err = g.Wait()
|
||||
return
|
||||
err = fwd.UpdatePortVisibility(ctx, port.number, port.visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -367,11 +323,10 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
// Run forwarding of all ports concurrently, aborting all of
|
||||
// them at the first failure, including cancellation of the context.
|
||||
|
|
@ -386,9 +341,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
defer listen.Close()
|
||||
|
||||
a.errLogger.Printf("Forwarding ports: remote %d <=> local %d", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
|
||||
return fwd.ForwardToListener(ctx, listen) // error always non-nil
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false)
|
||||
})
|
||||
}
|
||||
return group.Wait() // first error
|
||||
|
|
|
|||
|
|
@ -2,18 +2,34 @@ package codespace
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestListPorts(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mockApi := GetMockApi(false)
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
a := &App{
|
||||
io: ios,
|
||||
apiClient: mockApi,
|
||||
}
|
||||
|
||||
selector := &CodespaceSelector{api: a.apiClient, codespaceName: "codespace-name"}
|
||||
err := a.ListPorts(ctx, selector, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilitySuccess(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
|
|
@ -26,175 +42,35 @@ func TestPortsUpdateVisibilitySuccess(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingSucceeded",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: true,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, true)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilityFailure403(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
{
|
||||
number: 9999,
|
||||
visibility: "public",
|
||||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingFailed",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: false,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
ErrorDetail: "test error",
|
||||
StatusCode: 403,
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
if err == nil {
|
||||
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
|
||||
}
|
||||
|
||||
if errors.Unwrap(err) != errUpdatePortVisibilityForbidden {
|
||||
t.Errorf("expected: %v, got: %v", errUpdatePortVisibilityForbidden, errors.Unwrap(err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilityFailure(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
{
|
||||
number: 9999,
|
||||
visibility: "public",
|
||||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingFailed",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: false,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
ErrorDetail: "test error",
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, false)
|
||||
if err == nil {
|
||||
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
|
||||
}
|
||||
|
||||
var expectedErr *ErrUpdatingPortVisibility
|
||||
if !errors.As(err, &expectedErr) {
|
||||
t.Errorf("expected: %v, got: %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error {
|
||||
t.Helper()
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
|
||||
}
|
||||
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
|
||||
ch := make(chan *jsonrpc2.Conn, 1)
|
||||
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
ch <- conn
|
||||
return nil, nil
|
||||
}
|
||||
testServer, err := livesharetest.NewServer(
|
||||
livesharetest.WithNonSecure(),
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create test server: %w", err)
|
||||
}
|
||||
|
||||
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, allowOrgPorts bool) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for i, pd := range portsData {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case conn := <-ch:
|
||||
_, _ = conn.DispatchCall(ctx, eventResponses[i], pd, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
mockApi := &apiClientMock{
|
||||
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
|
||||
return &api.Codespace{
|
||||
Name: "codespace-name",
|
||||
State: api.CodespaceStateAvailable,
|
||||
Connection: api.CodespaceConnection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: testServer.URL(),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockApi := GetMockApi(allowOrgPorts)
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
a := &App{
|
||||
|
|
@ -251,6 +127,44 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func GetMockApi(allowOrgPorts bool) *apiClientMock {
|
||||
return &apiClientMock{
|
||||
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
|
||||
allowedPortPrivacySettings := []string{"public", "private"}
|
||||
if allowOrgPorts {
|
||||
allowedPortPrivacySettings = append(allowedPortPrivacySettings, "org")
|
||||
}
|
||||
|
||||
return &api.Codespace{
|
||||
Name: "codespace-name",
|
||||
State: api.CodespaceStateAvailable,
|
||||
Connection: api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "tunnel access-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
},
|
||||
RuntimeConstraints: api.RuntimeConstraints{
|
||||
AllowedPortPrivacySettings: allowedPortPrivacySettings,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
StartCodespaceFunc: func(ctx context.Context, codespaceName string) error {
|
||||
return nil
|
||||
},
|
||||
GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
|
||||
return nil, nil
|
||||
},
|
||||
HTTPClientFunc: func() (*http.Client, error) {
|
||||
return connection.NewMockHttpClient()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testingPortsApp() *App {
|
||||
disabledCodespace := &api.Codespace{
|
||||
Name: "disabledCodespace",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue