Merge branch 'trunk' into gh-ext-browse-followup
This commit is contained in:
commit
5dea3a923f
33 changed files with 997 additions and 256 deletions
|
|
@ -64,6 +64,8 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
|
|||
client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
|
||||
}
|
||||
|
||||
client.Transport = AddASCIISanitizer(client.Transport)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ import (
|
|||
|
||||
const (
|
||||
errorProjectsV2ReadScope = "field requires one of the following scopes: ['read:project']"
|
||||
errorProjectsV2RepositoryField = "Field 'ProjectsV2' doesn't exist on type 'Repository'"
|
||||
errorProjectsV2OrganizationField = "Field 'ProjectsV2' doesn't exist on type 'Organization'"
|
||||
errorProjectsV2IssueField = "Field 'ProjectItems' doesn't exist on type 'Issue'"
|
||||
errorProjectsV2PullRequestField = "Field 'ProjectItems' doesn't exist on type 'PullRequest'"
|
||||
errorProjectsV2RepositoryField = "Field 'projectsV2' doesn't exist on type 'Repository'"
|
||||
errorProjectsV2OrganizationField = "Field 'projectsV2' doesn't exist on type 'Organization'"
|
||||
errorProjectsV2IssueField = "Field 'projectItems' doesn't exist on type 'Issue'"
|
||||
errorProjectsV2PullRequestField = "Field 'projectItems' doesn't exist on type 'PullRequest'"
|
||||
)
|
||||
|
||||
// UpdateProjectV2Items uses the addProjectV2ItemById and the deleteProjectV2Item mutations
|
||||
|
|
|
|||
|
|
@ -221,22 +221,22 @@ func TestProjectsV2IgnorableError(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "repository projectsV2 field error",
|
||||
errMsg: "Field 'ProjectsV2' doesn't exist on type 'Repository'",
|
||||
errMsg: "Field 'projectsV2' doesn't exist on type 'Repository'",
|
||||
expectOut: true,
|
||||
},
|
||||
{
|
||||
name: "organization projectsV2 field error",
|
||||
errMsg: "Field 'ProjectsV2' doesn't exist on type 'Organization'",
|
||||
errMsg: "Field 'projectsV2' doesn't exist on type 'Organization'",
|
||||
expectOut: true,
|
||||
},
|
||||
{
|
||||
name: "issue projectItems field error",
|
||||
errMsg: "Field 'ProjectItems' doesn't exist on type 'Issue'",
|
||||
errMsg: "Field 'projectItems' doesn't exist on type 'Issue'",
|
||||
expectOut: true,
|
||||
},
|
||||
{
|
||||
name: "pullRequest projectItems field error",
|
||||
errMsg: "Field 'ProjectItems' doesn't exist on type 'PullRequest'",
|
||||
errMsg: "Field 'projectItems' doesn't exist on type 'PullRequest'",
|
||||
expectOut: true,
|
||||
},
|
||||
{
|
||||
|
|
|
|||
193
api/sanitize_ascii.go
Normal file
193
api/sanitize_ascii.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
|
||||
|
||||
// GitHub servers return non-printable characters as their unicode code point values.
|
||||
// The values of \u0000 to \u001F represent C0 ASCII control characters and
|
||||
// the values of \u0080 to \u009F represent C1 ASCII control characters. These control
|
||||
// characters will be interpreted by the terminal, this behaviour can be used maliciously
|
||||
// as an attack vector, especially the control character \u001B. This function wraps
|
||||
// JSON response bodies in a ReadCloser that transforms C0 and C1 control characters
|
||||
// to their caret and hex notations respectively so that the terminal will not interpret them.
|
||||
func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper {
|
||||
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) {
|
||||
return res, err
|
||||
}
|
||||
res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body}
|
||||
return res, err
|
||||
}}
|
||||
}
|
||||
|
||||
// sanitizeASCIIReadCloser implements the ReadCloser interface.
|
||||
type sanitizeASCIIReadCloser struct {
|
||||
io.ReadCloser
|
||||
addBackslash bool
|
||||
previousWindow []byte
|
||||
}
|
||||
|
||||
// Read uses a sliding window alogorithm to detect C0 and C1
|
||||
// ASCII control sequences as they are read and replaces them
|
||||
// with equivelent inert characters. Characters that are not part
|
||||
// of a control sequence not modified.
|
||||
func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
|
||||
var readErr error
|
||||
var outIndex int
|
||||
var bufIndex int
|
||||
var bufLen int
|
||||
var window []byte
|
||||
buf := make([]byte, len(out))
|
||||
|
||||
bufLen, readErr = s.ReadCloser.Read(buf)
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
if bufLen > 0 {
|
||||
// Do not sanitize if there was a read error that is not EOF.
|
||||
bufLen = copy(out, buf)
|
||||
}
|
||||
return bufLen, readErr
|
||||
}
|
||||
|
||||
if s.previousWindow != nil {
|
||||
buf = append(s.previousWindow, buf...)
|
||||
bufLen += len(s.previousWindow)
|
||||
}
|
||||
|
||||
for {
|
||||
remaining := min(6, (bufLen - bufIndex))
|
||||
window = buf[bufIndex : bufIndex+remaining]
|
||||
if remaining < 6 {
|
||||
break
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(window, []byte(`\u00`)) {
|
||||
repl, _ := mapControlCharacterToCaret(window)
|
||||
if s.addBackslash {
|
||||
repl = append([]byte{92}, repl...)
|
||||
}
|
||||
l := len(repl)
|
||||
for j := 0; j < l; j++ {
|
||||
out[outIndex] = repl[j]
|
||||
outIndex++
|
||||
}
|
||||
bufIndex += 6
|
||||
s.addBackslash = false
|
||||
continue
|
||||
}
|
||||
|
||||
if window[0] == '\\' {
|
||||
s.addBackslash = !s.addBackslash
|
||||
} else {
|
||||
s.addBackslash = false
|
||||
}
|
||||
|
||||
out[outIndex] = buf[bufIndex]
|
||||
outIndex++
|
||||
bufIndex++
|
||||
}
|
||||
|
||||
if readErr != nil && errors.Is(readErr, io.EOF) {
|
||||
remaining := bufLen - bufIndex
|
||||
for j := 0; j < remaining; j++ {
|
||||
out[outIndex] = window[j]
|
||||
outIndex++
|
||||
bufIndex++
|
||||
}
|
||||
} else {
|
||||
s.previousWindow = window
|
||||
}
|
||||
|
||||
return outIndex, readErr
|
||||
}
|
||||
|
||||
// mapControlCharacterToCaret maps C0 control sequences to caret notation
|
||||
// and C1 control sequences to hex notation. C1 control sequences do not
|
||||
// have caret notation representation.
|
||||
func mapControlCharacterToCaret(b []byte) ([]byte, bool) {
|
||||
m := map[string]string{
|
||||
`\u0000`: `^@`,
|
||||
`\u0001`: `^A`,
|
||||
`\u0002`: `^B`,
|
||||
`\u0003`: `^C`,
|
||||
`\u0004`: `^D`,
|
||||
`\u0005`: `^E`,
|
||||
`\u0006`: `^F`,
|
||||
`\u0007`: `^G`,
|
||||
`\u0008`: `^H`,
|
||||
`\u0009`: `^I`,
|
||||
`\u000a`: `^J`,
|
||||
`\u000b`: `^K`,
|
||||
`\u000c`: `^L`,
|
||||
`\u000d`: `^M`,
|
||||
`\u000e`: `^N`,
|
||||
`\u000f`: `^O`,
|
||||
`\u0010`: `^P`,
|
||||
`\u0011`: `^Q`,
|
||||
`\u0012`: `^R`,
|
||||
`\u0013`: `^S`,
|
||||
`\u0014`: `^T`,
|
||||
`\u0015`: `^U`,
|
||||
`\u0016`: `^V`,
|
||||
`\u0017`: `^W`,
|
||||
`\u0018`: `^X`,
|
||||
`\u0019`: `^Y`,
|
||||
`\u001a`: `^Z`,
|
||||
`\u001b`: `^[`,
|
||||
`\u001c`: `^\\`,
|
||||
`\u001d`: `^]`,
|
||||
`\u001e`: `^^`,
|
||||
`\u001f`: `^_`,
|
||||
`\u0080`: `\\200`,
|
||||
`\u0081`: `\\201`,
|
||||
`\u0082`: `\\202`,
|
||||
`\u0083`: `\\203`,
|
||||
`\u0084`: `\\204`,
|
||||
`\u0085`: `\\205`,
|
||||
`\u0086`: `\\206`,
|
||||
`\u0087`: `\\207`,
|
||||
`\u0088`: `\\210`,
|
||||
`\u0089`: `\\211`,
|
||||
`\u008a`: `\\212`,
|
||||
`\u008b`: `\\213`,
|
||||
`\u008c`: `\\214`,
|
||||
`\u008d`: `\\215`,
|
||||
`\u008e`: `\\216`,
|
||||
`\u008f`: `\\217`,
|
||||
`\u0090`: `\\220`,
|
||||
`\u0091`: `\\221`,
|
||||
`\u0092`: `\\222`,
|
||||
`\u0093`: `\\223`,
|
||||
`\u0094`: `\\224`,
|
||||
`\u0095`: `\\225`,
|
||||
`\u0096`: `\\226`,
|
||||
`\u0097`: `\\227`,
|
||||
`\u0098`: `\\230`,
|
||||
`\u0099`: `\\231`,
|
||||
`\u009a`: `\\232`,
|
||||
`\u009b`: `\\233`,
|
||||
`\u009c`: `\\234`,
|
||||
`\u009d`: `\\235`,
|
||||
`\u009e`: `\\236`,
|
||||
`\u009f`: `\\237`,
|
||||
}
|
||||
if c, ok := m[strings.ToLower(string(b))]; ok {
|
||||
return []byte(c), true
|
||||
}
|
||||
return b, false
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
50
api/sanitize_ascii_test.go
Normal file
50
api/sanitize_ascii_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
issue := Issue{
|
||||
Title: "\u001B[31mRed Title\u001B[0m",
|
||||
Body: "1\u0001 2\u0002 3\u0003 4\u0004 5\u0005 6\u0006 7\u0007 8\u0008 9\t A\r\n B\u000b C\u000c D\r\n E\u000e F\u000f",
|
||||
Author: Author{
|
||||
ID: "1",
|
||||
Name: "10\u0010 11\u0011 12\u0012 13\u0013 14\u0014 15\u0015 16\u0016 17\u0017 18\u0018 19\u0019 1A\u001a 1B\u001b 1C\u001c 1D\u001d 1E\u001e 1F\u001f",
|
||||
Login: "monalisa",
|
||||
},
|
||||
ActiveLockReason: "Escaped \u001B \\u001B \\\u001B \\\\u001B",
|
||||
}
|
||||
responseData, _ := json.Marshal(issue)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
fmt.Fprint(w, string(responseData))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client, err := NewHTTPClient(HTTPClientOptions{})
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest("GET", ts.URL, nil)
|
||||
require.NoError(t, err)
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
var issue Issue
|
||||
err = json.Unmarshal(body, &issue)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "^[[31mRed Title^[[0m", issue.Title)
|
||||
assert.Equal(t, "1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\t A\r\n B^K C^L D\r\n E^N F^O", issue.Body)
|
||||
assert.Equal(t, "10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y 1A^Z 1B^[ 1C^\\ 1D^] 1E^^ 1F^_", issue.Author.Name)
|
||||
assert.Equal(t, "monalisa", issue.Author.Login)
|
||||
assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason)
|
||||
}
|
||||
2
go.mod
2
go.mod
|
|
@ -10,7 +10,7 @@ require (
|
|||
github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da
|
||||
github.com/charmbracelet/lipgloss v0.5.0
|
||||
github.com/cli/go-gh v1.0.0
|
||||
github.com/cli/oauth v1.0.0
|
||||
github.com/cli/oauth v1.0.1
|
||||
github.com/cli/safeexec v1.0.1
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2
|
||||
github.com/creack/pty v1.1.18
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -62,8 +62,8 @@ github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai
|
|||
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
github.com/cli/go-gh v1.0.0 h1:zE1YUAUYqGXNZuICEBeOkIMJ5F50BS0ftvtoWGlsEFI=
|
||||
github.com/cli/go-gh v1.0.0/go.mod h1:bqxLdCoTZ73BuiPEJx4olcO/XKhVZaFDchFagYRBweE=
|
||||
github.com/cli/oauth v1.0.0 h1:zuatYn8BRWWO98y2jNXK4RKOryU1u6JTqPrdSPW5pSE=
|
||||
github.com/cli/oauth v1.0.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
|
||||
github.com/cli/oauth v1.0.1 h1:pXnTFl/qUegXHK531Dv0LNjW4mLx626eS42gnzfXJPA=
|
||||
github.com/cli/oauth v1.0.1/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
|
||||
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
|
||||
github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00=
|
||||
github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: codespace/codespace_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package codespace
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ CodespaceHostServer = &CodespaceHostServerMock{}
|
||||
|
||||
// CodespaceHostServerMock is a mock implementation of CodespaceHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked CodespaceHostServer
|
||||
// mockedCodespaceHostServer := &CodespaceHostServerMock{
|
||||
// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
|
||||
// panic("mock out the NotifyCodespaceOfClientActivity method")
|
||||
// },
|
||||
// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
|
||||
// panic("mock out the RebuildContainerAsync method")
|
||||
// },
|
||||
// mustEmbedUnimplementedCodespaceHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedCodespaceHostServer in code that requires CodespaceHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type CodespaceHostServerMock struct {
|
||||
// NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method.
|
||||
NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error)
|
||||
|
||||
// RebuildContainerAsyncFunc mocks the RebuildContainerAsync method.
|
||||
RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method.
|
||||
mustEmbedUnimplementedCodespaceHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method.
|
||||
NotifyCodespaceOfClientActivity []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value.
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}
|
||||
// RebuildContainerAsync holds details about calls to the RebuildContainerAsync method.
|
||||
RebuildContainerAsync []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// RebuildContainerRequest is the rebuildContainerRequest argument value.
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method.
|
||||
mustEmbedUnimplementedCodespaceHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockNotifyCodespaceOfClientActivity sync.RWMutex
|
||||
lockRebuildContainerAsync sync.RWMutex
|
||||
lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc.
|
||||
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
|
||||
if mock.NotifyCodespaceOfClientActivityFunc == nil {
|
||||
panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest,
|
||||
}
|
||||
mock.lockNotifyCodespaceOfClientActivity.Lock()
|
||||
mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo)
|
||||
mock.lockNotifyCodespaceOfClientActivity.Unlock()
|
||||
return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest)
|
||||
}
|
||||
|
||||
// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls())
|
||||
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}
|
||||
mock.lockNotifyCodespaceOfClientActivity.RLock()
|
||||
calls = mock.calls.NotifyCodespaceOfClientActivity
|
||||
mock.lockNotifyCodespaceOfClientActivity.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// RebuildContainerAsync calls RebuildContainerAsyncFunc.
|
||||
func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
|
||||
if mock.RebuildContainerAsyncFunc == nil {
|
||||
panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
RebuildContainerRequest: rebuildContainerRequest,
|
||||
}
|
||||
mock.lockRebuildContainerAsync.Lock()
|
||||
mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo)
|
||||
mock.lockRebuildContainerAsync.Unlock()
|
||||
return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest)
|
||||
}
|
||||
|
||||
// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls())
|
||||
func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}
|
||||
mock.lockRebuildContainerAsync.RLock()
|
||||
calls = mock.calls.RebuildContainerAsync
|
||||
mock.lockRebuildContainerAsync.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc.
|
||||
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() {
|
||||
if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil {
|
||||
panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedCodespaceHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls())
|
||||
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers.
|
|||
|
||||
1. [Download `protoc`](https://grpc.io/docs/protoc-installation/)
|
||||
2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/)
|
||||
3. Run `./generate.sh` from the `internal/codespaces/grpc` directory
|
||||
3. Install moq: `go install github.com/matryer/moq@latest`
|
||||
4. Run `./generate.sh` from the `internal/codespaces/rpc` directory
|
||||
|
||||
## Add New Protocol Buffers
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then
|
|||
fi
|
||||
|
||||
function generate {
|
||||
local contract="$1"
|
||||
local dir="$1"
|
||||
local proto="$2"
|
||||
|
||||
local contract="$dir/$proto"
|
||||
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract"
|
||||
echo "Generated protocol buffers for $contract"
|
||||
|
||||
services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}')
|
||||
moq -out $contract.mock.go $dir $services
|
||||
echo "Generated mock protocols for $contract"
|
||||
}
|
||||
|
||||
generate jupyter/jupyter_server_host_service.v1.proto
|
||||
generate codespace/codespace_host_service.v1.proto
|
||||
generate ssh/ssh_server_host_service.v1.proto
|
||||
generate jupyter jupyter_server_host_service.v1.proto
|
||||
generate codespace codespace_host_service.v1.proto
|
||||
generate ssh ssh_server_host_service.v1.proto
|
||||
|
||||
echo 'Done!'
|
||||
|
|
|
|||
|
|
@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
|
|||
invoker.codespaceClient = codespace.NewCodespaceHostClient(conn)
|
||||
invoker.sshClient = ssh.NewSshServerHostClient(conn)
|
||||
|
||||
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
|
||||
_ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName)
|
||||
|
||||
// Start the activity heatbeats
|
||||
go invoker.heartbeat(pfctx, 1*time.Minute)
|
||||
|
||||
|
|
@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
|
|||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
|
||||
_ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
|
|||
|
|
@ -3,84 +3,159 @@ package rpc
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
||||
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func startServer(t *testing.T) {
|
||||
t.Helper()
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
|
||||
type mockServer struct {
|
||||
jupyter.JupyterServerHostServerMock
|
||||
codespace.CodespaceHostServerMock
|
||||
ssh.SshServerHostServerMock
|
||||
}
|
||||
|
||||
func newMockServer() *mockServer {
|
||||
server := &mockServer{}
|
||||
|
||||
server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
|
||||
return &codespace.NotifyCodespaceOfClientActivityResponse{
|
||||
Message: "",
|
||||
Result: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks.
|
||||
// It does not return until the Context is cancelled and the server fully shuts down.
|
||||
func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error {
|
||||
s := grpc.NewServer()
|
||||
jupyter.RegisterJupyterServerHostServer(s, server)
|
||||
codespace.RegisterCodespaceHostServer(s, server)
|
||||
ssh.RegisterSshServerHostServer(s, server)
|
||||
|
||||
ch := make(chan error, 1)
|
||||
go func() { ch <- s.Serve(listener) }()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Stop()
|
||||
<-ch
|
||||
return nil
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function.
|
||||
// The Invoker does not need to be closed directly, that will be handled by the shutdown function.
|
||||
func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:16634")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ch := make(chan error)
|
||||
go func() { ch <- runTestGrpcServer(ctx, listener, server) }()
|
||||
|
||||
// Start the gRPC server in the background
|
||||
go func() {
|
||||
err := rpctest.StartServer(ctx)
|
||||
if err != nil && err != context.Canceled {
|
||||
log.Println(fmt.Errorf("error starting test server: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Stop the gRPC server when the test is done
|
||||
t.Cleanup(func() {
|
||||
close := func() {
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
func createTestInvoker(t *testing.T) Invoker {
|
||||
t.Helper()
|
||||
|
||||
// Clear the stored client activity
|
||||
rpctest.NotifyReceivedActivity = ""
|
||||
<-ch
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
close()
|
||||
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
testNotifyCodespaceOfClientActivity(t)
|
||||
return invoker, func() {
|
||||
invoker.Close()
|
||||
})
|
||||
|
||||
return invoker
|
||||
close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test that the RPC invoker notifies the codespace of client activity on connection
|
||||
func testNotifyCodespaceOfClientActivity(t *testing.T) {
|
||||
if rpctest.NotifyReceivedActivity != connectedEventName {
|
||||
t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage)
|
||||
func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) {
|
||||
calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls()
|
||||
if len(calls) == 0 {
|
||||
t.Fatalf("no client activity calls")
|
||||
}
|
||||
|
||||
for _, call := range calls {
|
||||
activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities()
|
||||
if activities[0] == connectedEventName {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("no activity named %s", connectedEventName)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
|
||||
func TestStartJupyterServerSuccess(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
resp := jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(1234),
|
||||
ServerUrl: "http://localhost:1234?token=1234",
|
||||
Message: "",
|
||||
Result: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
port, url, err := invoker.StartJupyterServer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
if port != rpctest.JupyterPort {
|
||||
t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port)
|
||||
if strconv.Itoa(port) != resp.Port {
|
||||
t.Fatalf("expected %s, got %d", resp.Port, port)
|
||||
}
|
||||
if url != rpctest.JupyterServerUrl {
|
||||
t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url)
|
||||
if url != resp.ServerUrl {
|
||||
t.Fatalf("expected %s, got %s", resp.ServerUrl, url)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
|
||||
func TestStartJupyterServerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.JupyterMessage = "error message"
|
||||
rpctest.JupyterResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
|
||||
resp := jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(1234),
|
||||
ServerUrl: "http://localhost:1234?token=1234",
|
||||
Message: "error message",
|
||||
Result: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message)
|
||||
port, url, err := invoker.StartJupyterServer(context.Background())
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
|
|
@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) {
|
|||
if url != "" {
|
||||
t.Fatalf("expected %s, got %s", "", url)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild
|
||||
func TestRebuildContainerIncremental(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
err := invoker.RebuildContainer(context.Background(), false)
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
err = invoker.RebuildContainer(context.Background(), false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker doesn't throw an error when requesting a full rebuild
|
||||
func TestRebuildContainerFull(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
err := invoker.RebuildContainer(context.Background(), true)
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
err = invoker.RebuildContainer(context.Background(), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker throws an error when the rebuild fails
|
||||
func TestRebuildContainerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.RebuildContainer = false
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := "couldn't rebuild codespace"
|
||||
err := invoker.RebuildContainer(context.Background(), true)
|
||||
err = invoker.RebuildContainer(context.Background(), true)
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
}
|
||||
|
|
@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) {
|
|||
|
||||
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
|
||||
func TestStartSSHServerSuccess(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
resp := ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(1234),
|
||||
User: "test",
|
||||
Message: "",
|
||||
Result: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
port, user, err := invoker.StartSSHServer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
if port != rpctest.SshServerPort {
|
||||
t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port)
|
||||
if strconv.Itoa(port) != resp.ServerPort {
|
||||
t.Fatalf("expected %s, got %d", resp.ServerPort, port)
|
||||
}
|
||||
if user != rpctest.SshUser {
|
||||
t.Fatalf("expected %s, got %s", rpctest.SshUser, user)
|
||||
if user != resp.User {
|
||||
t.Fatalf("expected %s, got %s", resp.User, user)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns an error when the SSH server fails to start
|
||||
func TestStartSSHServerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.SshMessage = "error message"
|
||||
rpctest.SshResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage)
|
||||
resp := ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(1234),
|
||||
User: "test",
|
||||
Message: "error message",
|
||||
Result: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message)
|
||||
port, user, err := invoker.StartSSHServer(context.Background())
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: jupyter/jupyter_server_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package jupyter
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ JupyterServerHostServer = &JupyterServerHostServerMock{}
|
||||
|
||||
// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked JupyterServerHostServer
|
||||
// mockedJupyterServerHostServer := &JupyterServerHostServerMock{
|
||||
// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
|
||||
// panic("mock out the GetRunningServer method")
|
||||
// },
|
||||
// mustEmbedUnimplementedJupyterServerHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type JupyterServerHostServerMock struct {
|
||||
// GetRunningServerFunc mocks the GetRunningServer method.
|
||||
GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method.
|
||||
mustEmbedUnimplementedJupyterServerHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// GetRunningServer holds details about calls to the GetRunningServer method.
|
||||
GetRunningServer []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// GetRunningServerRequest is the getRunningServerRequest argument value.
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method.
|
||||
mustEmbedUnimplementedJupyterServerHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockGetRunningServer sync.RWMutex
|
||||
lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// GetRunningServer calls GetRunningServerFunc.
|
||||
func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
|
||||
if mock.GetRunningServerFunc == nil {
|
||||
panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
GetRunningServerRequest: getRunningServerRequest,
|
||||
}
|
||||
mock.lockGetRunningServer.Lock()
|
||||
mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo)
|
||||
mock.lockGetRunningServer.Unlock()
|
||||
return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest)
|
||||
}
|
||||
|
||||
// GetRunningServerCalls gets all the calls that were made to GetRunningServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedJupyterServerHostServer.GetRunningServerCalls())
|
||||
func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}
|
||||
mock.lockGetRunningServer.RLock()
|
||||
calls = mock.calls.GetRunningServer
|
||||
mock.lockGetRunningServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc.
|
||||
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() {
|
||||
if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil {
|
||||
panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedJupyterServerHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls())
|
||||
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: ssh/ssh_server_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that SshServerHostServerMock does implement SshServerHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ SshServerHostServer = &SshServerHostServerMock{}
|
||||
|
||||
// SshServerHostServerMock is a mock implementation of SshServerHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesSshServerHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked SshServerHostServer
|
||||
// mockedSshServerHostServer := &SshServerHostServerMock{
|
||||
// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
|
||||
// panic("mock out the StartRemoteServerAsync method")
|
||||
// },
|
||||
// mustEmbedUnimplementedSshServerHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedSshServerHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedSshServerHostServer in code that requires SshServerHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type SshServerHostServerMock struct {
|
||||
// StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method.
|
||||
StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method.
|
||||
mustEmbedUnimplementedSshServerHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method.
|
||||
StartRemoteServerAsync []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// StartRemoteServerRequest is the startRemoteServerRequest argument value.
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method.
|
||||
mustEmbedUnimplementedSshServerHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockStartRemoteServerAsync sync.RWMutex
|
||||
lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// StartRemoteServerAsync calls StartRemoteServerAsyncFunc.
|
||||
func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
|
||||
if mock.StartRemoteServerAsyncFunc == nil {
|
||||
panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
StartRemoteServerRequest: startRemoteServerRequest,
|
||||
}
|
||||
mock.lockStartRemoteServerAsync.Lock()
|
||||
mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo)
|
||||
mock.lockStartRemoteServerAsync.Unlock()
|
||||
return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest)
|
||||
}
|
||||
|
||||
// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls())
|
||||
func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}
|
||||
mock.lockStartRemoteServerAsync.RLock()
|
||||
calls = mock.calls.StartRemoteServerAsync
|
||||
mock.lockStartRemoteServerAsync.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc.
|
||||
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() {
|
||||
if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil {
|
||||
panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedSshServerHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls())
|
||||
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedSshServerHostServer
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
ServerPort = 50051
|
||||
)
|
||||
|
||||
// Mock responses for the `GetRunningServer` RPC method
|
||||
var (
|
||||
JupyterPort = 1234
|
||||
JupyterServerUrl = "http://localhost:1234?token=1234"
|
||||
JupyterMessage = ""
|
||||
JupyterResult = true
|
||||
)
|
||||
|
||||
// Mock responses for the `RebuildContainerAsync` RPC method
|
||||
var (
|
||||
RebuildContainer = true
|
||||
)
|
||||
|
||||
// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method
|
||||
// NotifyMessage is used to store the activity that was sent to the server
|
||||
var (
|
||||
NotifyMessage = ""
|
||||
NotifyResult = true
|
||||
NotifyReceivedActivity = ""
|
||||
)
|
||||
|
||||
// Mock responses for the `StartRemoteServerAsync` RPC method
|
||||
var (
|
||||
SshServerPort = 1234
|
||||
SshUser = "test"
|
||||
SshMessage = ""
|
||||
SshResult = true
|
||||
)
|
||||
|
||||
type server struct {
|
||||
jupyter.UnimplementedJupyterServerHostServer
|
||||
codespace.CodespaceHostServer
|
||||
ssh.SshServerHostServer
|
||||
}
|
||||
|
||||
func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(JupyterPort),
|
||||
ServerUrl: JupyterServerUrl,
|
||||
Message: JupyterMessage,
|
||||
Result: JupyterResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &codespace.RebuildContainerResponse{
|
||||
RebuildContainer: RebuildContainer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
|
||||
// If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected")
|
||||
if len(in.GetClientActivities()) > 0 {
|
||||
NotifyReceivedActivity = in.GetClientActivities()[0]
|
||||
}
|
||||
|
||||
return &codespace.NotifyCodespaceOfClientActivityResponse{
|
||||
Message: NotifyMessage,
|
||||
Result: NotifyResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(SshServerPort),
|
||||
User: SshUser,
|
||||
Message: SshMessage,
|
||||
Result: SshResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Starts the mock gRPC server listening on port 50051
|
||||
func StartServer(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
s := grpc.NewServer()
|
||||
jupyter.RegisterJupyterServerHostServer(s, &server{})
|
||||
codespace.RegisterCodespaceHostServer(s, &server{})
|
||||
ssh.RegisterSshServerHostServer(s, &server{})
|
||||
|
||||
ch := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Serve(listener); err != nil {
|
||||
ch <- fmt.Errorf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Stop()
|
||||
return ctx.Err()
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string {
|
|||
}
|
||||
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return liveshare.ChannelID{}, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
|
|||
opts.KeyID = args[0]
|
||||
|
||||
if !opts.IO.CanPrompt() && !opts.Confirmed {
|
||||
return cmdutil.FlagErrorf("--confirm required when not running interactively")
|
||||
return cmdutil.FlagErrorf("--yes required when not running interactively")
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
|
|
@ -48,7 +48,9 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
|
|||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) {
|
|||
{
|
||||
name: "confirm flag tty",
|
||||
tty: true,
|
||||
input: "ABC123 --confirm",
|
||||
input: "ABC123 --yes",
|
||||
output: DeleteOptions{KeyID: "ABC123", Confirmed: true},
|
||||
},
|
||||
{
|
||||
|
|
@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) {
|
|||
name: "no tty",
|
||||
input: "ABC123",
|
||||
wantErr: true,
|
||||
wantErrMsg: "--confirm required when not running interactively",
|
||||
wantErrMsg: "--yes required when not running interactively",
|
||||
},
|
||||
{
|
||||
name: "confirm flag no tty",
|
||||
input: "ABC123 --confirm",
|
||||
input: "ABC123 --yes",
|
||||
output: DeleteOptions{KeyID: "ABC123", Confirmed: true},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -49,11 +49,14 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
|
|||
if runF != nil {
|
||||
return runF(opts)
|
||||
}
|
||||
|
||||
return deleteRun(opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "confirm deletion without prompting")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting")
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co
|
|||
opts.Name = args[0]
|
||||
|
||||
if !opts.IO.CanPrompt() && !opts.Confirmed {
|
||||
return cmdutil.FlagErrorf("--confirm required when not running interactively")
|
||||
return cmdutil.FlagErrorf("--yes required when not running interactively")
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
|
|
@ -53,6 +53,8 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co
|
|||
}
|
||||
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Confirm deletion without prompting")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "Confirm deletion without prompting")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ func TestNewCmdDelete(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "confirm argument",
|
||||
input: "test --confirm",
|
||||
input: "test --yes",
|
||||
output: deleteOptions{Name: "test", Confirmed: true},
|
||||
},
|
||||
{
|
||||
name: "confirm no tty",
|
||||
input: "test",
|
||||
wantErr: true,
|
||||
wantErrMsg: "--confirm required when not running interactively",
|
||||
wantErrMsg: "--yes required when not running interactively",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,16 +47,20 @@ With no argument, archives the current repository.`),
|
|||
}
|
||||
|
||||
if !opts.Confirmed && !opts.IO.CanPrompt() {
|
||||
return cmdutil.FlagErrorf("--confirm required when not running interactively")
|
||||
return cmdutil.FlagErrorf("--yes required when not running interactively")
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
return runF(opts)
|
||||
}
|
||||
|
||||
return archiveRun(opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ func TestNewCmdArchive(t *testing.T) {
|
|||
{
|
||||
name: "no arguments no tty",
|
||||
input: "",
|
||||
errMsg: "--confirm required when not running interactively",
|
||||
errMsg: "--yes required when not running interactively",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co
|
|||
|
||||
if len(args) == 1 && !confirm && !opts.HasRepoOverride {
|
||||
if !opts.IO.CanPrompt() {
|
||||
return cmdutil.FlagErrorf("--confirm required when passing a single argument")
|
||||
return cmdutil.FlagErrorf("--yes required when passing a single argument")
|
||||
}
|
||||
opts.DoConfirm = true
|
||||
}
|
||||
|
|
@ -68,12 +68,15 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co
|
|||
if runf != nil {
|
||||
return runf(opts)
|
||||
}
|
||||
|
||||
return renameRun(opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmdutil.EnableRepoOverride(cmd, f)
|
||||
cmd.Flags().BoolVarP(&confirm, "confirm", "y", false, "skip confirmation prompt")
|
||||
cmd.Flags().BoolVar(&confirm, "confirm", false, "Skip confirmation prompt")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVarP(&confirm, "yes", "y", false, "Skip the confirmation prompt")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ func TestNewCmdRename(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "one argument no tty confirmed",
|
||||
input: "REPO --confirm",
|
||||
input: "REPO --yes",
|
||||
output: RenameOptions{
|
||||
newRepoSelector: "REPO",
|
||||
},
|
||||
|
|
@ -43,12 +43,12 @@ func TestNewCmdRename(t *testing.T) {
|
|||
{
|
||||
name: "one argument no tty",
|
||||
input: "REPO",
|
||||
errMsg: "--confirm required when passing a single argument",
|
||||
errMsg: "--yes required when passing a single argument",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "one argument tty confirmed",
|
||||
input: "REPO --confirm",
|
||||
input: "REPO --yes",
|
||||
tty: true,
|
||||
output: RenameOptions{
|
||||
newRepoSelector: "REPO",
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ type Run struct {
|
|||
workflowName string // cache column
|
||||
WorkflowID int64 `json:"workflow_id"`
|
||||
Number int64 `json:"run_number"`
|
||||
Attempts uint8 `json:"run_attempt"`
|
||||
Attempts uint64 `json:"run_attempt"`
|
||||
HeadBranch string `json:"head_branch"`
|
||||
JobsURL string `json:"jobs_url"`
|
||||
HeadCommit Commit `json:"head_commit"`
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ type SecretPayload struct {
|
|||
KeyID string `json:"key_id"`
|
||||
}
|
||||
|
||||
// The Codespaces Secret API currently expects repositories IDs as strings
|
||||
type CodespacesSecretPayload struct {
|
||||
type DependabotSecretPayload struct {
|
||||
EncryptedValue string `json:"encrypted_value"`
|
||||
Visibility string `json:"visibility,omitempty"`
|
||||
Repositories []string `json:"selected_repository_ids,omitempty"`
|
||||
KeyID string `json:"key_id"`
|
||||
}
|
||||
|
|
@ -70,31 +70,40 @@ func putSecret(client *api.Client, host, path string, payload interface{}) error
|
|||
}
|
||||
|
||||
func putOrgSecret(client *api.Client, host string, pk *PubKey, orgName, visibility, secretName, eValue string, repositoryIDs []int64, app shared.App) error {
|
||||
path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName)
|
||||
|
||||
if app == shared.Dependabot {
|
||||
repos := make([]string, len(repositoryIDs))
|
||||
for i, id := range repositoryIDs {
|
||||
repos[i] = strconv.FormatInt(id, 10)
|
||||
}
|
||||
|
||||
payload := DependabotSecretPayload{
|
||||
EncryptedValue: eValue,
|
||||
KeyID: pk.ID,
|
||||
Repositories: repos,
|
||||
Visibility: visibility,
|
||||
}
|
||||
|
||||
return putSecret(client, host, path, payload)
|
||||
}
|
||||
|
||||
payload := SecretPayload{
|
||||
EncryptedValue: eValue,
|
||||
KeyID: pk.ID,
|
||||
Repositories: repositoryIDs,
|
||||
Visibility: visibility,
|
||||
}
|
||||
path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName)
|
||||
|
||||
return putSecret(client, host, path, payload)
|
||||
}
|
||||
|
||||
func putUserSecret(client *api.Client, host string, pk *PubKey, key, eValue string, repositoryIDs []int64) error {
|
||||
payload := CodespacesSecretPayload{
|
||||
payload := SecretPayload{
|
||||
EncryptedValue: eValue,
|
||||
KeyID: pk.ID,
|
||||
Repositories: repositoryIDs,
|
||||
}
|
||||
|
||||
if len(repositoryIDs) > 0 {
|
||||
repositoryStringIDs := make([]string, len(repositoryIDs))
|
||||
for i, id := range repositoryIDs {
|
||||
repositoryStringIDs[i] = strconv.FormatInt(id, 10)
|
||||
}
|
||||
payload.Repositories = repositoryStringIDs
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("user/codespaces/secrets/%s", key)
|
||||
return putSecret(client, host, path, payload)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -333,11 +333,12 @@ func Test_setRun_env(t *testing.T) {
|
|||
|
||||
func Test_setRun_org(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *SetOptions
|
||||
wantVisibility shared.Visibility
|
||||
wantRepositories []int64
|
||||
wantApp string
|
||||
name string
|
||||
opts *SetOptions
|
||||
wantVisibility shared.Visibility
|
||||
wantRepositories []int64
|
||||
wantDependabotRepositories []string
|
||||
wantApp string
|
||||
}{
|
||||
{
|
||||
name: "all vis",
|
||||
|
|
@ -362,10 +363,21 @@ func Test_setRun_org(t *testing.T) {
|
|||
opts: &SetOptions{
|
||||
OrgName: "UmbrellaCorporation",
|
||||
Visibility: shared.All,
|
||||
Application: "dependabot",
|
||||
Application: shared.Dependabot,
|
||||
},
|
||||
wantApp: "dependabot",
|
||||
},
|
||||
{
|
||||
name: "Dependabot selected visibility",
|
||||
opts: &SetOptions{
|
||||
OrgName: "UmbrellaCorporation",
|
||||
Visibility: shared.Selected,
|
||||
Application: shared.Dependabot,
|
||||
RepositoryNames: []string{"birkin", "UmbrellaCorporation/wesker"},
|
||||
},
|
||||
wantDependabotRepositories: []string{"1", "2"},
|
||||
wantApp: "dependabot",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -410,13 +422,24 @@ func Test_setRun_org(t *testing.T) {
|
|||
|
||||
data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body)
|
||||
assert.NoError(t, err)
|
||||
var payload SecretPayload
|
||||
err = json.Unmarshal(data, &payload)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, payload.KeyID, "123")
|
||||
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
|
||||
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
|
||||
assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories)
|
||||
|
||||
if tt.opts.Application == shared.Dependabot {
|
||||
var payload DependabotSecretPayload
|
||||
err = json.Unmarshal(data, &payload)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, payload.KeyID, "123")
|
||||
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
|
||||
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
|
||||
assert.ElementsMatch(t, payload.Repositories, tt.wantDependabotRepositories)
|
||||
} else {
|
||||
var payload SecretPayload
|
||||
err = json.Unmarshal(data, &payload)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, payload.KeyID, "123")
|
||||
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
|
||||
assert.Equal(t, payload.Visibility, tt.opts.Visibility)
|
||||
assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -426,7 +449,7 @@ func Test_setRun_user(t *testing.T) {
|
|||
name string
|
||||
opts *SetOptions
|
||||
wantVisibility shared.Visibility
|
||||
wantRepositories []string
|
||||
wantRepositories []int64
|
||||
}{
|
||||
{
|
||||
name: "all vis",
|
||||
|
|
@ -442,7 +465,7 @@ func Test_setRun_user(t *testing.T) {
|
|||
Visibility: shared.Selected,
|
||||
RepositoryNames: []string{"cli/cli", "github/hub"},
|
||||
},
|
||||
wantRepositories: []string{"212613049", "401025"},
|
||||
wantRepositories: []int64{212613049, 401025},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -481,7 +504,7 @@ func Test_setRun_user(t *testing.T) {
|
|||
|
||||
data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body)
|
||||
assert.NoError(t, err)
|
||||
var payload CodespacesSecretPayload
|
||||
var payload SecretPayload
|
||||
err = json.Unmarshal(data, &payload)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, payload.KeyID, "123")
|
||||
|
|
|
|||
|
|
@ -37,17 +37,21 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
|
|||
opts.KeyID = args[0]
|
||||
|
||||
if !opts.IO.CanPrompt() && !opts.Confirmed {
|
||||
return cmdutil.FlagErrorf("--confirm required when not running interactively")
|
||||
return cmdutil.FlagErrorf("--yes required when not running interactively")
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
return runF(opts)
|
||||
}
|
||||
|
||||
return deleteRun(opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt")
|
||||
cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt")
|
||||
_ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead")
|
||||
cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) {
|
|||
{
|
||||
name: "confirm flag tty",
|
||||
tty: true,
|
||||
input: "123 --confirm",
|
||||
input: "123 --yes",
|
||||
output: DeleteOptions{KeyID: "123", Confirmed: true},
|
||||
},
|
||||
{
|
||||
|
|
@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) {
|
|||
name: "no tty",
|
||||
input: "123",
|
||||
wantErr: true,
|
||||
wantErrMsg: "--confirm required when not running interactively",
|
||||
wantErrMsg: "--yes required when not running interactively",
|
||||
},
|
||||
{
|
||||
name: "confirm flag no tty",
|
||||
input: "123 --confirm",
|
||||
input: "123 --yes",
|
||||
output: DeleteOptions{KeyID: "123", Confirmed: true},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue