Merge branch 'trunk' into gh-ext-browse-followup

This commit is contained in:
Nate Smith 2023-02-02 12:27:17 -08:00 committed by GitHub
commit 5dea3a923f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 997 additions and 256 deletions

View file

@ -64,6 +64,8 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
}
client.Transport = AddASCIISanitizer(client.Transport)
return client, nil
}

View file

@ -10,10 +10,10 @@ import (
const (
errorProjectsV2ReadScope = "field requires one of the following scopes: ['read:project']"
errorProjectsV2RepositoryField = "Field 'ProjectsV2' doesn't exist on type 'Repository'"
errorProjectsV2OrganizationField = "Field 'ProjectsV2' doesn't exist on type 'Organization'"
errorProjectsV2IssueField = "Field 'ProjectItems' doesn't exist on type 'Issue'"
errorProjectsV2PullRequestField = "Field 'ProjectItems' doesn't exist on type 'PullRequest'"
errorProjectsV2RepositoryField = "Field 'projectsV2' doesn't exist on type 'Repository'"
errorProjectsV2OrganizationField = "Field 'projectsV2' doesn't exist on type 'Organization'"
errorProjectsV2IssueField = "Field 'projectItems' doesn't exist on type 'Issue'"
errorProjectsV2PullRequestField = "Field 'projectItems' doesn't exist on type 'PullRequest'"
)
// UpdateProjectV2Items uses the addProjectV2ItemById and the deleteProjectV2Item mutations

View file

@ -221,22 +221,22 @@ func TestProjectsV2IgnorableError(t *testing.T) {
},
{
name: "repository projectsV2 field error",
errMsg: "Field 'ProjectsV2' doesn't exist on type 'Repository'",
errMsg: "Field 'projectsV2' doesn't exist on type 'Repository'",
expectOut: true,
},
{
name: "organization projectsV2 field error",
errMsg: "Field 'ProjectsV2' doesn't exist on type 'Organization'",
errMsg: "Field 'projectsV2' doesn't exist on type 'Organization'",
expectOut: true,
},
{
name: "issue projectItems field error",
errMsg: "Field 'ProjectItems' doesn't exist on type 'Issue'",
errMsg: "Field 'projectItems' doesn't exist on type 'Issue'",
expectOut: true,
},
{
name: "pullRequest projectItems field error",
errMsg: "Field 'ProjectItems' doesn't exist on type 'PullRequest'",
errMsg: "Field 'projectItems' doesn't exist on type 'PullRequest'",
expectOut: true,
},
{

193
api/sanitize_ascii.go Normal file
View file

@ -0,0 +1,193 @@
package api
import (
"bytes"
"errors"
"io"
"net/http"
"regexp"
"strings"
)
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
// GitHub servers return non-printable characters as their unicode code point values.
// The values of \u0000 to \u001F represent C0 ASCII control characters and
// the values of \u0080 to \u009F represent C1 ASCII control characters. These control
// characters will be interpreted by the terminal, this behaviour can be used maliciously
// as an attack vector, especially the control character \u001B. This function wraps
// JSON response bodies in a ReadCloser that transforms C0 and C1 control characters
// to their caret and hex notations respectively so that the terminal will not interpret them.
func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTrip(req)
if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) {
return res, err
}
res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body}
return res, err
}}
}
// sanitizeASCIIReadCloser implements the ReadCloser interface.
type sanitizeASCIIReadCloser struct {
io.ReadCloser
addBackslash bool
previousWindow []byte
}
// Read uses a sliding window alogorithm to detect C0 and C1
// ASCII control sequences as they are read and replaces them
// with equivelent inert characters. Characters that are not part
// of a control sequence not modified.
func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
var readErr error
var outIndex int
var bufIndex int
var bufLen int
var window []byte
buf := make([]byte, len(out))
bufLen, readErr = s.ReadCloser.Read(buf)
if readErr != nil && !errors.Is(readErr, io.EOF) {
if bufLen > 0 {
// Do not sanitize if there was a read error that is not EOF.
bufLen = copy(out, buf)
}
return bufLen, readErr
}
if s.previousWindow != nil {
buf = append(s.previousWindow, buf...)
bufLen += len(s.previousWindow)
}
for {
remaining := min(6, (bufLen - bufIndex))
window = buf[bufIndex : bufIndex+remaining]
if remaining < 6 {
break
}
if bytes.HasPrefix(window, []byte(`\u00`)) {
repl, _ := mapControlCharacterToCaret(window)
if s.addBackslash {
repl = append([]byte{92}, repl...)
}
l := len(repl)
for j := 0; j < l; j++ {
out[outIndex] = repl[j]
outIndex++
}
bufIndex += 6
s.addBackslash = false
continue
}
if window[0] == '\\' {
s.addBackslash = !s.addBackslash
} else {
s.addBackslash = false
}
out[outIndex] = buf[bufIndex]
outIndex++
bufIndex++
}
if readErr != nil && errors.Is(readErr, io.EOF) {
remaining := bufLen - bufIndex
for j := 0; j < remaining; j++ {
out[outIndex] = window[j]
outIndex++
bufIndex++
}
} else {
s.previousWindow = window
}
return outIndex, readErr
}
// mapControlCharacterToCaret maps C0 control sequences to caret notation
// and C1 control sequences to hex notation. C1 control sequences do not
// have caret notation representation.
func mapControlCharacterToCaret(b []byte) ([]byte, bool) {
m := map[string]string{
`\u0000`: `^@`,
`\u0001`: `^A`,
`\u0002`: `^B`,
`\u0003`: `^C`,
`\u0004`: `^D`,
`\u0005`: `^E`,
`\u0006`: `^F`,
`\u0007`: `^G`,
`\u0008`: `^H`,
`\u0009`: `^I`,
`\u000a`: `^J`,
`\u000b`: `^K`,
`\u000c`: `^L`,
`\u000d`: `^M`,
`\u000e`: `^N`,
`\u000f`: `^O`,
`\u0010`: `^P`,
`\u0011`: `^Q`,
`\u0012`: `^R`,
`\u0013`: `^S`,
`\u0014`: `^T`,
`\u0015`: `^U`,
`\u0016`: `^V`,
`\u0017`: `^W`,
`\u0018`: `^X`,
`\u0019`: `^Y`,
`\u001a`: `^Z`,
`\u001b`: `^[`,
`\u001c`: `^\\`,
`\u001d`: `^]`,
`\u001e`: `^^`,
`\u001f`: `^_`,
`\u0080`: `\\200`,
`\u0081`: `\\201`,
`\u0082`: `\\202`,
`\u0083`: `\\203`,
`\u0084`: `\\204`,
`\u0085`: `\\205`,
`\u0086`: `\\206`,
`\u0087`: `\\207`,
`\u0088`: `\\210`,
`\u0089`: `\\211`,
`\u008a`: `\\212`,
`\u008b`: `\\213`,
`\u008c`: `\\214`,
`\u008d`: `\\215`,
`\u008e`: `\\216`,
`\u008f`: `\\217`,
`\u0090`: `\\220`,
`\u0091`: `\\221`,
`\u0092`: `\\222`,
`\u0093`: `\\223`,
`\u0094`: `\\224`,
`\u0095`: `\\225`,
`\u0096`: `\\226`,
`\u0097`: `\\227`,
`\u0098`: `\\230`,
`\u0099`: `\\231`,
`\u009a`: `\\232`,
`\u009b`: `\\233`,
`\u009c`: `\\234`,
`\u009d`: `\\235`,
`\u009e`: `\\236`,
`\u009f`: `\\237`,
}
if c, ok := m[strings.ToLower(string(b))]; ok {
return []byte(c), true
}
return b, false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -0,0 +1,50 @@
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
issue := Issue{
Title: "\u001B[31mRed Title\u001B[0m",
Body: "1\u0001 2\u0002 3\u0003 4\u0004 5\u0005 6\u0006 7\u0007 8\u0008 9\t A\r\n B\u000b C\u000c D\r\n E\u000e F\u000f",
Author: Author{
ID: "1",
Name: "10\u0010 11\u0011 12\u0012 13\u0013 14\u0014 15\u0015 16\u0016 17\u0017 18\u0018 19\u0019 1A\u001a 1B\u001b 1C\u001c 1D\u001d 1E\u001e 1F\u001f",
Login: "monalisa",
},
ActiveLockReason: "Escaped \u001B \\u001B \\\u001B \\\\u001B",
}
responseData, _ := json.Marshal(issue)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprint(w, string(responseData))
}))
defer ts.Close()
client, err := NewHTTPClient(HTTPClientOptions{})
require.NoError(t, err)
req, err := http.NewRequest("GET", ts.URL, nil)
require.NoError(t, err)
res, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
res.Body.Close()
require.NoError(t, err)
var issue Issue
err = json.Unmarshal(body, &issue)
require.NoError(t, err)
assert.Equal(t, "^[[31mRed Title^[[0m", issue.Title)
assert.Equal(t, "1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\t A\r\n B^K C^L D\r\n E^N F^O", issue.Body)
assert.Equal(t, "10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y 1A^Z 1B^[ 1C^\\ 1D^] 1E^^ 1F^_", issue.Author.Name)
assert.Equal(t, "monalisa", issue.Author.Login)
assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason)
}

2
go.mod
View file

@ -10,7 +10,7 @@ require (
github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da
github.com/charmbracelet/lipgloss v0.5.0
github.com/cli/go-gh v1.0.0
github.com/cli/oauth v1.0.0
github.com/cli/oauth v1.0.1
github.com/cli/safeexec v1.0.1
github.com/cpuguy83/go-md2man/v2 v2.0.2
github.com/creack/pty v1.1.18

4
go.sum
View file

@ -62,8 +62,8 @@ github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
github.com/cli/go-gh v1.0.0 h1:zE1YUAUYqGXNZuICEBeOkIMJ5F50BS0ftvtoWGlsEFI=
github.com/cli/go-gh v1.0.0/go.mod h1:bqxLdCoTZ73BuiPEJx4olcO/XKhVZaFDchFagYRBweE=
github.com/cli/oauth v1.0.0 h1:zuatYn8BRWWO98y2jNXK4RKOryU1u6JTqPrdSPW5pSE=
github.com/cli/oauth v1.0.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
github.com/cli/oauth v1.0.1 h1:pXnTFl/qUegXHK531Dv0LNjW4mLx626eS42gnzfXJPA=
github.com/cli/oauth v1.0.1/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00=
github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: codespace/codespace_host_service.v1.proto

View file

@ -0,0 +1,168 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package codespace
import (
context "context"
sync "sync"
)
// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer.
// If this is not the case, regenerate this file with moq.
var _ CodespaceHostServer = &CodespaceHostServerMock{}
// CodespaceHostServerMock is a mock implementation of CodespaceHostServer.
//
// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) {
//
// // make and configure a mocked CodespaceHostServer
// mockedCodespaceHostServer := &CodespaceHostServerMock{
// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
// panic("mock out the NotifyCodespaceOfClientActivity method")
// },
// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
// panic("mock out the RebuildContainerAsync method")
// },
// mustEmbedUnimplementedCodespaceHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method")
// },
// }
//
// // use mockedCodespaceHostServer in code that requires CodespaceHostServer
// // and then make assertions.
//
// }
type CodespaceHostServerMock struct {
// NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method.
NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error)
// RebuildContainerAsyncFunc mocks the RebuildContainerAsync method.
RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error)
// mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method.
mustEmbedUnimplementedCodespaceHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method.
NotifyCodespaceOfClientActivity []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value.
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}
// RebuildContainerAsync holds details about calls to the RebuildContainerAsync method.
RebuildContainerAsync []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// RebuildContainerRequest is the rebuildContainerRequest argument value.
RebuildContainerRequest *RebuildContainerRequest
}
// mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method.
mustEmbedUnimplementedCodespaceHostServer []struct {
}
}
lockNotifyCodespaceOfClientActivity sync.RWMutex
lockRebuildContainerAsync sync.RWMutex
lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex
}
// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc.
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
if mock.NotifyCodespaceOfClientActivityFunc == nil {
panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called")
}
callInfo := struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}{
ContextMoqParam: contextMoqParam,
NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest,
}
mock.lockNotifyCodespaceOfClientActivity.Lock()
mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo)
mock.lockNotifyCodespaceOfClientActivity.Unlock()
return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest)
}
// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity.
// Check the length with:
//
// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls())
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
} {
var calls []struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}
mock.lockNotifyCodespaceOfClientActivity.RLock()
calls = mock.calls.NotifyCodespaceOfClientActivity
mock.lockNotifyCodespaceOfClientActivity.RUnlock()
return calls
}
// RebuildContainerAsync calls RebuildContainerAsyncFunc.
func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
if mock.RebuildContainerAsyncFunc == nil {
panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called")
}
callInfo := struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
}{
ContextMoqParam: contextMoqParam,
RebuildContainerRequest: rebuildContainerRequest,
}
mock.lockRebuildContainerAsync.Lock()
mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo)
mock.lockRebuildContainerAsync.Unlock()
return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest)
}
// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync.
// Check the length with:
//
// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls())
func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
} {
var calls []struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
}
mock.lockRebuildContainerAsync.RLock()
calls = mock.calls.RebuildContainerAsync
mock.lockRebuildContainerAsync.RUnlock()
return calls
}
// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc.
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() {
if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil {
panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock()
mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo)
mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock()
mock.mustEmbedUnimplementedCodespaceHostServerFunc()
}
// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer.
// Check the length with:
//
// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls())
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer
mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock()
return calls
}

View file

@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers.
1. [Download `protoc`](https://grpc.io/docs/protoc-installation/)
2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/)
3. Run `./generate.sh` from the `internal/codespaces/grpc` directory
3. Install moq: `go install github.com/matryer/moq@latest`
4. Run `./generate.sh` from the `internal/codespaces/rpc` directory
## Add New Protocol Buffers

View file

@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then
fi
function generate {
local contract="$1"
local dir="$1"
local proto="$2"
local contract="$dir/$proto"
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract"
echo "Generated protocol buffers for $contract"
services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}')
moq -out $contract.mock.go $dir $services
echo "Generated mock protocols for $contract"
}
generate jupyter/jupyter_server_host_service.v1.proto
generate codespace/codespace_host_service.v1.proto
generate ssh/ssh_server_host_service.v1.proto
generate jupyter jupyter_server_host_service.v1.proto
generate codespace codespace_host_service.v1.proto
generate ssh ssh_server_host_service.v1.proto
echo 'Done!'

View file

@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
invoker.codespaceClient = codespace.NewCodespaceHostClient(conn)
invoker.sshClient = ssh.NewSshServerHostClient(conn)
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
_ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName)
// Start the activity heatbeats
go invoker.heartbeat(pfctx, 1*time.Minute)
@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
_ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName)
for {
select {
case <-ctx.Done():

View file

@ -3,84 +3,159 @@ package rpc
import (
"context"
"fmt"
"log"
"os"
"net"
"strconv"
"testing"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
"google.golang.org/grpc"
)
func startServer(t *testing.T) {
t.Helper()
if os.Getenv("GITHUB_ACTIONS") == "true" {
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
type mockServer struct {
jupyter.JupyterServerHostServerMock
codespace.CodespaceHostServerMock
ssh.SshServerHostServerMock
}
func newMockServer() *mockServer {
server := &mockServer{}
server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
return &codespace.NotifyCodespaceOfClientActivityResponse{
Message: "",
Result: true,
}, nil
}
return server
}
// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks.
// It does not return until the Context is cancelled and the server fully shuts down.
func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error {
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, server)
codespace.RegisterCodespaceHostServer(s, server)
ssh.RegisterSshServerHostServer(s, server)
ch := make(chan error, 1)
go func() { ch <- s.Serve(listener) }()
select {
case <-ctx.Done():
s.Stop()
<-ch
return nil
case err := <-ch:
return err
}
}
// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function.
// The Invoker does not need to be closed directly, that will be handled by the shutdown function.
func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) {
listener, err := net.Listen("tcp", "127.0.0.1:16634")
if err != nil {
return nil, nil, fmt.Errorf("failed to listen: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan error)
go func() { ch <- runTestGrpcServer(ctx, listener, server) }()
// Start the gRPC server in the background
go func() {
err := rpctest.StartServer(ctx)
if err != nil && err != context.Canceled {
log.Println(fmt.Errorf("error starting test server: %v", err))
}
}()
// Stop the gRPC server when the test is done
t.Cleanup(func() {
close := func() {
cancel()
})
}
func createTestInvoker(t *testing.T) Invoker {
t.Helper()
// Clear the stored client activity
rpctest.NotifyReceivedActivity = ""
<-ch
listener.Close()
}
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
close()
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
}
t.Cleanup(func() {
testNotifyCodespaceOfClientActivity(t)
return invoker, func() {
invoker.Close()
})
return invoker
close()
}, nil
}
// Test that the RPC invoker notifies the codespace of client activity on connection
func testNotifyCodespaceOfClientActivity(t *testing.T) {
if rpctest.NotifyReceivedActivity != connectedEventName {
t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage)
func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) {
calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls()
if len(calls) == 0 {
t.Fatalf("no client activity calls")
}
for _, call := range calls {
activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities()
if activities[0] == connectedEventName {
return
}
}
t.Fatalf("no activity named %s", connectedEventName)
}
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
func TestStartJupyterServerSuccess(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
resp := jupyter.GetRunningServerResponse{
Port: strconv.Itoa(1234),
ServerUrl: "http://localhost:1234?token=1234",
Message: "",
Result: true,
}
server := newMockServer()
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
port, url, err := invoker.StartJupyterServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != rpctest.JupyterPort {
t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port)
if strconv.Itoa(port) != resp.Port {
t.Fatalf("expected %s, got %d", resp.Port, port)
}
if url != rpctest.JupyterServerUrl {
t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url)
if url != resp.ServerUrl {
t.Fatalf("expected %s, got %s", resp.ServerUrl, url)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
func TestStartJupyterServerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.JupyterMessage = "error message"
rpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
resp := jupyter.GetRunningServerResponse{
Port: strconv.Itoa(1234),
ServerUrl: "http://localhost:1234?token=1234",
Message: "error message",
Result: false,
}
server := newMockServer()
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message)
port, url, err := invoker.StartJupyterServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) {
if url != "" {
t.Fatalf("expected %s, got %s", "", url)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild
func TestRebuildContainerIncremental(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
err := invoker.RebuildContainer(context.Background(), false)
resp := codespace.RebuildContainerResponse{
RebuildContainer: true,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
err = invoker.RebuildContainer(context.Background(), false)
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker doesn't throw an error when requesting a full rebuild
func TestRebuildContainerFull(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
err := invoker.RebuildContainer(context.Background(), true)
resp := codespace.RebuildContainerResponse{
RebuildContainer: true,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
err = invoker.RebuildContainer(context.Background(), true)
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker throws an error when the rebuild fails
func TestRebuildContainerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.RebuildContainer = false
resp := codespace.RebuildContainerResponse{
RebuildContainer: false,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := "couldn't rebuild codespace"
err := invoker.RebuildContainer(context.Background(), true)
err = invoker.RebuildContainer(context.Background(), true)
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
}
@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) {
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
func TestStartSSHServerSuccess(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
resp := ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(1234),
User: "test",
Message: "",
Result: true,
}
server := newMockServer()
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
port, user, err := invoker.StartSSHServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != rpctest.SshServerPort {
t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port)
if strconv.Itoa(port) != resp.ServerPort {
t.Fatalf("expected %s, got %d", resp.ServerPort, port)
}
if user != rpctest.SshUser {
t.Fatalf("expected %s, got %s", rpctest.SshUser, user)
if user != resp.User {
t.Fatalf("expected %s, got %s", resp.User, user)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker returns an error when the SSH server fails to start
func TestStartSSHServerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.SshMessage = "error message"
rpctest.SshResult = false
errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage)
resp := ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(1234),
User: "test",
Message: "error message",
Result: false,
}
server := newMockServer()
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message)
port, user, err := invoker.StartSSHServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: jupyter/jupyter_server_host_service.v1.proto

View file

@ -0,0 +1,118 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package jupyter
import (
context "context"
sync "sync"
)
// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer.
// If this is not the case, regenerate this file with moq.
var _ JupyterServerHostServer = &JupyterServerHostServerMock{}
// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer.
//
// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) {
//
// // make and configure a mocked JupyterServerHostServer
// mockedJupyterServerHostServer := &JupyterServerHostServerMock{
// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
// panic("mock out the GetRunningServer method")
// },
// mustEmbedUnimplementedJupyterServerHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method")
// },
// }
//
// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer
// // and then make assertions.
//
// }
type JupyterServerHostServerMock struct {
// GetRunningServerFunc mocks the GetRunningServer method.
GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error)
// mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method.
mustEmbedUnimplementedJupyterServerHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// GetRunningServer holds details about calls to the GetRunningServer method.
GetRunningServer []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// GetRunningServerRequest is the getRunningServerRequest argument value.
GetRunningServerRequest *GetRunningServerRequest
}
// mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method.
mustEmbedUnimplementedJupyterServerHostServer []struct {
}
}
lockGetRunningServer sync.RWMutex
lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex
}
// GetRunningServer calls GetRunningServerFunc.
func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
if mock.GetRunningServerFunc == nil {
panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called")
}
callInfo := struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
}{
ContextMoqParam: contextMoqParam,
GetRunningServerRequest: getRunningServerRequest,
}
mock.lockGetRunningServer.Lock()
mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo)
mock.lockGetRunningServer.Unlock()
return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest)
}
// GetRunningServerCalls gets all the calls that were made to GetRunningServer.
// Check the length with:
//
// len(mockedJupyterServerHostServer.GetRunningServerCalls())
func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
} {
var calls []struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
}
mock.lockGetRunningServer.RLock()
calls = mock.calls.GetRunningServer
mock.lockGetRunningServer.RUnlock()
return calls
}
// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc.
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() {
if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil {
panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock()
mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo)
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock()
mock.mustEmbedUnimplementedJupyterServerHostServerFunc()
}
// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer.
// Check the length with:
//
// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls())
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock()
return calls
}

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: ssh/ssh_server_host_service.v1.proto

View file

@ -0,0 +1,118 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ssh
import (
context "context"
sync "sync"
)
// Ensure, that SshServerHostServerMock does implement SshServerHostServer.
// If this is not the case, regenerate this file with moq.
var _ SshServerHostServer = &SshServerHostServerMock{}
// SshServerHostServerMock is a mock implementation of SshServerHostServer.
//
// func TestSomethingThatUsesSshServerHostServer(t *testing.T) {
//
// // make and configure a mocked SshServerHostServer
// mockedSshServerHostServer := &SshServerHostServerMock{
// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
// panic("mock out the StartRemoteServerAsync method")
// },
// mustEmbedUnimplementedSshServerHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedSshServerHostServer method")
// },
// }
//
// // use mockedSshServerHostServer in code that requires SshServerHostServer
// // and then make assertions.
//
// }
type SshServerHostServerMock struct {
// StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method.
StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error)
// mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method.
mustEmbedUnimplementedSshServerHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method.
StartRemoteServerAsync []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// StartRemoteServerRequest is the startRemoteServerRequest argument value.
StartRemoteServerRequest *StartRemoteServerRequest
}
// mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method.
mustEmbedUnimplementedSshServerHostServer []struct {
}
}
lockStartRemoteServerAsync sync.RWMutex
lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex
}
// StartRemoteServerAsync calls StartRemoteServerAsyncFunc.
func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
if mock.StartRemoteServerAsyncFunc == nil {
panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called")
}
callInfo := struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
}{
ContextMoqParam: contextMoqParam,
StartRemoteServerRequest: startRemoteServerRequest,
}
mock.lockStartRemoteServerAsync.Lock()
mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo)
mock.lockStartRemoteServerAsync.Unlock()
return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest)
}
// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync.
// Check the length with:
//
// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls())
func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
} {
var calls []struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
}
mock.lockStartRemoteServerAsync.RLock()
calls = mock.calls.StartRemoteServerAsync
mock.lockStartRemoteServerAsync.RUnlock()
return calls
}
// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc.
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() {
if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil {
panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedSshServerHostServer.Lock()
mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo)
mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock()
mock.mustEmbedUnimplementedSshServerHostServerFunc()
}
// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer.
// Check the length with:
//
// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls())
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedSshServerHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedSshServerHostServer
mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock()
return calls
}

View file

@ -1,117 +0,0 @@
package test
import (
"context"
"fmt"
"net"
"strconv"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
"google.golang.org/grpc"
)
const (
ServerPort = 50051
)
// Mock responses for the `GetRunningServer` RPC method
var (
JupyterPort = 1234
JupyterServerUrl = "http://localhost:1234?token=1234"
JupyterMessage = ""
JupyterResult = true
)
// Mock responses for the `RebuildContainerAsync` RPC method
var (
RebuildContainer = true
)
// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method
// NotifyMessage is used to store the activity that was sent to the server
var (
NotifyMessage = ""
NotifyResult = true
NotifyReceivedActivity = ""
)
// Mock responses for the `StartRemoteServerAsync` RPC method
var (
SshServerPort = 1234
SshUser = "test"
SshMessage = ""
SshResult = true
)
type server struct {
jupyter.UnimplementedJupyterServerHostServer
codespace.CodespaceHostServer
ssh.SshServerHostServer
}
func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &jupyter.GetRunningServerResponse{
Port: strconv.Itoa(JupyterPort),
ServerUrl: JupyterServerUrl,
Message: JupyterMessage,
Result: JupyterResult,
}, nil
}
func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &codespace.RebuildContainerResponse{
RebuildContainer: RebuildContainer,
}, nil
}
func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
// If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected")
if len(in.GetClientActivities()) > 0 {
NotifyReceivedActivity = in.GetClientActivities()[0]
}
return &codespace.NotifyCodespaceOfClientActivityResponse{
Message: NotifyMessage,
Result: NotifyResult,
}, nil
}
func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(SshServerPort),
User: SshUser,
Message: SshMessage,
Result: SshResult,
}, nil
}
// Starts the mock gRPC server listening on port 50051
func StartServer(ctx context.Context) error {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}
defer listener.Close()
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, &server{})
codespace.RegisterCodespaceHostServer(s, &server{})
ssh.RegisterSshServerHostServer(s, &server{})
ch := make(chan error, 1)
go func() {
if err := s.Serve(listener); err != nil {
ch <- fmt.Errorf("failed to serve: %v", err)
}
}()
select {
case <-ctx.Done():
s.Stop()
return ctx.Err()
case err := <-ch:
return err
}
}

