Forward codespace ports over Dev Tunnels
This commit is contained in:
parent
48b0d53d0e
commit
e059f32aa5
13 changed files with 1271 additions and 303 deletions
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -104,6 +105,7 @@ type apiClient interface {
|
|||
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
|
||||
GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)
|
||||
GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error)
|
||||
HTTPClient() (*http.Client, error)
|
||||
}
|
||||
|
||||
var errNoCodespaces = errors.New("you have no codespaces")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package codespace
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
codespacesAPI "github.com/cli/cli/v2/internal/codespaces/api"
|
||||
|
|
@ -40,15 +41,15 @@ import (
|
|||
// GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) {
|
||||
// panic("mock out the GetCodespacesMachines method")
|
||||
// },
|
||||
// HTTPClientFunc: func() (*http.Client, error) {
|
||||
// panic("mock out the HTTPClient method")
|
||||
// },
|
||||
// GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
|
||||
// panic("mock out the GetOrgMemberCodespace method")
|
||||
// },
|
||||
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error) {
|
||||
// panic("mock out the GetRepository method")
|
||||
// },
|
||||
// ServerURLFunc: func() string {
|
||||
// panic("mock out the ServerURL method")
|
||||
// },
|
||||
// GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) {
|
||||
// panic("mock out the GetUser method")
|
||||
// },
|
||||
|
|
@ -58,6 +59,9 @@ import (
|
|||
// ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error) {
|
||||
// panic("mock out the ListDevContainers method")
|
||||
// },
|
||||
// ServerURLFunc: func() string {
|
||||
// panic("mock out the ServerURL method")
|
||||
// },
|
||||
// StartCodespaceFunc: func(ctx context.Context, name string) error {
|
||||
// panic("mock out the StartCodespace method")
|
||||
// },
|
||||
|
|
@ -95,15 +99,15 @@ type apiClientMock struct {
|
|||
// GetCodespacesMachinesFunc mocks the GetCodespacesMachines method.
|
||||
GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error)
|
||||
|
||||
// HTTPClientFunc mocks the HTTPClient method.
|
||||
HTTPClientFunc func() (*http.Client, error)
|
||||
|
||||
// GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method.
|
||||
GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error)
|
||||
|
||||
// GetRepositoryFunc mocks the GetRepository method.
|
||||
GetRepositoryFunc func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error)
|
||||
|
||||
// ServerURLFunc mocks the ServerURL method.
|
||||
ServerURLFunc func() string
|
||||
|
||||
// GetUserFunc mocks the GetUser method.
|
||||
GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error)
|
||||
|
||||
|
|
@ -113,6 +117,9 @@ type apiClientMock struct {
|
|||
// ListDevContainersFunc mocks the ListDevContainers method.
|
||||
ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error)
|
||||
|
||||
// ServerURLFunc mocks the ServerURL method.
|
||||
ServerURLFunc func() string
|
||||
|
||||
// StartCodespaceFunc mocks the StartCodespace method.
|
||||
StartCodespaceFunc func(ctx context.Context, name string) error
|
||||
|
||||
|
|
@ -195,6 +202,9 @@ type apiClientMock struct {
|
|||
// DevcontainerPath is the devcontainerPath argument value.
|
||||
DevcontainerPath string
|
||||
}
|
||||
// HTTPClient holds details about calls to the HTTPClient method.
|
||||
HTTPClient []struct {
|
||||
}
|
||||
// GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method.
|
||||
GetOrgMemberCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -213,9 +223,6 @@ type apiClientMock struct {
|
|||
// Nwo is the nwo argument value.
|
||||
Nwo string
|
||||
}
|
||||
// ServerURL holds details about calls to the ServerURL method.
|
||||
ServerURL []struct {
|
||||
}
|
||||
// GetUser holds details about calls to the GetUser method.
|
||||
GetUser []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -239,6 +246,9 @@ type apiClientMock struct {
|
|||
// Limit is the limit argument value.
|
||||
Limit int
|
||||
}
|
||||
// ServerURL holds details about calls to the ServerURL method.
|
||||
ServerURL []struct {
|
||||
}
|
||||
// StartCodespace holds details about calls to the StartCodespace method.
|
||||
StartCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
|
|
@ -266,12 +276,13 @@ type apiClientMock struct {
|
|||
lockGetCodespaceRepoSuggestions sync.RWMutex
|
||||
lockGetCodespaceRepositoryContents sync.RWMutex
|
||||
lockGetCodespacesMachines sync.RWMutex
|
||||
lockHTTPClient sync.RWMutex
|
||||
lockGetOrgMemberCodespace sync.RWMutex
|
||||
lockGetRepository sync.RWMutex
|
||||
lockServerURL sync.RWMutex
|
||||
lockGetUser sync.RWMutex
|
||||
lockListCodespaces sync.RWMutex
|
||||
lockListDevContainers sync.RWMutex
|
||||
lockServerURL sync.RWMutex
|
||||
lockStartCodespace sync.RWMutex
|
||||
lockStopCodespace sync.RWMutex
|
||||
}
|
||||
|
|
@ -600,6 +611,33 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// HTTPClient calls HTTPClientFunc.
|
||||
func (mock *apiClientMock) HTTPClient() (*http.Client, error) {
|
||||
if mock.HTTPClientFunc == nil {
|
||||
panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockHTTPClient.Lock()
|
||||
mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo)
|
||||
mock.lockHTTPClient.Unlock()
|
||||
return mock.HTTPClientFunc()
|
||||
}
|
||||
|
||||
// HTTPClientCalls gets all the calls that were made to HTTPClient.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.HTTPClientCalls())
|
||||
func (mock *apiClientMock) HTTPClientCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockHTTPClient.RLock()
|
||||
calls = mock.calls.HTTPClient
|
||||
mock.lockHTTPClient.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetOrgMemberCodespace calls GetOrgMemberCodespaceFunc.
|
||||
func (mock *apiClientMock) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
|
||||
if mock.GetOrgMemberCodespaceFunc == nil {
|
||||
|
|
@ -680,33 +718,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// ServerURL calls ServerURLFunc.
|
||||
func (mock *apiClientMock) ServerURL() string {
|
||||
if mock.ServerURLFunc == nil {
|
||||
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockServerURL.Lock()
|
||||
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
|
||||
mock.lockServerURL.Unlock()
|
||||
return mock.ServerURLFunc()
|
||||
}
|
||||
|
||||
// ServerURLCalls gets all the calls that were made to ServerURL.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.ServerURLCalls())
|
||||
func (mock *apiClientMock) ServerURLCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockServerURL.RLock()
|
||||
calls = mock.calls.ServerURL
|
||||
mock.lockServerURL.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetUser calls GetUserFunc.
|
||||
func (mock *apiClientMock) GetUser(ctx context.Context) (*codespacesAPI.User, error) {
|
||||
if mock.GetUserFunc == nil {
|
||||
|
|
@ -819,6 +830,33 @@ func (mock *apiClientMock) ListDevContainersCalls() []struct {
|
|||
return calls
|
||||
}
|
||||
|
||||
// ServerURL calls ServerURLFunc.
|
||||
func (mock *apiClientMock) ServerURL() string {
|
||||
if mock.ServerURLFunc == nil {
|
||||
panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockServerURL.Lock()
|
||||
mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo)
|
||||
mock.lockServerURL.Unlock()
|
||||
return mock.ServerURLFunc()
|
||||
}
|
||||
|
||||
// ServerURLCalls gets all the calls that were made to ServerURL.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedapiClient.ServerURLCalls())
|
||||
func (mock *apiClientMock) ServerURLCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockServerURL.RLock()
|
||||
calls = mock.calls.ServerURL
|
||||
mock.lockServerURL.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// StartCodespace calls StartCodespaceFunc.
|
||||
func (mock *apiClientMock) StartCodespace(ctx context.Context, name string) error {
|
||||
if mock.StartCodespaceFunc == nil {
|
||||
|
|
|
|||
|
|
@ -6,26 +6,21 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/tableprinter"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/cli/cli/v2/utils"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
"github.com/muhammadmuzzammil1998/jsonc"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
vscodeServerPortName = "VSCodeServerInternal"
|
||||
codespacesInternalPortName = "CodespacesInternal"
|
||||
)
|
||||
|
||||
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
|
||||
// according to the specified flags.
|
||||
func newPortsCmd(app *App) *cobra.Command {
|
||||
|
|
@ -62,15 +57,19 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
|
||||
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, false, "")
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
var ports []*liveshare.Port
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
|
||||
var ports []*tunnels.TunnelPort
|
||||
err = a.RunWithProgress("Fetching ports", func() (err error) {
|
||||
ports, err = session.GetSharedServers(ctx)
|
||||
ports, err = fwd.ListPorts(ctx)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -87,9 +86,10 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
|
||||
for _, p := range ports {
|
||||
// filter out internal ports from list
|
||||
if strings.HasPrefix(p.SessionName, vscodeServerPortName) || strings.HasPrefix(p.SessionName, codespacesInternalPortName) {
|
||||
if portforwarder.IsInternalPort(p) {
|
||||
continue
|
||||
}
|
||||
|
||||
portInfos = append(portInfos, &portInfo{
|
||||
Port: p,
|
||||
codespace: codespace,
|
||||
|
|
@ -107,40 +107,42 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
}
|
||||
|
||||
cs := a.io.ColorScheme()
|
||||
//nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter
|
||||
tp := utils.NewTablePrinter(a.io)
|
||||
tp := tableprinter.New(a.io)
|
||||
|
||||
if tp.IsTTY() {
|
||||
tp.AddField("LABEL", nil, nil)
|
||||
tp.AddField("PORT", nil, nil)
|
||||
tp.AddField("VISIBILITY", nil, nil)
|
||||
tp.AddField("BROWSE URL", nil, nil)
|
||||
if a.io.IsStdoutTTY() {
|
||||
tp.AddField("LABEL")
|
||||
tp.AddField("PORT")
|
||||
tp.AddField("VISIBILITY")
|
||||
tp.AddField("BROWSE URL")
|
||||
tp.EndRow()
|
||||
}
|
||||
|
||||
for _, port := range portInfos {
|
||||
tp.AddField(port.Label(), nil, nil)
|
||||
tp.AddField(strconv.Itoa(port.SourcePort), nil, cs.Yellow)
|
||||
tp.AddField(port.Privacy, nil, nil)
|
||||
tp.AddField(port.BrowseURL(), nil, nil)
|
||||
// Convert the ACE to a friendly visibility string (private, org, public)
|
||||
visibility := portforwarder.AccessControlEntriesToVisibility(port.Port.AccessControl.Entries)
|
||||
|
||||
tp.AddField(port.Label())
|
||||
tp.AddField(cs.Yellow(fmt.Sprintf("%d", port.Port.PortNumber)))
|
||||
tp.AddField(visibility)
|
||||
tp.AddField(port.BrowseURL())
|
||||
tp.EndRow()
|
||||
}
|
||||
return tp.Render()
|
||||
}
|
||||
|
||||
type portInfo struct {
|
||||
*liveshare.Port
|
||||
Port *tunnels.TunnelPort
|
||||
codespace *api.Codespace
|
||||
devContainer *devContainer
|
||||
}
|
||||
|
||||
func (pi *portInfo) BrowseURL() string {
|
||||
return fmt.Sprintf("https://%s-%d.preview.app.github.dev", pi.codespace.Name, pi.Port.SourcePort)
|
||||
return fmt.Sprintf("https://%s-%d.app.github.dev", pi.codespace.Name, pi.Port.PortNumber)
|
||||
}
|
||||
|
||||
func (pi *portInfo) Label() string {
|
||||
if pi.devContainer != nil {
|
||||
portStr := strconv.Itoa(pi.Port.SourcePort)
|
||||
portStr := strconv.Itoa(int(pi.Port.PortNumber))
|
||||
if attributes, ok := pi.devContainer.PortAttributes[portStr]; ok {
|
||||
return attributes.Label
|
||||
}
|
||||
|
|
@ -150,7 +152,6 @@ func (pi *portInfo) Label() string {
|
|||
|
||||
var portFields = []string{
|
||||
"sourcePort",
|
||||
// "destinationPort", // TODO(mislav): this appears to always be blank?
|
||||
"visibility",
|
||||
"label",
|
||||
"browseUrl",
|
||||
|
|
@ -162,11 +163,9 @@ func (pi *portInfo) ExportData(fields []string) map[string]interface{} {
|
|||
for _, f := range fields {
|
||||
switch f {
|
||||
case "sourcePort":
|
||||
data[f] = pi.Port.SourcePort
|
||||
case "destinationPort":
|
||||
data[f] = pi.Port.DestinationPort
|
||||
data[f] = pi.Port.PortNumber
|
||||
case "visibility":
|
||||
data[f] = pi.Port.Privacy
|
||||
data[f] = portforwarder.AccessControlEntriesToVisibility(pi.Port.AccessControl.Entries)
|
||||
case "label":
|
||||
data[f] = pi.Label()
|
||||
case "browseUrl":
|
||||
|
|
@ -235,30 +234,6 @@ func newPortsVisibilityCmd(app *App, selector *CodespaceSelector) *cobra.Command
|
|||
}
|
||||
}
|
||||
|
||||
type ErrUpdatingPortVisibility struct {
|
||||
port int
|
||||
visibility string
|
||||
err error
|
||||
}
|
||||
|
||||
func newErrUpdatingPortVisibility(port int, visibility string, err error) *ErrUpdatingPortVisibility {
|
||||
return &ErrUpdatingPortVisibility{
|
||||
port: port,
|
||||
visibility: visibility,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ErrUpdatingPortVisibility) Error() string {
|
||||
return fmt.Sprintf("error waiting for port %d to update to %s: %s", e.port, e.visibility, e.err)
|
||||
}
|
||||
|
||||
func (e *ErrUpdatingPortVisibility) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting")
|
||||
|
||||
func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelector, args []string) (err error) {
|
||||
ports, err := a.parsePortVisibilities(args)
|
||||
if err != nil {
|
||||
|
|
@ -270,47 +245,28 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
|
||||
// TODO: check if port visibility can be updated in parallel instead of sequentially
|
||||
for _, port := range ports {
|
||||
err := a.RunWithProgress(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility), func() (err error) {
|
||||
// wait for success or failure
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
g.Go(func() error {
|
||||
updateNotif, err := session.WaitForPortNotification(ctx, port.number, liveshare.PortChangeKindUpdate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error waiting for port %d to update: %w", port.number, err)
|
||||
|
||||
}
|
||||
if !updateNotif.Success {
|
||||
if updateNotif.StatusCode == http.StatusForbidden {
|
||||
return newErrUpdatingPortVisibility(port.number, port.visibility, errUpdatePortVisibilityForbidden)
|
||||
}
|
||||
return newErrUpdatingPortVisibility(port.number, port.visibility, errors.New(updateNotif.ErrorDetail))
|
||||
|
||||
}
|
||||
return nil // success
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
err := session.UpdateSharedServerPrivacy(ctx, port.number, port.visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// wait for success or failure
|
||||
err = g.Wait()
|
||||
return
|
||||
err = fwd.UpdatePortVisibility(ctx, port.number, port.visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -367,11 +323,10 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
// Run forwarding of all ports concurrently, aborting all of
|
||||
// them at the first failure, including cancellation of the context.
|
||||
|
|
@ -386,9 +341,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
defer listen.Close()
|
||||
|
||||
a.errLogger.Printf("Forwarding ports: remote %d <=> local %d", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
|
||||
return fwd.ForwardToListener(ctx, listen) // error always non-nil
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false)
|
||||
})
|
||||
}
|
||||
return group.Wait() // first error
|
||||
|
|
|
|||
|
|
@ -2,18 +2,34 @@ package codespace
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestListPorts(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mockApi := GetMockApi(false)
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
a := &App{
|
||||
io: ios,
|
||||
apiClient: mockApi,
|
||||
}
|
||||
|
||||
selector := &CodespaceSelector{api: a.apiClient, codespaceName: "codespace-name"}
|
||||
err := a.ListPorts(ctx, selector, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilitySuccess(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
|
|
@ -26,175 +42,35 @@ func TestPortsUpdateVisibilitySuccess(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingSucceeded",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: true,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, true)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilityFailure403(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
{
|
||||
number: 9999,
|
||||
visibility: "public",
|
||||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingFailed",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: false,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
ErrorDetail: "test error",
|
||||
StatusCode: 403,
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
if err == nil {
|
||||
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
|
||||
}
|
||||
|
||||
if errors.Unwrap(err) != errUpdatePortVisibilityForbidden {
|
||||
t.Errorf("expected: %v, got: %v", errUpdatePortVisibilityForbidden, errors.Unwrap(err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsUpdateVisibilityFailure(t *testing.T) {
|
||||
portVisibilities := []portVisibility{
|
||||
{
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
{
|
||||
number: 9999,
|
||||
visibility: "public",
|
||||
},
|
||||
}
|
||||
|
||||
eventResponses := []string{
|
||||
"serverSharing.sharingSucceeded",
|
||||
"serverSharing.sharingFailed",
|
||||
}
|
||||
|
||||
portsData := []liveshare.PortNotification{
|
||||
{
|
||||
Success: true,
|
||||
Port: 80,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
},
|
||||
{
|
||||
Success: false,
|
||||
Port: 9999,
|
||||
ChangeKind: liveshare.PortChangeKindUpdate,
|
||||
ErrorDetail: "test error",
|
||||
number: 80,
|
||||
visibility: "org",
|
||||
},
|
||||
}
|
||||
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData)
|
||||
err := runUpdateVisibilityTest(t, portVisibilities, false)
|
||||
if err == nil {
|
||||
t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly")
|
||||
}
|
||||
|
||||
var expectedErr *ErrUpdatingPortVisibility
|
||||
if !errors.As(err, &expectedErr) {
|
||||
t.Errorf("expected: %v, got: %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error {
|
||||
t.Helper()
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
|
||||
}
|
||||
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
|
||||
ch := make(chan *jsonrpc2.Conn, 1)
|
||||
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
ch <- conn
|
||||
return nil, nil
|
||||
}
|
||||
testServer, err := livesharetest.NewServer(
|
||||
livesharetest.WithNonSecure(),
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create test server: %w", err)
|
||||
}
|
||||
|
||||
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, allowOrgPorts bool) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for i, pd := range portsData {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case conn := <-ch:
|
||||
_, _ = conn.DispatchCall(ctx, eventResponses[i], pd, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
mockApi := &apiClientMock{
|
||||
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
|
||||
return &api.Codespace{
|
||||
Name: "codespace-name",
|
||||
State: api.CodespaceStateAvailable,
|
||||
Connection: api.CodespaceConnection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: testServer.URL(),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockApi := GetMockApi(allowOrgPorts)
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
a := &App{
|
||||
|
|
@ -251,6 +127,44 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func GetMockApi(allowOrgPorts bool) *apiClientMock {
|
||||
return &apiClientMock{
|
||||
GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) {
|
||||
allowedPortPrivacySettings := []string{"public", "private"}
|
||||
if allowOrgPorts {
|
||||
allowedPortPrivacySettings = append(allowedPortPrivacySettings, "org")
|
||||
}
|
||||
|
||||
return &api.Codespace{
|
||||
Name: "codespace-name",
|
||||
State: api.CodespaceStateAvailable,
|
||||
Connection: api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "tunnel access-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
},
|
||||
RuntimeConstraints: api.RuntimeConstraints{
|
||||
AllowedPortPrivacySettings: allowedPortPrivacySettings,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
StartCodespaceFunc: func(ctx context.Context, codespaceName string) error {
|
||||
return nil
|
||||
},
|
||||
GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
|
||||
return nil, nil
|
||||
},
|
||||
HTTPClientFunc: func() (*http.Client, error) {
|
||||
return connection.NewMockHttpClient()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testingPortsApp() *App {
|
||||
disabledCodespace := &api.Codespace{
|
||||
Name: "disabledCodespace",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue