From f669a10cf9869f4171e94582fc4326cd77ff4b07 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Wed, 25 Jan 2023 14:57:21 -0600 Subject: [PATCH 1/8] Fix race conditions in invoker_test (#6905) --- .../codespace/codespace_host_service.v1.pb.go | 2 +- .../codespace_host_service.v1.proto.mock.go | 168 +++++++++++ internal/codespaces/rpc/generate.md | 3 +- internal/codespaces/rpc/generate.sh | 15 +- internal/codespaces/rpc/invoker.go | 6 +- internal/codespaces/rpc/invoker_test.go | 281 ++++++++++++++---- .../jupyter_server_host_service.v1.pb.go | 2 +- ...pyter_server_host_service.v1.proto.mock.go | 118 ++++++++ .../rpc/ssh/ssh_server_host_service.v1.pb.go | 2 +- .../ssh_server_host_service.v1.proto.mock.go | 118 ++++++++ internal/codespaces/rpc/test/server.go | 117 -------- internal/codespaces/rpc/test/session.go | 2 +- 12 files changed, 640 insertions(+), 194 deletions(-) create mode 100644 internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go create mode 100644 internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go create mode 100644 internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go delete mode 100644 internal/codespaces/rpc/test/server.go diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go index 21b908838..6da7f9e39 100644 --- a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: codespace/codespace_host_service.v1.proto diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go new file mode 100644 index 000000000..246849fe0 --- /dev/null +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go @@ -0,0 +1,168 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package codespace + +import ( + context "context" + sync "sync" +) + +// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer. +// If this is not the case, regenerate this file with moq. +var _ CodespaceHostServer = &CodespaceHostServerMock{} + +// CodespaceHostServerMock is a mock implementation of CodespaceHostServer. +// +// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) { +// +// // make and configure a mocked CodespaceHostServer +// mockedCodespaceHostServer := &CodespaceHostServerMock{ +// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { +// panic("mock out the NotifyCodespaceOfClientActivity method") +// }, +// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { +// panic("mock out the RebuildContainerAsync method") +// }, +// mustEmbedUnimplementedCodespaceHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method") +// }, +// } +// +// // use mockedCodespaceHostServer in code that requires CodespaceHostServer +// // and then make assertions. +// +// } +type CodespaceHostServerMock struct { + // NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) + + // RebuildContainerAsyncFunc mocks the RebuildContainerAsync method. + RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) + + // mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivity []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value. + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + // RebuildContainerAsync holds details about calls to the RebuildContainerAsync method. + RebuildContainerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // RebuildContainerRequest is the rebuildContainerRequest argument value. + RebuildContainerRequest *RebuildContainerRequest + } + // mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServer []struct { + } + } + lockNotifyCodespaceOfClientActivity sync.RWMutex + lockRebuildContainerAsync sync.RWMutex + lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex +} + +// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc. +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { + if mock.NotifyCodespaceOfClientActivityFunc == nil { + panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called") + } + callInfo := struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + }{ + ContextMoqParam: contextMoqParam, + NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest, + } + mock.lockNotifyCodespaceOfClientActivity.Lock() + mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo) + mock.lockNotifyCodespaceOfClientActivity.Unlock() + return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest) +} + +// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity. +// Check the length with: +// +// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls()) +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest +} { + var calls []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + mock.lockNotifyCodespaceOfClientActivity.RLock() + calls = mock.calls.NotifyCodespaceOfClientActivity + mock.lockNotifyCodespaceOfClientActivity.RUnlock() + return calls +} + +// RebuildContainerAsync calls RebuildContainerAsyncFunc. +func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { + if mock.RebuildContainerAsyncFunc == nil { + panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + }{ + ContextMoqParam: contextMoqParam, + RebuildContainerRequest: rebuildContainerRequest, + } + mock.lockRebuildContainerAsync.Lock() + mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo) + mock.lockRebuildContainerAsync.Unlock() + return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest) +} + +// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync. +// Check the length with: +// +// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls()) +func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest +} { + var calls []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + } + mock.lockRebuildContainerAsync.RLock() + calls = mock.calls.RebuildContainerAsync + mock.lockRebuildContainerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc. +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() { + if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil { + panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock() + mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo) + mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock() + mock.mustEmbedUnimplementedCodespaceHostServerFunc() +} + +// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer. +// Check the length with: +// +// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls()) +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer + mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/generate.md b/internal/codespaces/rpc/generate.md index 7ae1dcc1a..d0d6bbc9d 100644 --- a/internal/codespaces/rpc/generate.md +++ b/internal/codespaces/rpc/generate.md @@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers. 1. [Download `protoc`](https://grpc.io/docs/protoc-installation/) 2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/) -3. Run `./generate.sh` from the `internal/codespaces/grpc` directory +3. Install moq: `go install github.com/matryer/moq@latest` +4. Run `./generate.sh` from the `internal/codespaces/rpc` directory ## Add New Protocol Buffers diff --git a/internal/codespaces/rpc/generate.sh b/internal/codespaces/rpc/generate.sh index 159803bbe..4ba2f898a 100755 --- a/internal/codespaces/rpc/generate.sh +++ b/internal/codespaces/rpc/generate.sh @@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then fi function generate { - local contract="$1" + local dir="$1" + local proto="$2" + + local contract="$dir/$proto" protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract" echo "Generated protocol buffers for $contract" + + services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}') + moq -out $contract.mock.go $dir $services + echo "Generated mock protocols for $contract" } -generate jupyter/jupyter_server_host_service.v1.proto -generate codespace/codespace_host_service.v1.proto -generate ssh/ssh_server_host_service.v1.proto +generate jupyter jupyter_server_host_service.v1.proto +generate codespace codespace_host_service.v1.proto +generate ssh ssh_server_host_service.v1.proto echo 'Done!' diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index fa3c2897a..bb2e25a55 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, invoker.codespaceClient = codespace.NewCodespaceHostClient(conn) invoker.sshClient = ssh.NewSshServerHostClient(conn) + // Send initial connection heartbeat (no need to throw if we fail to get a response from the server) + _ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName) + // Start the activity heatbeats go invoker.heartbeat(pfctx, 1*time.Minute) @@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() - // Send initial connection heartbeat (no need to throw if we fail to get a response from the server) - _ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName) - for { select { case <-ctx.Done(): diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index c01148cf1..ba3e13ac3 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -3,84 +3,159 @@ package rpc import ( "context" "fmt" - "log" - "os" + "net" + "strconv" "testing" + "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" + "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" + "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test" + "google.golang.org/grpc" ) -func startServer(t *testing.T) { - t.Helper() - if os.Getenv("GITHUB_ACTIONS") == "true" { - t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663") +type mockServer struct { + jupyter.JupyterServerHostServerMock + codespace.CodespaceHostServerMock + ssh.SshServerHostServerMock +} + +func newMockServer() *mockServer { + server := &mockServer{} + + server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) { + return &codespace.NotifyCodespaceOfClientActivityResponse{ + Message: "", + Result: true, + }, nil + } + + return server +} + +// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks. +// It does not return until the Context is cancelled and the server fully shuts down. +func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error { + s := grpc.NewServer() + jupyter.RegisterJupyterServerHostServer(s, server) + codespace.RegisterCodespaceHostServer(s, server) + ssh.RegisterSshServerHostServer(s, server) + + ch := make(chan error, 1) + go func() { ch <- s.Serve(listener) }() + + select { + case <-ctx.Done(): + s.Stop() + <-ch + return nil + case err := <-ch: + return err + } +} + +// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function. +// The Invoker does not need to be closed directly, that will be handled by the shutdown function. +func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) { + listener, err := net.Listen("tcp", "127.0.0.1:16634") + if err != nil { + return nil, nil, fmt.Errorf("failed to listen: %w", err) } ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan error) + go func() { ch <- runTestGrpcServer(ctx, listener, server) }() - // Start the gRPC server in the background - go func() { - err := rpctest.StartServer(ctx) - if err != nil && err != context.Canceled { - log.Println(fmt.Errorf("error starting test server: %v", err)) - } - }() - - // Stop the gRPC server when the test is done - t.Cleanup(func() { + close := func() { cancel() - }) -} - -func createTestInvoker(t *testing.T) Invoker { - t.Helper() - - // Clear the stored client activity - rpctest.NotifyReceivedActivity = "" + <-ch + listener.Close() + } invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}) if err != nil { - t.Fatalf("error connecting to internal server: %v", err) + close() + return nil, nil, fmt.Errorf("error connecting to internal server: %w", err) } - t.Cleanup(func() { - testNotifyCodespaceOfClientActivity(t) + return invoker, func() { invoker.Close() - }) - - return invoker + close() + }, nil } // Test that the RPC invoker notifies the codespace of client activity on connection -func testNotifyCodespaceOfClientActivity(t *testing.T) { - if rpctest.NotifyReceivedActivity != connectedEventName { - t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage) +func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) { + calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls() + if len(calls) == 0 { + t.Fatalf("no client activity calls") } + + for _, call := range calls { + activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities() + if activities[0] == connectedEventName { + return + } + } + + t.Fatalf("no activity named %s", connectedEventName) } // Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "", + Result: true, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, url, err := invoker.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.JupyterPort { - t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port) + if strconv.Itoa(port) != resp.Port { + t.Fatalf("expected %s, got %d", resp.Port, port) } - if url != rpctest.JupyterServerUrl { - t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url) + if url != resp.ServerUrl { + t.Fatalf("expected %s, got %s", resp.ServerUrl, url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.JupyterMessage = "error message" - rpctest.JupyterResult = false - errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message) port, url, err := invoker.StartJupyterServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) @@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) { if url != "" { t.Fatalf("expected %s, got %s", "", url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild func TestRebuildContainerIncremental(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), false) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), false) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting a full rebuild func TestRebuildContainerFull(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), true) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), true) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker throws an error when the rebuild fails func TestRebuildContainerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.RebuildContainer = false + resp := codespace.RebuildContainerResponse{ + RebuildContainer: false, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + errorMessage := "couldn't rebuild codespace" - err := invoker.RebuildContainer(context.Background(), true) + err = invoker.RebuildContainer(context.Background(), true) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) } @@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) { // Test that the RPC invoker returns the correct port and user when the SSH server starts successfully func TestStartSSHServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "", + Result: true, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, user, err := invoker.StartSSHServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.SshServerPort { - t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port) + if strconv.Itoa(port) != resp.ServerPort { + t.Fatalf("expected %s, got %d", resp.ServerPort, port) } - if user != rpctest.SshUser { - t.Fatalf("expected %s, got %s", rpctest.SshUser, user) + if user != resp.User { + t.Fatalf("expected %s, got %s", resp.User, user) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the SSH server fails to start func TestStartSSHServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.SshMessage = "error message" - rpctest.SshResult = false - errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message) port, user, err := invoker.StartSSHServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go index 8e11c6a32..b8f400d3c 100644 --- a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: jupyter/jupyter_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..12ea0bb5b --- /dev/null +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package jupyter + +import ( + context "context" + sync "sync" +) + +// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ JupyterServerHostServer = &JupyterServerHostServerMock{} + +// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer. +// +// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) { +// +// // make and configure a mocked JupyterServerHostServer +// mockedJupyterServerHostServer := &JupyterServerHostServerMock{ +// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { +// panic("mock out the GetRunningServer method") +// }, +// mustEmbedUnimplementedJupyterServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method") +// }, +// } +// +// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer +// // and then make assertions. +// +// } +type JupyterServerHostServerMock struct { + // GetRunningServerFunc mocks the GetRunningServer method. + GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) + + // mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // GetRunningServer holds details about calls to the GetRunningServer method. + GetRunningServer []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // GetRunningServerRequest is the getRunningServerRequest argument value. + GetRunningServerRequest *GetRunningServerRequest + } + // mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServer []struct { + } + } + lockGetRunningServer sync.RWMutex + lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex +} + +// GetRunningServer calls GetRunningServerFunc. +func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { + if mock.GetRunningServerFunc == nil { + panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called") + } + callInfo := struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + }{ + ContextMoqParam: contextMoqParam, + GetRunningServerRequest: getRunningServerRequest, + } + mock.lockGetRunningServer.Lock() + mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo) + mock.lockGetRunningServer.Unlock() + return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest) +} + +// GetRunningServerCalls gets all the calls that were made to GetRunningServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.GetRunningServerCalls()) +func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + } + mock.lockGetRunningServer.RLock() + calls = mock.calls.GetRunningServer + mock.lockGetRunningServer.RUnlock() + return calls +} + +// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc. +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() { + if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil { + panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock() + mock.mustEmbedUnimplementedJupyterServerHostServerFunc() +} + +// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls()) +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go index c495eb781..3dd22f583 100644 --- a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: ssh/ssh_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..d11e99461 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ssh + +import ( + context "context" + sync "sync" +) + +// Ensure, that SshServerHostServerMock does implement SshServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ SshServerHostServer = &SshServerHostServerMock{} + +// SshServerHostServerMock is a mock implementation of SshServerHostServer. +// +// func TestSomethingThatUsesSshServerHostServer(t *testing.T) { +// +// // make and configure a mocked SshServerHostServer +// mockedSshServerHostServer := &SshServerHostServerMock{ +// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { +// panic("mock out the StartRemoteServerAsync method") +// }, +// mustEmbedUnimplementedSshServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedSshServerHostServer method") +// }, +// } +// +// // use mockedSshServerHostServer in code that requires SshServerHostServer +// // and then make assertions. +// +// } +type SshServerHostServerMock struct { + // StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method. + StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) + + // mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method. + StartRemoteServerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // StartRemoteServerRequest is the startRemoteServerRequest argument value. + StartRemoteServerRequest *StartRemoteServerRequest + } + // mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServer []struct { + } + } + lockStartRemoteServerAsync sync.RWMutex + lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex +} + +// StartRemoteServerAsync calls StartRemoteServerAsyncFunc. +func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { + if mock.StartRemoteServerAsyncFunc == nil { + panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + }{ + ContextMoqParam: contextMoqParam, + StartRemoteServerRequest: startRemoteServerRequest, + } + mock.lockStartRemoteServerAsync.Lock() + mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo) + mock.lockStartRemoteServerAsync.Unlock() + return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest) +} + +// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync. +// Check the length with: +// +// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls()) +func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + } + mock.lockStartRemoteServerAsync.RLock() + calls = mock.calls.StartRemoteServerAsync + mock.lockStartRemoteServerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc. +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() { + if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil { + panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedSshServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock() + mock.mustEmbedUnimplementedSshServerHostServerFunc() +} + +// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer. +// Check the length with: +// +// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls()) +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedSshServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedSshServerHostServer + mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/test/server.go b/internal/codespaces/rpc/test/server.go deleted file mode 100644 index 07b80bb83..000000000 --- a/internal/codespaces/rpc/test/server.go +++ /dev/null @@ -1,117 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - "strconv" - - "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" - "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" - "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" - "google.golang.org/grpc" -) - -const ( - ServerPort = 50051 -) - -// Mock responses for the `GetRunningServer` RPC method -var ( - JupyterPort = 1234 - JupyterServerUrl = "http://localhost:1234?token=1234" - JupyterMessage = "" - JupyterResult = true -) - -// Mock responses for the `RebuildContainerAsync` RPC method -var ( - RebuildContainer = true -) - -// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method -// NotifyMessage is used to store the activity that was sent to the server -var ( - NotifyMessage = "" - NotifyResult = true - NotifyReceivedActivity = "" -) - -// Mock responses for the `StartRemoteServerAsync` RPC method -var ( - SshServerPort = 1234 - SshUser = "test" - SshMessage = "" - SshResult = true -) - -type server struct { - jupyter.UnimplementedJupyterServerHostServer - codespace.CodespaceHostServer - ssh.SshServerHostServer -} - -func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { - return &jupyter.GetRunningServerResponse{ - Port: strconv.Itoa(JupyterPort), - ServerUrl: JupyterServerUrl, - Message: JupyterMessage, - Result: JupyterResult, - }, nil -} - -func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { - return &codespace.RebuildContainerResponse{ - RebuildContainer: RebuildContainer, - }, nil -} - -func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) { - // If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected") - if len(in.GetClientActivities()) > 0 { - NotifyReceivedActivity = in.GetClientActivities()[0] - } - - return &codespace.NotifyCodespaceOfClientActivityResponse{ - Message: NotifyMessage, - Result: NotifyResult, - }, nil -} - -func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { - return &ssh.StartRemoteServerResponse{ - ServerPort: strconv.Itoa(SshServerPort), - User: SshUser, - Message: SshMessage, - Result: SshResult, - }, nil -} - -// Starts the mock gRPC server listening on port 50051 -func StartServer(ctx context.Context) error { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) - if err != nil { - return fmt.Errorf("failed to listen: %v", err) - } - defer listener.Close() - - s := grpc.NewServer() - jupyter.RegisterJupyterServerHostServer(s, &server{}) - codespace.RegisterCodespaceHostServer(s, &server{}) - ssh.RegisterSshServerHostServer(s, &server{}) - - ch := make(chan error, 1) - go func() { - if err := s.Serve(listener); err != nil { - ch <- fmt.Errorf("failed to serve: %v", err) - } - }() - - select { - case <-ctx.Done(): - s.Stop() - return ctx.Err() - case err := <-ch: - return err - } -} diff --git a/internal/codespaces/rpc/test/session.go b/internal/codespaces/rpc/test/session.go index 607451392..531d4c33f 100644 --- a/internal/codespaces/rpc/test/session.go +++ b/internal/codespaces/rpc/test/session.go @@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string { } func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return liveshare.ChannelID{}, err } From bab1b00c397eea36b7291a8340a0eb240266a900 Mon Sep 17 00:00:00 2001 From: Damien Sedgwick Date: Thu, 26 Jan 2023 11:48:21 +0000 Subject: [PATCH 2/8] Rename `--confirm` flag to `--yes` for various destructive commands (#6915) --- pkg/cmd/gpg-key/delete/delete.go | 6 ++++-- pkg/cmd/gpg-key/delete/delete_test.go | 6 +++--- pkg/cmd/issue/delete/delete.go | 3 +++ pkg/cmd/label/delete.go | 4 +++- pkg/cmd/label/delete_test.go | 4 ++-- pkg/cmd/repo/archive/archive.go | 8 ++++++-- pkg/cmd/repo/archive/archive_test.go | 2 +- pkg/cmd/repo/rename/rename.go | 7 +++++-- pkg/cmd/repo/rename/rename_test.go | 6 +++--- pkg/cmd/ssh-key/delete/delete.go | 8 ++++++-- pkg/cmd/ssh-key/delete/delete_test.go | 6 +++--- 11 files changed, 39 insertions(+), 21 deletions(-) diff --git a/pkg/cmd/gpg-key/delete/delete.go b/pkg/cmd/gpg-key/delete/delete.go index bb5277a38..b8569f72d 100644 --- a/pkg/cmd/gpg-key/delete/delete.go +++ b/pkg/cmd/gpg-key/delete/delete.go @@ -38,7 +38,7 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co opts.KeyID = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { @@ -48,7 +48,9 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/gpg-key/delete/delete_test.go b/pkg/cmd/gpg-key/delete/delete_test.go index 2835f9cb2..a7f0fda67 100644 --- a/pkg/cmd/gpg-key/delete/delete_test.go +++ b/pkg/cmd/gpg-key/delete/delete_test.go @@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) { { name: "confirm flag tty", tty: true, - input: "ABC123 --confirm", + input: "ABC123 --yes", output: DeleteOptions{KeyID: "ABC123", Confirmed: true}, }, { @@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) { name: "no tty", input: "ABC123", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, { name: "confirm flag no tty", - input: "ABC123 --confirm", + input: "ABC123 --yes", output: DeleteOptions{KeyID: "ABC123", Confirmed: true}, }, { diff --git a/pkg/cmd/issue/delete/delete.go b/pkg/cmd/issue/delete/delete.go index e9297c65a..f0c68de42 100644 --- a/pkg/cmd/issue/delete/delete.go +++ b/pkg/cmd/issue/delete/delete.go @@ -49,11 +49,14 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co if runF != nil { return runF(opts) } + return deleteRun(opts) }, } cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "confirm deletion without prompting") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting") return cmd } diff --git a/pkg/cmd/label/delete.go b/pkg/cmd/label/delete.go index 8788b7215..c9d8f4cae 100644 --- a/pkg/cmd/label/delete.go +++ b/pkg/cmd/label/delete.go @@ -42,7 +42,7 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co opts.Name = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { @@ -53,6 +53,8 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co } cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Confirm deletion without prompting") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "Confirm deletion without prompting") return cmd } diff --git a/pkg/cmd/label/delete_test.go b/pkg/cmd/label/delete_test.go index 3bee0d987..f43b8d5bd 100644 --- a/pkg/cmd/label/delete_test.go +++ b/pkg/cmd/label/delete_test.go @@ -37,14 +37,14 @@ func TestNewCmdDelete(t *testing.T) { }, { name: "confirm argument", - input: "test --confirm", + input: "test --yes", output: deleteOptions{Name: "test", Confirmed: true}, }, { name: "confirm no tty", input: "test", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, } diff --git a/pkg/cmd/repo/archive/archive.go b/pkg/cmd/repo/archive/archive.go index 9e18ea1c1..50b1e7908 100644 --- a/pkg/cmd/repo/archive/archive.go +++ b/pkg/cmd/repo/archive/archive.go @@ -47,16 +47,20 @@ With no argument, archives the current repository.`), } if !opts.Confirmed && !opts.IO.CanPrompt() { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } + if runF != nil { return runF(opts) } + return archiveRun(opts) }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/repo/archive/archive_test.go b/pkg/cmd/repo/archive/archive_test.go index 6d681e784..02aab3383 100644 --- a/pkg/cmd/repo/archive/archive_test.go +++ b/pkg/cmd/repo/archive/archive_test.go @@ -26,7 +26,7 @@ func TestNewCmdArchive(t *testing.T) { { name: "no arguments no tty", input: "", - errMsg: "--confirm required when not running interactively", + errMsg: "--yes required when not running interactively", wantErr: true, }, { diff --git a/pkg/cmd/repo/rename/rename.go b/pkg/cmd/repo/rename/rename.go index 2d07cd8b0..f979e8101 100644 --- a/pkg/cmd/repo/rename/rename.go +++ b/pkg/cmd/repo/rename/rename.go @@ -60,7 +60,7 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co if len(args) == 1 && !confirm && !opts.HasRepoOverride { if !opts.IO.CanPrompt() { - return cmdutil.FlagErrorf("--confirm required when passing a single argument") + return cmdutil.FlagErrorf("--yes required when passing a single argument") } opts.DoConfirm = true } @@ -68,12 +68,15 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co if runf != nil { return runf(opts) } + return renameRun(opts) }, } cmdutil.EnableRepoOverride(cmd, f) - cmd.Flags().BoolVarP(&confirm, "confirm", "y", false, "skip confirmation prompt") + cmd.Flags().BoolVar(&confirm, "confirm", false, "Skip confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&confirm, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/repo/rename/rename_test.go b/pkg/cmd/repo/rename/rename_test.go index 523b2ba4a..611e3dcda 100644 --- a/pkg/cmd/repo/rename/rename_test.go +++ b/pkg/cmd/repo/rename/rename_test.go @@ -35,7 +35,7 @@ func TestNewCmdRename(t *testing.T) { }, { name: "one argument no tty confirmed", - input: "REPO --confirm", + input: "REPO --yes", output: RenameOptions{ newRepoSelector: "REPO", }, @@ -43,12 +43,12 @@ func TestNewCmdRename(t *testing.T) { { name: "one argument no tty", input: "REPO", - errMsg: "--confirm required when passing a single argument", + errMsg: "--yes required when passing a single argument", wantErr: true, }, { name: "one argument tty confirmed", - input: "REPO --confirm", + input: "REPO --yes", tty: true, output: RenameOptions{ newRepoSelector: "REPO", diff --git a/pkg/cmd/ssh-key/delete/delete.go b/pkg/cmd/ssh-key/delete/delete.go index a11cbd769..de53f1391 100644 --- a/pkg/cmd/ssh-key/delete/delete.go +++ b/pkg/cmd/ssh-key/delete/delete.go @@ -37,17 +37,21 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co opts.KeyID = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { return runF(opts) } + return deleteRun(opts) }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") + return cmd } diff --git a/pkg/cmd/ssh-key/delete/delete_test.go b/pkg/cmd/ssh-key/delete/delete_test.go index 437443c55..85e79de3a 100644 --- a/pkg/cmd/ssh-key/delete/delete_test.go +++ b/pkg/cmd/ssh-key/delete/delete_test.go @@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) { { name: "confirm flag tty", tty: true, - input: "123 --confirm", + input: "123 --yes", output: DeleteOptions{KeyID: "123", Confirmed: true}, }, { @@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) { name: "no tty", input: "123", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, { name: "confirm flag no tty", - input: "123 --confirm", + input: "123 --yes", output: DeleteOptions{KeyID: "123", Confirmed: true}, }, { From d2f3e89ad3bded7c2ff93b1b09fccbcce61d4674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 27 Jan 2023 18:08:56 +0100 Subject: [PATCH 3/8] Fix ignoring ProjectsV2-specific errors for GHES --- api/queries_projects_v2.go | 8 ++++---- api/queries_projects_v2_test.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/queries_projects_v2.go b/api/queries_projects_v2.go index 9609524f1..e3b214f69 100644 --- a/api/queries_projects_v2.go +++ b/api/queries_projects_v2.go @@ -10,10 +10,10 @@ import ( const ( errorProjectsV2ReadScope = "field requires one of the following scopes: ['read:project']" - errorProjectsV2RepositoryField = "Field 'ProjectsV2' doesn't exist on type 'Repository'" - errorProjectsV2OrganizationField = "Field 'ProjectsV2' doesn't exist on type 'Organization'" - errorProjectsV2IssueField = "Field 'ProjectItems' doesn't exist on type 'Issue'" - errorProjectsV2PullRequestField = "Field 'ProjectItems' doesn't exist on type 'PullRequest'" + errorProjectsV2RepositoryField = "Field 'projectsV2' doesn't exist on type 'Repository'" + errorProjectsV2OrganizationField = "Field 'projectsV2' doesn't exist on type 'Organization'" + errorProjectsV2IssueField = "Field 'projectItems' doesn't exist on type 'Issue'" + errorProjectsV2PullRequestField = "Field 'projectItems' doesn't exist on type 'PullRequest'" ) // UpdateProjectV2Items uses the addProjectV2ItemById and the deleteProjectV2Item mutations diff --git a/api/queries_projects_v2_test.go b/api/queries_projects_v2_test.go index 693405126..bf6a618ba 100644 --- a/api/queries_projects_v2_test.go +++ b/api/queries_projects_v2_test.go @@ -221,22 +221,22 @@ func TestProjectsV2IgnorableError(t *testing.T) { }, { name: "repository projectsV2 field error", - errMsg: "Field 'ProjectsV2' doesn't exist on type 'Repository'", + errMsg: "Field 'projectsV2' doesn't exist on type 'Repository'", expectOut: true, }, { name: "organization projectsV2 field error", - errMsg: "Field 'ProjectsV2' doesn't exist on type 'Organization'", + errMsg: "Field 'projectsV2' doesn't exist on type 'Organization'", expectOut: true, }, { name: "issue projectItems field error", - errMsg: "Field 'ProjectItems' doesn't exist on type 'Issue'", + errMsg: "Field 'projectItems' doesn't exist on type 'Issue'", expectOut: true, }, { name: "pullRequest projectItems field error", - errMsg: "Field 'ProjectItems' doesn't exist on type 'PullRequest'", + errMsg: "Field 'projectItems' doesn't exist on type 'PullRequest'", expectOut: true, }, { From f6431ca001f5b5b26bb192f35504128f032ef550 Mon Sep 17 00:00:00 2001 From: Josh Gross Date: Fri, 27 Jan 2023 14:36:35 -0500 Subject: [PATCH 4/8] Use int64 repository IDs for Codespaces user secrets --- pkg/cmd/secret/set/http.go | 22 +++------------------- pkg/cmd/secret/set/set_test.go | 6 +++--- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/pkg/cmd/secret/set/http.go b/pkg/cmd/secret/set/http.go index f36c2b59e..2ea2e8eff 100644 --- a/pkg/cmd/secret/set/http.go +++ b/pkg/cmd/secret/set/http.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "sort" - "strconv" "strings" "github.com/cli/cli/v2/api" @@ -20,13 +19,6 @@ type SecretPayload struct { KeyID string `json:"key_id"` } -// The Codespaces Secret API currently expects repositories IDs as strings -type CodespacesSecretPayload struct { - EncryptedValue string `json:"encrypted_value"` - Repositories []string `json:"selected_repository_ids,omitempty"` - KeyID string `json:"key_id"` -} - type PubKey struct { ID string `json:"key_id"` Key string @@ -59,7 +51,7 @@ func getEnvPubKey(client *api.Client, repo ghrepo.Interface, envName string) (*P ghrepo.FullName(repo), envName)) } -func putSecret(client *api.Client, host, path string, payload interface{}) error { +func putSecret(client *api.Client, host, path string, payload SecretPayload) error { payloadBytes, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to serialize: %w", err) @@ -82,19 +74,11 @@ func putOrgSecret(client *api.Client, host string, pk *PubKey, orgName, visibili } func putUserSecret(client *api.Client, host string, pk *PubKey, key, eValue string, repositoryIDs []int64) error { - payload := CodespacesSecretPayload{ + payload := SecretPayload{ EncryptedValue: eValue, KeyID: pk.ID, + Repositories: repositoryIDs, } - - if len(repositoryIDs) > 0 { - repositoryStringIDs := make([]string, len(repositoryIDs)) - for i, id := range repositoryIDs { - repositoryStringIDs[i] = strconv.FormatInt(id, 10) - } - payload.Repositories = repositoryStringIDs - } - path := fmt.Sprintf("user/codespaces/secrets/%s", key) return putSecret(client, host, path, payload) } diff --git a/pkg/cmd/secret/set/set_test.go b/pkg/cmd/secret/set/set_test.go index d9313f81f..babc5dc4f 100644 --- a/pkg/cmd/secret/set/set_test.go +++ b/pkg/cmd/secret/set/set_test.go @@ -426,7 +426,7 @@ func Test_setRun_user(t *testing.T) { name string opts *SetOptions wantVisibility shared.Visibility - wantRepositories []string + wantRepositories []int64 }{ { name: "all vis", @@ -442,7 +442,7 @@ func Test_setRun_user(t *testing.T) { Visibility: shared.Selected, RepositoryNames: []string{"cli/cli", "github/hub"}, }, - wantRepositories: []string{"212613049", "401025"}, + wantRepositories: []int64{212613049, 401025}, }, } @@ -481,7 +481,7 @@ func Test_setRun_user(t *testing.T) { data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body) assert.NoError(t, err) - var payload CodespacesSecretPayload + var payload SecretPayload err = json.Unmarshal(data, &payload) assert.NoError(t, err) assert.Equal(t, payload.KeyID, "123") From 1786ece4a4275b035d9bb88ac5bd656444b5ea47 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Tue, 31 Jan 2023 08:55:41 +1100 Subject: [PATCH 5/8] Change uint8 to uint64 to hold job with more than 255 run attempts (#6935) --- pkg/cmd/run/shared/shared.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cmd/run/shared/shared.go b/pkg/cmd/run/shared/shared.go index 557882968..a1a71c4f9 100644 --- a/pkg/cmd/run/shared/shared.go +++ b/pkg/cmd/run/shared/shared.go @@ -77,7 +77,7 @@ type Run struct { workflowName string // cache column WorkflowID int64 `json:"workflow_id"` Number int64 `json:"run_number"` - Attempts uint8 `json:"run_attempt"` + Attempts uint64 `json:"run_attempt"` HeadBranch string `json:"head_branch"` JobsURL string `json:"jobs_url"` HeadCommit Commit `json:"head_commit"` From 1233bd44395306d5eddf8e6ab5344857709385f0 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 2 Feb 2023 07:12:22 +1100 Subject: [PATCH 6/8] Special case setting dependabot org secrets (#6941) --- pkg/cmd/secret/set/http.go | 29 ++++++++++++++++++-- pkg/cmd/secret/set/set_test.go | 49 +++++++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 15 deletions(-) diff --git a/pkg/cmd/secret/set/http.go b/pkg/cmd/secret/set/http.go index 2ea2e8eff..d5c2bf436 100644 --- a/pkg/cmd/secret/set/http.go +++ b/pkg/cmd/secret/set/http.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "sort" + "strconv" "strings" "github.com/cli/cli/v2/api" @@ -19,6 +20,13 @@ type SecretPayload struct { KeyID string `json:"key_id"` } +type DependabotSecretPayload struct { + EncryptedValue string `json:"encrypted_value"` + Visibility string `json:"visibility,omitempty"` + Repositories []string `json:"selected_repository_ids,omitempty"` + KeyID string `json:"key_id"` +} + type PubKey struct { ID string `json:"key_id"` Key string @@ -51,7 +59,7 @@ func getEnvPubKey(client *api.Client, repo ghrepo.Interface, envName string) (*P ghrepo.FullName(repo), envName)) } -func putSecret(client *api.Client, host, path string, payload SecretPayload) error { +func putSecret(client *api.Client, host, path string, payload interface{}) error { payloadBytes, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to serialize: %w", err) @@ -62,13 +70,30 @@ func putSecret(client *api.Client, host, path string, payload SecretPayload) err } func putOrgSecret(client *api.Client, host string, pk *PubKey, orgName, visibility, secretName, eValue string, repositoryIDs []int64, app shared.App) error { + path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName) + + if app == shared.Dependabot { + repos := make([]string, len(repositoryIDs)) + for i, id := range repositoryIDs { + repos[i] = strconv.FormatInt(id, 10) + } + + payload := DependabotSecretPayload{ + EncryptedValue: eValue, + KeyID: pk.ID, + Repositories: repos, + Visibility: visibility, + } + + return putSecret(client, host, path, payload) + } + payload := SecretPayload{ EncryptedValue: eValue, KeyID: pk.ID, Repositories: repositoryIDs, Visibility: visibility, } - path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName) return putSecret(client, host, path, payload) } diff --git a/pkg/cmd/secret/set/set_test.go b/pkg/cmd/secret/set/set_test.go index babc5dc4f..592527c5f 100644 --- a/pkg/cmd/secret/set/set_test.go +++ b/pkg/cmd/secret/set/set_test.go @@ -333,11 +333,12 @@ func Test_setRun_env(t *testing.T) { func Test_setRun_org(t *testing.T) { tests := []struct { - name string - opts *SetOptions - wantVisibility shared.Visibility - wantRepositories []int64 - wantApp string + name string + opts *SetOptions + wantVisibility shared.Visibility + wantRepositories []int64 + wantDependabotRepositories []string + wantApp string }{ { name: "all vis", @@ -362,10 +363,21 @@ func Test_setRun_org(t *testing.T) { opts: &SetOptions{ OrgName: "UmbrellaCorporation", Visibility: shared.All, - Application: "dependabot", + Application: shared.Dependabot, }, wantApp: "dependabot", }, + { + name: "Dependabot selected visibility", + opts: &SetOptions{ + OrgName: "UmbrellaCorporation", + Visibility: shared.Selected, + Application: shared.Dependabot, + RepositoryNames: []string{"birkin", "UmbrellaCorporation/wesker"}, + }, + wantDependabotRepositories: []string{"1", "2"}, + wantApp: "dependabot", + }, } for _, tt := range tests { @@ -410,13 +422,24 @@ func Test_setRun_org(t *testing.T) { data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body) assert.NoError(t, err) - var payload SecretPayload - err = json.Unmarshal(data, &payload) - assert.NoError(t, err) - assert.Equal(t, payload.KeyID, "123") - assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") - assert.Equal(t, payload.Visibility, tt.opts.Visibility) - assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories) + + if tt.opts.Application == shared.Dependabot { + var payload DependabotSecretPayload + err = json.Unmarshal(data, &payload) + assert.NoError(t, err) + assert.Equal(t, payload.KeyID, "123") + assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") + assert.Equal(t, payload.Visibility, tt.opts.Visibility) + assert.ElementsMatch(t, payload.Repositories, tt.wantDependabotRepositories) + } else { + var payload SecretPayload + err = json.Unmarshal(data, &payload) + assert.NoError(t, err) + assert.Equal(t, payload.KeyID, "123") + assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") + assert.Equal(t, payload.Visibility, tt.opts.Visibility) + assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories) + } }) } } From ced071feaea37478ea9bb7ad656077eceb209e01 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 2 Feb 2023 08:19:30 +1100 Subject: [PATCH 7/8] Sanitize ANSII control characters returned from the server (#6916) --- api/http_client.go | 2 + api/sanitize_ascii.go | 193 +++++++++++++++++++++++++++++++++++++ api/sanitize_ascii_test.go | 50 ++++++++++ 3 files changed, 245 insertions(+) create mode 100644 api/sanitize_ascii.go create mode 100644 api/sanitize_ascii_test.go diff --git a/api/http_client.go b/api/http_client.go index 81693cbd1..83f228409 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -64,6 +64,8 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) } + client.Transport = AddASCIISanitizer(client.Transport) + return client, nil } diff --git a/api/sanitize_ascii.go b/api/sanitize_ascii.go new file mode 100644 index 000000000..92741a147 --- /dev/null +++ b/api/sanitize_ascii.go @@ -0,0 +1,193 @@ +package api + +import ( + "bytes" + "errors" + "io" + "net/http" + "regexp" + "strings" +) + +var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + +// GitHub servers return non-printable characters as their unicode code point values. +// The values of \u0000 to \u001F represent C0 ASCII control characters and +// the values of \u0080 to \u009F represent C1 ASCII control characters. These control +// characters will be interpreted by the terminal, this behaviour can be used maliciously +// as an attack vector, especially the control character \u001B. This function wraps +// JSON response bodies in a ReadCloser that transforms C0 and C1 control characters +// to their caret and hex notations respectively so that the terminal will not interpret them. +func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + res, err := rt.RoundTrip(req) + if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) { + return res, err + } + res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body} + return res, err + }} +} + +// sanitizeASCIIReadCloser implements the ReadCloser interface. +type sanitizeASCIIReadCloser struct { + io.ReadCloser + addBackslash bool + previousWindow []byte +} + +// Read uses a sliding window alogorithm to detect C0 and C1 +// ASCII control sequences as they are read and replaces them +// with equivelent inert characters. Characters that are not part +// of a control sequence not modified. +func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { + var readErr error + var outIndex int + var bufIndex int + var bufLen int + var window []byte + buf := make([]byte, len(out)) + + bufLen, readErr = s.ReadCloser.Read(buf) + if readErr != nil && !errors.Is(readErr, io.EOF) { + if bufLen > 0 { + // Do not sanitize if there was a read error that is not EOF. + bufLen = copy(out, buf) + } + return bufLen, readErr + } + + if s.previousWindow != nil { + buf = append(s.previousWindow, buf...) + bufLen += len(s.previousWindow) + } + + for { + remaining := min(6, (bufLen - bufIndex)) + window = buf[bufIndex : bufIndex+remaining] + if remaining < 6 { + break + } + + if bytes.HasPrefix(window, []byte(`\u00`)) { + repl, _ := mapControlCharacterToCaret(window) + if s.addBackslash { + repl = append([]byte{92}, repl...) + } + l := len(repl) + for j := 0; j < l; j++ { + out[outIndex] = repl[j] + outIndex++ + } + bufIndex += 6 + s.addBackslash = false + continue + } + + if window[0] == '\\' { + s.addBackslash = !s.addBackslash + } else { + s.addBackslash = false + } + + out[outIndex] = buf[bufIndex] + outIndex++ + bufIndex++ + } + + if readErr != nil && errors.Is(readErr, io.EOF) { + remaining := bufLen - bufIndex + for j := 0; j < remaining; j++ { + out[outIndex] = window[j] + outIndex++ + bufIndex++ + } + } else { + s.previousWindow = window + } + + return outIndex, readErr +} + +// mapControlCharacterToCaret maps C0 control sequences to caret notation +// and C1 control sequences to hex notation. C1 control sequences do not +// have caret notation representation. +func mapControlCharacterToCaret(b []byte) ([]byte, bool) { + m := map[string]string{ + `\u0000`: `^@`, + `\u0001`: `^A`, + `\u0002`: `^B`, + `\u0003`: `^C`, + `\u0004`: `^D`, + `\u0005`: `^E`, + `\u0006`: `^F`, + `\u0007`: `^G`, + `\u0008`: `^H`, + `\u0009`: `^I`, + `\u000a`: `^J`, + `\u000b`: `^K`, + `\u000c`: `^L`, + `\u000d`: `^M`, + `\u000e`: `^N`, + `\u000f`: `^O`, + `\u0010`: `^P`, + `\u0011`: `^Q`, + `\u0012`: `^R`, + `\u0013`: `^S`, + `\u0014`: `^T`, + `\u0015`: `^U`, + `\u0016`: `^V`, + `\u0017`: `^W`, + `\u0018`: `^X`, + `\u0019`: `^Y`, + `\u001a`: `^Z`, + `\u001b`: `^[`, + `\u001c`: `^\\`, + `\u001d`: `^]`, + `\u001e`: `^^`, + `\u001f`: `^_`, + `\u0080`: `\\200`, + `\u0081`: `\\201`, + `\u0082`: `\\202`, + `\u0083`: `\\203`, + `\u0084`: `\\204`, + `\u0085`: `\\205`, + `\u0086`: `\\206`, + `\u0087`: `\\207`, + `\u0088`: `\\210`, + `\u0089`: `\\211`, + `\u008a`: `\\212`, + `\u008b`: `\\213`, + `\u008c`: `\\214`, + `\u008d`: `\\215`, + `\u008e`: `\\216`, + `\u008f`: `\\217`, + `\u0090`: `\\220`, + `\u0091`: `\\221`, + `\u0092`: `\\222`, + `\u0093`: `\\223`, + `\u0094`: `\\224`, + `\u0095`: `\\225`, + `\u0096`: `\\226`, + `\u0097`: `\\227`, + `\u0098`: `\\230`, + `\u0099`: `\\231`, + `\u009a`: `\\232`, + `\u009b`: `\\233`, + `\u009c`: `\\234`, + `\u009d`: `\\235`, + `\u009e`: `\\236`, + `\u009f`: `\\237`, + } + if c, ok := m[strings.ToLower(string(b))]; ok { + return []byte(c), true + } + return b, false +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/api/sanitize_ascii_test.go b/api/sanitize_ascii_test.go new file mode 100644 index 000000000..9b405edc8 --- /dev/null +++ b/api/sanitize_ascii_test.go @@ -0,0 +1,50 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + issue := Issue{ + Title: "\u001B[31mRed Title\u001B[0m", + Body: "1\u0001 2\u0002 3\u0003 4\u0004 5\u0005 6\u0006 7\u0007 8\u0008 9\t A\r\n B\u000b C\u000c D\r\n E\u000e F\u000f", + Author: Author{ + ID: "1", + Name: "10\u0010 11\u0011 12\u0012 13\u0013 14\u0014 15\u0015 16\u0016 17\u0017 18\u0018 19\u0019 1A\u001a 1B\u001b 1C\u001c 1D\u001d 1E\u001e 1F\u001f", + Login: "monalisa", + }, + ActiveLockReason: "Escaped \u001B \\u001B \\\u001B \\\\u001B", + } + responseData, _ := json.Marshal(issue) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + fmt.Fprint(w, string(responseData)) + })) + defer ts.Close() + + client, err := NewHTTPClient(HTTPClientOptions{}) + require.NoError(t, err) + req, err := http.NewRequest("GET", ts.URL, nil) + require.NoError(t, err) + res, err := client.Do(req) + require.NoError(t, err) + body, err := io.ReadAll(res.Body) + res.Body.Close() + require.NoError(t, err) + var issue Issue + err = json.Unmarshal(body, &issue) + require.NoError(t, err) + assert.Equal(t, "^[[31mRed Title^[[0m", issue.Title) + assert.Equal(t, "1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\t A\r\n B^K C^L D\r\n E^N F^O", issue.Body) + assert.Equal(t, "10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y 1A^Z 1B^[ 1C^\\ 1D^] 1E^^ 1F^_", issue.Author.Name) + assert.Equal(t, "monalisa", issue.Author.Login) + assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason) +} From c36aece5cdd9bfa60a84c1260c59cefb01e964c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Feb 2023 14:03:41 +0000 Subject: [PATCH 8/8] Bump github.com/cli/oauth from 1.0.0 to 1.0.1 Bumps [github.com/cli/oauth](https://github.com/cli/oauth) from 1.0.0 to 1.0.1. - [Release notes](https://github.com/cli/oauth/releases) - [Commits](https://github.com/cli/oauth/compare/v1.0.0...v1.0.1) --- updated-dependencies: - dependency-name: github.com/cli/oauth dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7f5fa8ba0..e37f5edcb 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da github.com/charmbracelet/lipgloss v0.5.0 github.com/cli/go-gh v1.0.0 - github.com/cli/oauth v1.0.0 + github.com/cli/oauth v1.0.1 github.com/cli/safeexec v1.0.1 github.com/cpuguy83/go-md2man/v2 v2.0.2 github.com/creack/pty v1.1.18 diff --git a/go.sum b/go.sum index a0bd195f1..bb252437d 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,8 @@ github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= github.com/cli/go-gh v1.0.0 h1:zE1YUAUYqGXNZuICEBeOkIMJ5F50BS0ftvtoWGlsEFI= github.com/cli/go-gh v1.0.0/go.mod h1:bqxLdCoTZ73BuiPEJx4olcO/XKhVZaFDchFagYRBweE= -github.com/cli/oauth v1.0.0 h1:zuatYn8BRWWO98y2jNXK4RKOryU1u6JTqPrdSPW5pSE= -github.com/cli/oauth v1.0.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= +github.com/cli/oauth v1.0.1 h1:pXnTFl/qUegXHK531Dv0LNjW4mLx626eS42gnzfXJPA= +github.com/cli/oauth v1.0.1/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=