View file

@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string {
}
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return liveshare.ChannelID{}, err
}

View file

@ -38,7 +38,7 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
opts.KeyID = args[0]
if !opts.IO.CanPrompt() && !opts.Confirmed {
return cmdutil.FlagErrorf("--confirm required when not running interactively")
return cmdutil.FlagErrorf("--yes required when not running interactively")
}
if runF != nil {
@ -48,7 +48,9 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
},
}
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
return cmd
}

View file

@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) {
{
name: "confirm flag tty",
tty: true,
input: "ABC123 --confirm",
input: "ABC123 --yes",
output: DeleteOptions{KeyID: "ABC123", Confirmed: true},
},
{
@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) {
name: "no tty",
input: "ABC123",
wantErr: true,
wantErrMsg: "--confirm required when not running interactively",
wantErrMsg: "--yes required when not running interactively",
},
{
name: "confirm flag no tty",
input: "ABC123 --confirm",
input: "ABC123 --yes",
output: DeleteOptions{KeyID: "ABC123", Confirmed: true},
},
{

View file

@ -49,11 +49,14 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
if runF != nil {
return runF(opts)
}
return deleteRun(opts)
},
}
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "confirm deletion without prompting")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting")
return cmd
}

View file

@ -42,7 +42,7 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co
opts.Name = args[0]
if !opts.IO.CanPrompt() && !opts.Confirmed {
return cmdutil.FlagErrorf("--confirm required when not running interactively")
return cmdutil.FlagErrorf("--yes required when not running interactively")
}
if runF != nil {
@ -53,6 +53,8 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co
}
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Confirm deletion without prompting")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "Confirm deletion without prompting")
return cmd
}

