Forward codespace ports over Dev Tunnels

This commit is contained in:
David Gardiner 2023-09-19 13:36:06 -07:00
parent 48b0d53d0e
commit e059f32aa5
13 changed files with 1271 additions and 303 deletions

View file

@ -201,6 +201,7 @@ type Codespace struct {
GitStatus CodespaceGitStatus `json:"git_status"`
Connection CodespaceConnection `json:"connection"`
Machine CodespaceMachine `json:"machine"`
RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"`
VSCSTarget string `json:"vscs_target"`
PendingOperation bool `json:"pending_operation"`
PendingOperationDisabledReason string `json:"pending_operation_disabled_reason"`
@ -246,11 +247,25 @@ const (
)
type CodespaceConnection struct {
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
TunnelProperties TunnelProperties `json:"tunnelProperties"`
}
type TunnelProperties struct {
ConnectAccessToken string `json:"connectAccessToken"`
ManagePortsAccessToken string `json:"managePortsAccessToken"`
ServiceUri string `json:"serviceUri"`
TunnelId string `json:"tunnelId"`
ClusterId string `json:"clusterId"`
Domain string `json:"domain"`
}
type RuntimeConstraints struct {
AllowedPortPrivacySettings []string `json:"allowed_port_privacy_settings"`
}
// ListCodespaceFields is the list of exportable fields for a codespace when using the `gh cs list` command.
@ -1162,3 +1177,13 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error
return nil, fmt.Errorf("received response with status code %d", resp.StatusCode)
}, backoff.WithMaxRetries(bo, 3))
}
// HTTPClient returns the HTTP client used to make requests to the API.
func (a *API) HTTPClient() (*http.Client, error) {
httpClient, err := a.client()
if err != nil {
return nil, err
}
return httpClient, nil
}

View file

@ -5,24 +5,42 @@ import (
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/cli/cli/v2/pkg/liveshare"
)
func connectionReady(codespace *api.Codespace) bool {
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
// If the codespace is not available, it is not ready
if codespace.State != api.CodespaceStateAvailable {
return false
}
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
if usingDevTunnels {
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
codespace.Connection.TunnelProperties.ServiceUri != "" &&
codespace.Connection.TunnelProperties.TunnelId != "" &&
codespace.Connection.TunnelProperties.ClusterId != "" &&
codespace.Connection.TunnelProperties.Domain != ""
}
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
return codespace.Connection.SessionID != "" &&
codespace.Connection.SessionToken != "" &&
codespace.Connection.RelayEndpoint != "" &&
codespace.Connection.RelaySAS != "" &&
codespace.State == api.CodespaceStateAvailable
codespace.Connection.RelaySAS != ""
}
type apiClient interface {
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
StartCodespace(ctx context.Context, name string) error
HTTPClient() (*http.Client, error)
}
type progressIndicator interface {
@ -43,9 +61,48 @@ func (e *TimeoutError) Error() string {
return e.message
}
// ConnectToLiveshare waits for a Codespace to become running,
// and connects to it using a Live Share session.
// GetCodespaceConnection waits until a codespace is able
// to be connected to and initializes a connection to it.
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
if err != nil {
return nil, err
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
httpClient, err := apiClient.HTTPClient()
if err != nil {
return nil, fmt.Errorf("error getting http client: %w", err)
}
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
}
// ConnectToLiveshare waits until a codespace is able to be
// connected to and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
if err != nil {
return nil, err
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
}
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
if codespace.State != api.CodespaceStateAvailable {
progress.StartProgressIndicatorWithLabel("Starting codespace")
defer progress.StopProgressIndicator()
@ -54,7 +111,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
}
}
if !connectionReady(codespace) {
if !connectionReady(codespace, usingDevTunnels) {
expBackoff := backoff.NewExponentialBackOff()
expBackoff.Multiplier = 1.1
expBackoff.MaxInterval = 10 * time.Second
@ -67,7 +124,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
}
if connectionReady(codespace) {
if connectionReady(codespace, usingDevTunnels) {
return nil
}
@ -83,17 +140,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
}
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
return codespace, nil
}
// ListenTCP starts a localhost tcp listener on 127.0.0.1 (unless allInterfaces is true) and returns the listener and bound port

View file

@ -0,0 +1,116 @@
package connection
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
clientName = "gh"
)
type CodespaceConnection struct {
tunnelProperties api.TunnelProperties
TunnelManager *tunnels.Manager
TunnelClient *tunnels.Client
Options *tunnels.TunnelRequestOptions
Tunnel *tunnels.Tunnel
AllowedPortPrivacySettings []string
}
// NewCodespaceConnection initializes a connection to a codespace.
// This connections allows for port forwarding which enables the
// use of most features of the codespace command.
func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) {
// Get the tunnel properties
tunnelProperties := codespace.Connection.TunnelProperties
// Create the tunnel manager
tunnelManager, err := getTunnelManager(tunnelProperties, httpClient)
if err != nil {
return nil, fmt.Errorf("error getting tunnel management client: %w", err)
}
// Calculate allowed port privacy settings
allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings
// Get the access tokens
connectToken := tunnelProperties.ConnectAccessToken
managementToken := tunnelProperties.ManagePortsAccessToken
// Create the tunnel definition
tunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken},
TunnelID: tunnelProperties.TunnelId,
ClusterID: tunnelProperties.ClusterId,
Domain: tunnelProperties.Domain,
}
// Create options
options := &tunnels.TunnelRequestOptions{
IncludePorts: true,
}
// Create the tunnel client (not connected yet)
tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel client: %w", err)
}
return &CodespaceConnection{
tunnelProperties: tunnelProperties,
TunnelManager: tunnelManager,
TunnelClient: tunnelClient,
Options: options,
Tunnel: tunnel,
AllowedPortPrivacySettings: allowedPortPrivacySettings,
}, nil
}
// getTunnelManager creates a tunnel manager for the given codespace.
// The tunnel manager is used to get the tunnel hosted in the codespace that we
// want to connect to and perform operations on ports (add, remove, list, etc.).
func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) {
userAgent := []tunnels.UserAgent{{Name: clientName}}
url, err := url.Parse(tunnelProperties.ServiceUri)
if err != nil {
return nil, fmt.Errorf("error parsing tunnel service uri: %w", err)
}
// Create the tunnel manager
tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient)
if err != nil {
return nil, fmt.Errorf("error creating tunnel manager: %w", err)
}
return tunnelManager, nil
}
// getTunnelClient creates a tunnel client for the given tunnel.
// The tunnel client is used to connect to the the tunnel and allows
// for ports to be forwarded locally.
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
// Get the tunnel that we want to connect to
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel: %w", err)
}
// Copy the access tokens from the tunnel definition
codespaceTunnel.AccessTokens = tunnel.AccessTokens
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
if err != nil {
return nil, fmt.Errorf("error creating tunnel client: %w", err)
}
return tunnelClient, nil
}

View file

@ -0,0 +1,75 @@
package connection
import (
"context"
"reflect"
"testing"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
func TestNewCodespaceConnection(t *testing.T) {
ctx := context.Background()
// Create a mock codespace
connection := api.CodespaceConnection{
TunnelProperties: api.TunnelProperties{
ConnectAccessToken: "connect-token",
ManagePortsAccessToken: "manage-ports-token",
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
TunnelId: "tunnel-id",
ClusterId: "usw2",
Domain: "domain.com",
},
}
allowedPortPrivacySettings := []string{"public", "private"}
codespace := &api.Codespace{
Connection: connection,
RuntimeConstraints: api.RuntimeConstraints{AllowedPortPrivacySettings: allowedPortPrivacySettings},
}
// Create the mock HTTP client
httpClient, err := NewMockHttpClient()
if err != nil {
t.Fatalf("NewHttpClient returned an error: %v", err)
}
// Create the connection
conn, err := NewCodespaceConnection(ctx, codespace, httpClient)
if err != nil {
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
}
// Check that the connection was created successfully
if conn == nil {
t.Fatal("NewCodespaceConnection returned nil")
}
// Verify that the connection contains the expected tunnel properties
if conn.tunnelProperties != connection.TunnelProperties {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel properties: %+v", conn.tunnelProperties)
}
// Verify that the connection contains the expected tunnel
expectedTunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connection.TunnelProperties.ConnectAccessToken, tunnels.TunnelAccessScopeManagePorts: connection.TunnelProperties.ManagePortsAccessToken},
TunnelID: connection.TunnelProperties.TunnelId,
ClusterID: connection.TunnelProperties.ClusterId,
Domain: connection.TunnelProperties.Domain,
}
if !reflect.DeepEqual(conn.Tunnel, expectedTunnel) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel: %+v", conn.Tunnel)
}
// Verify that the connection contains the expected tunnel options
expectedOptions := &tunnels.TunnelRequestOptions{IncludePorts: true}
if !reflect.DeepEqual(conn.Options, expectedOptions) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected options: %+v", conn.Options)
}
// Verify that the connection contains the expected allowed port privacy settings
if !reflect.DeepEqual(conn.AllowedPortPrivacySettings, allowedPortPrivacySettings) {
t.Fatalf("NewCodespaceConnection returned a connection with unexpected allowed port privacy settings: %+v", conn.AllowedPortPrivacySettings)
}
}