View file

@ -37,14 +37,14 @@ func TestNewCmdDelete(t *testing.T) {
},
{
name: "confirm argument",
input: "test --confirm",
input: "test --yes",
output: deleteOptions{Name: "test", Confirmed: true},
},
{
name: "confirm no tty",
input: "test",
wantErr: true,
wantErrMsg: "--confirm required when not running interactively",
wantErrMsg: "--yes required when not running interactively",
},
}

View file

@ -47,16 +47,20 @@ With no argument, archives the current repository.`),
}
if !opts.Confirmed && !opts.IO.CanPrompt() {
return cmdutil.FlagErrorf("--confirm required when not running interactively")
return cmdutil.FlagErrorf("--yes required when not running interactively")
}
if runF != nil {
return runF(opts)
}
return archiveRun(opts)
},
}
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
return cmd
}

View file

@ -26,7 +26,7 @@ func TestNewCmdArchive(t *testing.T) {
{
name: "no arguments no tty",
input: "",
errMsg: "--confirm required when not running interactively",
errMsg: "--yes required when not running interactively",
wantErr: true,
},
{

View file

@ -60,7 +60,7 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co
if len(args) == 1 && !confirm && !opts.HasRepoOverride {
if !opts.IO.CanPrompt() {
return cmdutil.FlagErrorf("--confirm required when passing a single argument")
return cmdutil.FlagErrorf("--yes required when passing a single argument")
}
opts.DoConfirm = true
}
@ -68,12 +68,15 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co
if runf != nil {
return runf(opts)
}
return renameRun(opts)
},
}
cmdutil.EnableRepoOverride(cmd, f)
cmd.Flags().BoolVarP(&confirm, "confirm", "y", false, "skip confirmation prompt")
cmd.Flags().BoolVar(&confirm, "confirm", false, "Skip confirmation prompt")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVarP(&confirm, "yes", "y", false, "Skip the confirmation prompt")
return cmd
}

View file

@ -35,7 +35,7 @@ func TestNewCmdRename(t *testing.T) {
},
{
name: "one argument no tty confirmed",
input: "REPO --confirm",
input: "REPO --yes",
output: RenameOptions{
newRepoSelector: "REPO",
},
@ -43,12 +43,12 @@ func TestNewCmdRename(t *testing.T) {
{
name: "one argument no tty",
input: "REPO",
errMsg: "--confirm required when passing a single argument",
errMsg: "--yes required when passing a single argument",
wantErr: true,
},
{
name: "one argument tty confirmed",
input: "REPO --confirm",
input: "REPO --yes",
tty: true,
output: RenameOptions{
newRepoSelector: "REPO",

View file

@ -77,7 +77,7 @@ type Run struct {
workflowName string // cache column
WorkflowID int64 `json:"workflow_id"`
Number int64 `json:"run_number"`
Attempts uint8 `json:"run_attempt"`
Attempts uint64 `json:"run_attempt"`
HeadBranch string `json:"head_branch"`
JobsURL string `json:"jobs_url"`
HeadCommit Commit `json:"head_commit"`

View file

@ -20,9 +20,9 @@ type SecretPayload struct {
KeyID string `json:"key_id"`
}
// The Codespaces Secret API currently expects repositories IDs as strings
type CodespacesSecretPayload struct {
type DependabotSecretPayload struct {
EncryptedValue string `json:"encrypted_value"`
Visibility string `json:"visibility,omitempty"`
Repositories []string `json:"selected_repository_ids,omitempty"`
KeyID string `json:"key_id"`
}
@ -70,31 +70,40 @@ func putSecret(client *api.Client, host, path string, payload interface{}) error
}
func putOrgSecret(client *api.Client, host string, pk *PubKey, orgName, visibility, secretName, eValue string, repositoryIDs []int64, app shared.App) error {
path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName)
if app == shared.Dependabot {
repos := make([]string, len(repositoryIDs))
for i, id := range repositoryIDs {
repos[i] = strconv.FormatInt(id, 10)
}
payload := DependabotSecretPayload{
EncryptedValue: eValue,
KeyID: pk.ID,
Repositories: repos,
Visibility: visibility,
}
return putSecret(client, host, path, payload)
}
payload := SecretPayload{
EncryptedValue: eValue,
KeyID: pk.ID,
Repositories: repositoryIDs,
Visibility: visibility,
}
path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName)
return putSecret(client, host, path, payload)
}
func putUserSecret(client *api.Client, host string, pk *PubKey, key, eValue string, repositoryIDs []int64) error {
payload := CodespacesSecretPayload{
payload := SecretPayload{
EncryptedValue: eValue,
KeyID: pk.ID,
Repositories: repositoryIDs,
}
if len(repositoryIDs) > 0 {
repositoryStringIDs := make([]string, len(repositoryIDs))
for i, id := range repositoryIDs {
repositoryStringIDs[i] = strconv.FormatInt(id, 10)
}
payload.Repositories = repositoryStringIDs
}
path := fmt.Sprintf("user/codespaces/secrets/%s", key)
return putSecret(client, host, path, payload)
}

View file

@ -333,11 +333,12 @@ func Test_setRun_env(t *testing.T) {
func Test_setRun_org(t *testing.T) {
tests := []struct {
name string
opts *SetOptions
wantVisibility shared.Visibility
wantRepositories []int64
wantApp string
name string
opts *SetOptions
wantVisibility shared.Visibility
wantRepositories []int64
wantDependabotRepositories []string
wantApp string
}{
{
name: "all vis",
@ -362,10 +363,21 @@ func Test_setRun_org(t *testing.T) {
opts: &SetOptions{
OrgName: "UmbrellaCorporation",
Visibility: shared.All,
Application: "dependabot",
Application: shared.Dependabot,
},
wantApp: "dependabot",
},
{
name: "Dependabot selected visibility",
opts: &SetOptions{
OrgName: "UmbrellaCorporation",
Visibility: shared.Selected,
Application: shared.Dependabot,
RepositoryNames: []string{"birkin", "UmbrellaCorporation/wesker"},
},
wantDependabotRepositories: []string{"1", "2"},
wantApp: "dependabot",
},
}
for _, tt := range tests {
@ -410,13 +422,24 @@ func Test_setRun_org(t *testing.T) {
data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body)
assert.NoError(t, err)
var payload SecretPayload
err = json.Unmarshal(data, &payload)
assert.NoError(t, err)
assert.Equal(t, payload.KeyID, "123")
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories)
if tt.opts.Application == shared.Dependabot {
var payload DependabotSecretPayload
err = json.Unmarshal(data, &payload)
assert.NoError(t, err)
assert.Equal(t, payload.KeyID, "123")
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
assert.ElementsMatch(t, payload.Repositories, tt.wantDependabotRepositories)
} else {
var payload SecretPayload
err = json.Unmarshal(data, &payload)
assert.NoError(t, err)
assert.Equal(t, payload.KeyID, "123")
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories)
}
})
}
}
@ -426,7 +449,7 @@ func Test_setRun_user(t *testing.T) {
name string
opts *SetOptions
wantVisibility shared.Visibility
wantRepositories []string
wantRepositories []int64
}{
{
name: "all vis",
@ -442,7 +465,7 @@ func Test_setRun_user(t *testing.T) {
Visibility: shared.Selected,
RepositoryNames: []string{"cli/cli", "github/hub"},
},
wantRepositories: []string{"212613049", "401025"},
wantRepositories: []int64{212613049, 401025},
},
}
@ -481,7 +504,7 @@ func Test_setRun_user(t *testing.T) {
data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body)
assert.NoError(t, err)
var payload CodespacesSecretPayload
var payload SecretPayload
err = json.Unmarshal(data, &payload)
assert.NoError(t, err)
assert.Equal(t, payload.KeyID, "123")

View file

@ -37,17 +37,21 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
opts.KeyID = args[0]
if !opts.IO.CanPrompt() && !opts.Confirmed {
return cmdutil.FlagErrorf("--confirm required when not running interactively")
return cmdutil.FlagErrorf("--yes required when not running interactively")
}
if runF != nil {
return runF(opts)
}
return deleteRun(opts)
},
}
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
return cmd
}

View file

@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) {
{
name: "confirm flag tty",
tty: true,
input: "123 --confirm",
input: "123 --yes",
output: DeleteOptions{KeyID: "123", Confirmed: true},
},
{
@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) {
name: "no tty",
input: "123",
wantErr: true,
wantErrMsg: "--confirm required when not running interactively",
wantErrMsg: "--yes required when not running interactively",
},
{
name: "confirm flag no tty",
input: "123 --confirm",
input: "123 --yes",
output: DeleteOptions{KeyID: "123", Confirmed: true},
},
{