View file

@ -0,0 +1,396 @@
package connection
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/microsoft/dev-tunnels/go/tunnels"
tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh"
"github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages"
"golang.org/x/crypto/ssh"
)
func NewMockHttpClient() (*http.Client, error) {
accessToken := "tunnel access-token"
relayServer, err := newMockrelayServer(withAccessToken(accessToken))
if err != nil {
return nil, fmt.Errorf("NewrelayServer returned an error: %w", err)
}
hostURL := strings.Replace(relayServer.URL(), "http://", "ws://", 1)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var response []byte
if r.URL.Path == "/api/v1/tunnels/tunnel-id" {
tunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{
tunnels.TunnelAccessScopeConnect: accessToken,
},
Endpoints: []tunnels.TunnelEndpoint{
{
HostID: "host1",
TunnelRelayTunnelEndpoint: tunnels.TunnelRelayTunnelEndpoint{
ClientRelayURI: hostURL,
},
},
},
}
response, err = json.Marshal(*tunnel)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
} else if strings.HasPrefix(r.URL.Path, "/api/v1/tunnels/tunnel-id/ports") {
// Use regex to check if the path ends with a number
match, err := regexp.MatchString(`\/\d+$`, r.URL.Path)
if err != nil {
log.Fatalf("regexp.MatchString returned an error: %v", err)
}
// If the path ends with a number, it's a request for a specific port
if match || r.Method == http.MethodPost {
if r.Method == http.MethodDelete {
w.WriteHeader(http.StatusOK)
return
}
tunnelPort := &tunnels.TunnelPort{
AccessControl: &tunnels.TunnelAccessControl{
Entries: []tunnels.TunnelAccessControlEntry{},
},
}
// Convert the tunnel to JSON and write it to the response
response, err = json.Marshal(*tunnelPort)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
} else {
// If the path doesn't end with a number and we aren't making a POST request, return an array of ports
tunnelPorts := []tunnels.TunnelPort{
{
AccessControl: &tunnels.TunnelAccessControl{
Entries: []tunnels.TunnelAccessControlEntry{},
},
},
}
response, err = json.Marshal(tunnelPorts)
if err != nil {
log.Fatalf("json.Marshal returned an error: %v", err)
}
}
} else {
w.WriteHeader(http.StatusNotFound)
return
}
// Write the response
_, _ = w.Write(response)
}))
url, err := url.Parse(mockServer.URL)
if err != nil {
return nil, fmt.Errorf("url.Parse returned an error: %w", err)
}
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(url),
},
}, nil
}
type relayServer struct {
httpServer *httptest.Server
errc chan error
sshConfig *ssh.ServerConfig
channels map[string]channelHandler
accessToken string
serverConn *ssh.ServerConn
}
type relayServerOption func(*relayServer)
type channelHandler func(context.Context, ssh.NewChannel) error
func newMockrelayServer(opts ...relayServerOption) (*relayServer, error) {
server := &relayServer{
errc: make(chan error),
sshConfig: &ssh.ServerConfig{
NoClientAuth: true,
},
}
// Create a private key with the crypto package
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
privateKeyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)
// Parse the private key
sshPrivateKey, err := ssh.ParsePrivateKey(privateKeyPEM)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
server.sshConfig.AddHostKey(ssh.Signer(sshPrivateKey))
server.httpServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
for _, opt := range opts {
opt(server)
}
return server, nil
}
func withAccessToken(accessToken string) func(*relayServer) {
return func(server *relayServer) {
server.accessToken = accessToken
}
}
func (rs *relayServer) URL() string {
return rs.httpServer.URL
}
func (rs *relayServer) Err() <-chan error {
return rs.errc
}
func (rs *relayServer) sendError(err error) {
select {
case rs.errc <- err:
default:
// channel is blocked with a previous error, so we ignore this one
}
}
func (rs *relayServer) ForwardPort(ctx context.Context, port uint16) error {
pfr := messages.NewPortForwardRequest("127.0.0.1", uint32(port))
b, err := pfr.Marshal()
if err != nil {
return fmt.Errorf("error marshaling port forward request: %w", err)
}
replied, data, err := rs.serverConn.SendRequest(messages.PortForwardRequestType, true, b)
if err != nil {
return fmt.Errorf("error sending port forward request: %w", err)
}
if !replied {
return fmt.Errorf("port forward request not replied")
}
if data == nil {
return fmt.Errorf("no data returned")
}
return nil
}
func makeConnection(server *relayServer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.accessToken != "" {
if r.Header.Get("Authorization") != server.accessToken {
server.sendError(fmt.Errorf("invalid access token"))
return
}
}
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
server.sendError(fmt.Errorf("error upgrading to websocket: %w", err))
return
}
defer func() {
if err := c.Close(); err != nil {
server.sendError(fmt.Errorf("error closing websocket: %w", err))
}
}()
socketConn := newSocketConn(c)
serverConn, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
server.sendError(fmt.Errorf("error creating ssh server conn: %w", err))
return
}
go handleRequests(ctx, convertRequests(reqs))
server.serverConn = serverConn
if err := handleChannels(ctx, server, chans); err != nil {
server.sendError(fmt.Errorf("error handling channels: %w", err))
return
}
}
}
func (sr *sshRequest) Type() string {
return sr.request.Type
}
type sshRequest struct {
request *ssh.Request
}
// Reply method for sshRequest to satisfy the tunnelssh.SSHRequest interface
func (sr *sshRequest) Reply(success bool, message []byte) error {
return sr.request.Reply(success, message)
}
// convertRequests function
func convertRequests(reqs <-chan *ssh.Request) <-chan tunnelssh.SSHRequest {
out := make(chan tunnelssh.SSHRequest)
go func() {
for req := range reqs {
out <- &sshRequest{req}
}
close(out)
}()
return out
}
func handleChannels(ctx context.Context, server *relayServer, chans <-chan ssh.NewChannel) error {
errc := make(chan error, 1)
go func() {
for ch := range chans {
if handler, ok := server.channels[ch.ChannelType()]; ok {
if err := handler(ctx, ch); err != nil {
errc <- err
return
}
} else {
// generic accept of the channel to not block
_, _, err := ch.Accept()
if err != nil {
errc <- fmt.Errorf("error accepting channel: %w", err)
return
}
}
}
}()
return awaitError(ctx, errc)
}
func handleRequests(ctx context.Context, reqs <-chan tunnelssh.SSHRequest) {
for {
select {
case <-ctx.Done():
return
case req, ok := <-reqs:
if !ok {
return
}
if req.Type() == "RefreshPorts" {
_ = req.Reply(true, nil)
continue
} else {
_ = req.Reply(false, nil)
}
}
}
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errc:
return err
}
}
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %w", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}

View file

@ -0,0 +1,253 @@
package portforwarder
import (
"context"
"fmt"
"net"
"strings"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
githubSubjectId = "1"
InternalPortTag = "InternalPort"
UserForwardedPortTag = "UserForwardedPort"
)
const (
PrivatePortVisibility = "private"
OrgPortVisibility = "org"
PublicPortVisibility = "public"
)
type PortForwarder struct {
connection connection.CodespaceConnection
}
// NewPortForwarder returns a new PortForwarder for the specified codespace.
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
return &PortForwarder{
connection: *codespaceConnection,
}, nil
}
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
}
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
// If no visibility is provided, Dev Tunnels will use the default (private)
if visibility != "" {
// Check if the requested visibility is allowed
allowed := false
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
if allowedVisibility == visibility {
allowed = true
break
}
}
// If the requested visibility is not allowed, return an error
if !allowed {
return fmt.Errorf("visibility %s is not allowed", visibility)
}
accessControlEntries := visibilityToAccessControlEntries(visibility)
if len(accessControlEntries) > 0 {
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
Entries: accessControlEntries,
}
}
}
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
if internal {
tunnelPort.Tags = []string{InternalPortTag}
} else {
tunnelPort.Tags = []string{UserForwardedPortTag}
}
// Create the tunnel port
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
if err != nil && !strings.Contains(err.Error(), "409") {
return fmt.Errorf("create tunnel port failed: %v", err)
}
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
// If we don't want to connect to the port, exit early
if !connect {
done <- nil
return
}
// Ensure the port is forwarded before connecting
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
if err != nil {
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
return
}
// Connect to the forwarded port via a local TCP port
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
return
}
done <- nil
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// ListPorts fetches the list of ports that are currently forwarded.
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
if err != nil {
return nil, fmt.Errorf("error listing ports: %w", err)
}
return ports, nil
}
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
if err != nil {
return fmt.Errorf("error getting tunnel port: %w", err)
}
// If the port visibility isn't changing, don't do anything
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
return nil
}
// Delete the existing tunnel port to update
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options)
if err != nil {
return fmt.Errorf("error deleting tunnel port: %w", err)
}
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've deleted the port
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
done <- nil
}()
// Wait for the done channel to be closed
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
// Re-forward the port with the updated visibility
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
if err != nil {
return fmt.Errorf("error forwarding port: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
for _, entry := range accessControlEntries {
// If we have the anonymous type (and we're not denying it), it's public
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
return PublicPortVisibility
}
// If we have the organizations type (and we're not denying it), it's org
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
return OrgPortVisibility
}
}
// Else, it's private
return PrivatePortVisibility
}
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
switch visibility {
case PublicPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
Subjects: []string{},
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
}}
case OrgPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{
string(tunnels.TunnelAccessScopeConnect),
},
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
}}
default:
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{},
IsDeny: true,
}}
}
}
// IsInternalPort returns true if the port is internal.
func IsInternalPort(port *tunnels.TunnelPort) bool {
for _, tag := range port.Tags {
if strings.EqualFold(tag, InternalPortTag) {
return true
}
}
return false
}

View file

@ -0,0 +1,139 @@
package portforwarder
import (
"context"
"testing"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
func TestNewPortForwarder(t *testing.T) {
ctx := context.Background()
// Create a mock codespace
codespace := &api.Codespace{
Connection: api.CodespaceConnection{
TunnelProperties: api.TunnelProperties{
ConnectAccessToken: "connect-token",
ManagePortsAccessToken: "manage-ports-token",
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
TunnelId: "tunnel-id",
ClusterId: "usw2",
Domain: "domain.com",
},
},
RuntimeConstraints: api.RuntimeConstraints{
AllowedPortPrivacySettings: []string{"public", "private"},
},
}
// Create the mock HTTP client
httpClient, err := connection.NewMockHttpClient()
if err != nil {
t.Fatalf("NewHttpClient returned an error: %v", err)
}
// Call the function being tested
conn, err := connection.NewCodespaceConnection(ctx, codespace, httpClient)
if err != nil {
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
}
// Create the new port forwarder
portForwarder, err := NewPortForwarder(ctx, conn)
if err != nil {
t.Fatalf("NewPortForwarder returned an error: %v", err)
}
// Check that the port forwarder was created successfully
if portForwarder == nil {
t.Fatal("NewPortForwarder returned nil")
}
}
func TestAccessControlEntriesToVisibility(t *testing.T) {
publicAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
}}
orgAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
}}
privateAccessControlEntry := []tunnels.TunnelAccessControlEntry{}
orgIsDenyAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
IsDeny: true,
}}
tests := []struct {
name string
accessControlEntries []tunnels.TunnelAccessControlEntry
expected string
}{
{
name: "public",
accessControlEntries: publicAccessControlEntry,
expected: PublicPortVisibility,
},
{
name: "org",
accessControlEntries: orgAccessControlEntry,
expected: OrgPortVisibility,
},
{
name: "private",
accessControlEntries: privateAccessControlEntry,
expected: PrivatePortVisibility,
},
{
name: "orgIsDeny",
accessControlEntries: orgIsDenyAccessControlEntry,
expected: PrivatePortVisibility,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
visibility := AccessControlEntriesToVisibility(test.accessControlEntries)
if visibility != test.expected {
t.Errorf("expected %q, got %q", test.expected, visibility)
}
})
}
}
func TestIsInternalPort(t *testing.T) {
internalPort := &tunnels.TunnelPort{
Tags: []string{"InternalPort"},
}
userForwardedPort := &tunnels.TunnelPort{
Tags: []string{"UserForwardedPort"},
}
tests := []struct {
name string
port *tunnels.TunnelPort
expected bool
}{
{
name: "internal",
port: internalPort,
expected: true,
},
{
name: "user-forwarded",
port: userForwardedPort,
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
isInternal := IsInternalPort(test.port)
if isInternal != test.expected {
t.Errorf("expected %v, got %v", test.expected, isInternal)
}
})
}
}