diff --git a/internal/codespaces/rpc/generate.sh b/internal/codespaces/rpc/generate.sh index 08384f3d6..159803bbe 100755 --- a/internal/codespaces/rpc/generate.sh +++ b/internal/codespaces/rpc/generate.sh @@ -23,5 +23,6 @@ function generate { generate jupyter/jupyter_server_host_service.v1.proto generate codespace/codespace_host_service.v1.proto +generate ssh/ssh_server_host_service.v1.proto echo 'Done!' diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 9bd6114cf..67a88bb2f 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -7,11 +7,14 @@ import ( "context" "fmt" "net" + "os" "strconv" + "strings" "time" "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" + "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" "github.com/cli/cli/v2/pkg/liveshare" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -28,12 +31,16 @@ const ( codespacesInternalSessionName = "CodespacesInternal" ) +type StartSSHServerOptions struct { + UserPublicKeyFile string +} + type Invoker interface { Close() error StartJupyterServer(ctx context.Context) (int, string, error) RebuildContainer(ctx context.Context, full bool) error StartSSHServer(ctx context.Context) (int, string, error) - StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) + StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) } type invoker struct { @@ -42,6 +49,7 @@ type invoker struct { listener net.Listener jupyterClient jupyter.JupyterServerHostClient codespaceClient codespace.CodespaceHostClient + sshClient ssh.SshServerHostClient cancelPF context.CancelFunc } @@ -118,6 +126,7 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, invoker.conn = conn invoker.jupyterClient = jupyter.NewJupyterServerHostClient(conn) invoker.codespaceClient = codespace.NewCodespaceHostClient(conn) + invoker.sshClient = ssh.NewSshServerHostClient(conn) return invoker, nil } @@ -185,10 +194,38 @@ func (i *invoker) RebuildContainer(ctx context.Context, full bool) error { // Starts a remote SSH server to allow the user to connect to the codespace via SSH func (i *invoker) StartSSHServer(ctx context.Context) (int, string, error) { - return i.session.StartSSHServer(ctx) + return i.StartSSHServerWithOptions(ctx, StartSSHServerOptions{}) } // Starts a remote SSH server to allow the user to connect to the codespace via SSH -func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) { - return i.session.StartSSHServerWithOptions(ctx, options) +func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) { + ctx = i.appendMetadata(ctx) + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + + userPublicKey := "" + if options.UserPublicKeyFile != "" { + publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile) + if err != nil { + return 0, "", fmt.Errorf("failed to read public key file: %w", err) + } + + userPublicKey = strings.TrimSpace(string(publicKeyBytes)) + } + + response, err := i.sshClient.StartRemoteServerAsync(ctx, &ssh.StartRemoteServerRequest{UserPublicKey: userPublicKey}) + if err != nil { + return 0, "", fmt.Errorf("failed to invoke SSH RPC: %w", err) + } + + if !response.Result { + return 0, "", fmt.Errorf("failed to start SSH server: %s", response.Message) + } + + port, err := strconv.Atoi(response.ServerPort) + if err != nil { + return 0, "", fmt.Errorf("failed to parse SSH server port: %w", err) + } + + return port, response.User, nil } diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index 9788cf980..bfed27181 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -113,3 +113,38 @@ func TestRebuildContainerFailure(t *testing.T) { t.Fatalf("expected %v, got %v", errorMessage, err) } } + +// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully +func TestStartSSHServerSuccess(t *testing.T) { + startServer(t) + invoker := createTestInvoker(t) + port, user, err := invoker.StartSSHServer(context.Background()) + if err != nil { + t.Fatalf("expected %v, got %v", nil, err) + } + if port != rpctest.SshServerPort { + t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port) + } + if user != rpctest.SshUser { + t.Fatalf("expected %s, got %s", rpctest.SshUser, user) + } +} + +// Test that the RPC invoker returns an error when the SSH server fails to start +func TestStartSSHServerFailure(t *testing.T) { + startServer(t) + invoker := createTestInvoker(t) + rpctest.SshMessage = "error message" + rpctest.SshResult = false + errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage) + port, user, err := invoker.StartSSHServer(context.Background()) + if err.Error() != errorMessage { + t.Fatalf("expected %v, got %v", errorMessage, err) + } + if port != 0 { + t.Fatalf("expected %d, got %d", 0, port) + } + if user != "" { + t.Fatalf("expected %s, got %s", "", user) + } +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go new file mode 100644 index 000000000..c495eb781 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go @@ -0,0 +1,252 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.0 +// protoc v3.21.12 +// source: ssh/ssh_server_host_service.v1.proto + +package ssh + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type StartRemoteServerRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserPublicKey string `protobuf:"bytes,1,opt,name=UserPublicKey,proto3" json:"UserPublicKey,omitempty"` +} + +func (x *StartRemoteServerRequest) Reset() { + *x = StartRemoteServerRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StartRemoteServerRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartRemoteServerRequest) ProtoMessage() {} + +func (x *StartRemoteServerRequest) ProtoReflect() protoreflect.Message { + mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartRemoteServerRequest.ProtoReflect.Descriptor instead. +func (*StartRemoteServerRequest) Descriptor() ([]byte, []int) { + return file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP(), []int{0} +} + +func (x *StartRemoteServerRequest) GetUserPublicKey() string { + if x != nil { + return x.UserPublicKey + } + return "" +} + +type StartRemoteServerResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Result bool `protobuf:"varint,1,opt,name=Result,proto3" json:"Result,omitempty"` + ServerPort string `protobuf:"bytes,2,opt,name=ServerPort,proto3" json:"ServerPort,omitempty"` + User string `protobuf:"bytes,3,opt,name=User,proto3" json:"User,omitempty"` + Message string `protobuf:"bytes,4,opt,name=Message,proto3" json:"Message,omitempty"` +} + +func (x *StartRemoteServerResponse) Reset() { + *x = StartRemoteServerResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StartRemoteServerResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartRemoteServerResponse) ProtoMessage() {} + +func (x *StartRemoteServerResponse) ProtoReflect() protoreflect.Message { + mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartRemoteServerResponse.ProtoReflect.Descriptor instead. +func (*StartRemoteServerResponse) Descriptor() ([]byte, []int) { + return file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP(), []int{1} +} + +func (x *StartRemoteServerResponse) GetResult() bool { + if x != nil { + return x.Result + } + return false +} + +func (x *StartRemoteServerResponse) GetServerPort() string { + if x != nil { + return x.ServerPort + } + return "" +} + +func (x *StartRemoteServerResponse) GetUser() string { + if x != nil { + return x.User + } + return "" +} + +func (x *StartRemoteServerResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_ssh_ssh_server_host_service_v1_proto protoreflect.FileDescriptor + +var file_ssh_ssh_server_host_service_v1_proto_rawDesc = []byte{ + 0x0a, 0x24, 0x73, 0x73, 0x68, 0x2f, 0x73, 0x73, 0x68, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x27, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, + 0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x22, + 0x40, 0x0a, 0x18, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x55, + 0x73, 0x65, 0x72, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0d, 0x55, 0x73, 0x65, 0x72, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, + 0x79, 0x22, 0x81, 0x01, 0x0a, 0x19, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x16, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x55, 0x73, 0x65, 0x72, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0xb1, 0x01, 0x0a, 0x0d, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x9f, 0x01, 0x0a, 0x16, 0x53, 0x74, 0x61, 0x72, + 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x73, 0x79, + 0x6e, 0x63, 0x12, 0x41, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, + 0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x48, 0x6f, + 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61, + 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x42, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, + 0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, + 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x07, 0x5a, 0x05, 0x2e, 0x2f, 0x73, + 0x73, 0x68, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_ssh_ssh_server_host_service_v1_proto_rawDescOnce sync.Once + file_ssh_ssh_server_host_service_v1_proto_rawDescData = file_ssh_ssh_server_host_service_v1_proto_rawDesc +) + +func file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP() []byte { + file_ssh_ssh_server_host_service_v1_proto_rawDescOnce.Do(func() { + file_ssh_ssh_server_host_service_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_ssh_ssh_server_host_service_v1_proto_rawDescData) + }) + return file_ssh_ssh_server_host_service_v1_proto_rawDescData +} + +var file_ssh_ssh_server_host_service_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_ssh_ssh_server_host_service_v1_proto_goTypes = []interface{}{ + (*StartRemoteServerRequest)(nil), // 0: Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerRequest + (*StartRemoteServerResponse)(nil), // 1: Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerResponse +} +var file_ssh_ssh_server_host_service_v1_proto_depIdxs = []int32{ + 0, // 0: Codespaces.Grpc.SshServerHostService.v1.SshServerHost.StartRemoteServerAsync:input_type -> Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerRequest + 1, // 1: Codespaces.Grpc.SshServerHostService.v1.SshServerHost.StartRemoteServerAsync:output_type -> Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_ssh_ssh_server_host_service_v1_proto_init() } +func file_ssh_ssh_server_host_service_v1_proto_init() { + if File_ssh_ssh_server_host_service_v1_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_ssh_ssh_server_host_service_v1_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StartRemoteServerRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_ssh_ssh_server_host_service_v1_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StartRemoteServerResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_ssh_ssh_server_host_service_v1_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_ssh_ssh_server_host_service_v1_proto_goTypes, + DependencyIndexes: file_ssh_ssh_server_host_service_v1_proto_depIdxs, + MessageInfos: file_ssh_ssh_server_host_service_v1_proto_msgTypes, + }.Build() + File_ssh_ssh_server_host_service_v1_proto = out.File + file_ssh_ssh_server_host_service_v1_proto_rawDesc = nil + file_ssh_ssh_server_host_service_v1_proto_goTypes = nil + file_ssh_ssh_server_host_service_v1_proto_depIdxs = nil +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto new file mode 100644 index 000000000..322086b15 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +option go_package = "./ssh"; + +package Codespaces.Grpc.SshServerHostService.v1; + +service SshServerHost { + rpc StartRemoteServerAsync (StartRemoteServerRequest) returns (StartRemoteServerResponse); +} + +message StartRemoteServerRequest { + string UserPublicKey = 1; +} + +message StartRemoteServerResponse { + bool Result = 1; + string ServerPort = 2; + string User = 3; + string Message = 4; +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1_grpc.pb.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1_grpc.pb.go new file mode 100644 index 000000000..a111656e8 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1_grpc.pb.go @@ -0,0 +1,105 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.21.12 +// source: ssh/ssh_server_host_service.v1.proto + +package ssh + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// SshServerHostClient is the client API for SshServerHost service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type SshServerHostClient interface { + StartRemoteServerAsync(ctx context.Context, in *StartRemoteServerRequest, opts ...grpc.CallOption) (*StartRemoteServerResponse, error) +} + +type sshServerHostClient struct { + cc grpc.ClientConnInterface +} + +func NewSshServerHostClient(cc grpc.ClientConnInterface) SshServerHostClient { + return &sshServerHostClient{cc} +} + +func (c *sshServerHostClient) StartRemoteServerAsync(ctx context.Context, in *StartRemoteServerRequest, opts ...grpc.CallOption) (*StartRemoteServerResponse, error) { + out := new(StartRemoteServerResponse) + err := c.cc.Invoke(ctx, "/Codespaces.Grpc.SshServerHostService.v1.SshServerHost/StartRemoteServerAsync", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// SshServerHostServer is the server API for SshServerHost service. +// All implementations must embed UnimplementedSshServerHostServer +// for forward compatibility +type SshServerHostServer interface { + StartRemoteServerAsync(context.Context, *StartRemoteServerRequest) (*StartRemoteServerResponse, error) + mustEmbedUnimplementedSshServerHostServer() +} + +// UnimplementedSshServerHostServer must be embedded to have forward compatible implementations. +type UnimplementedSshServerHostServer struct { +} + +func (UnimplementedSshServerHostServer) StartRemoteServerAsync(context.Context, *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StartRemoteServerAsync not implemented") +} +func (UnimplementedSshServerHostServer) mustEmbedUnimplementedSshServerHostServer() {} + +// UnsafeSshServerHostServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to SshServerHostServer will +// result in compilation errors. +type UnsafeSshServerHostServer interface { + mustEmbedUnimplementedSshServerHostServer() +} + +func RegisterSshServerHostServer(s grpc.ServiceRegistrar, srv SshServerHostServer) { + s.RegisterService(&SshServerHost_ServiceDesc, srv) +} + +func _SshServerHost_StartRemoteServerAsync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartRemoteServerRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SshServerHostServer).StartRemoteServerAsync(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Codespaces.Grpc.SshServerHostService.v1.SshServerHost/StartRemoteServerAsync", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SshServerHostServer).StartRemoteServerAsync(ctx, req.(*StartRemoteServerRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// SshServerHost_ServiceDesc is the grpc.ServiceDesc for SshServerHost service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var SshServerHost_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "Codespaces.Grpc.SshServerHostService.v1.SshServerHost", + HandlerType: (*SshServerHostServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "StartRemoteServerAsync", + Handler: _SshServerHost_StartRemoteServerAsync_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "ssh/ssh_server_host_service.v1.proto", +} diff --git a/internal/codespaces/rpc/test/server.go b/internal/codespaces/rpc/test/server.go index 465583679..d2dc9f590 100644 --- a/internal/codespaces/rpc/test/server.go +++ b/internal/codespaces/rpc/test/server.go @@ -8,6 +8,7 @@ import ( "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" + "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" "google.golang.org/grpc" ) @@ -28,9 +29,18 @@ var ( RebuildContainer = true ) +// Mock responses for the `StartRemoteServerAsync` RPC method +var ( + SshServerPort = 1234 + SshUser = "test" + SshMessage = "" + SshResult = true +) + type server struct { jupyter.UnimplementedJupyterServerHostServer codespace.CodespaceHostServer + ssh.SshServerHostServer } func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { @@ -48,6 +58,15 @@ func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.Rebuil }, nil } +func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(SshServerPort), + User: SshUser, + Message: SshMessage, + Result: SshResult, + }, nil +} + // Starts the mock gRPC server listening on port 50051 func StartServer(ctx context.Context) error { listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) @@ -59,6 +78,7 @@ func StartServer(ctx context.Context) error { s := grpc.NewServer() jupyter.RegisterJupyterServerHostServer(s, &server{}) codespace.RegisterCodespaceHostServer(s, &server{}) + ssh.RegisterSshServerHostServer(s, &server{}) ch := make(chan error, 1) go func() { diff --git a/internal/codespaces/rpc/test/session.go b/internal/codespaces/rpc/test/session.go index 360b3464a..89d66a912 100644 --- a/internal/codespaces/rpc/test/session.go +++ b/internal/codespaces/rpc/test/session.go @@ -21,18 +21,6 @@ func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) { panic("unimplemented") } -func (*Session) RebuildContainer(context.Context, bool) error { - panic("unimplemented") -} - -func (*Session) StartSSHServer(context.Context) (int, string, error) { - panic("unimplemented") -} - -func (*Session) StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error) { - panic("unimplemented") -} - func (s *Session) KeepAlive(reason string) { } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 15d201510..3a3fdc86a 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -148,7 +148,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } sshContext := ssh.Context{} - startSSHOptions := liveshare.StartSSHServerOptions{} + startSSHOptions := rpc.StartSSHServerOptions{} keyPair, shouldAddArg, err := selectSSHKeys(ctx, sshContext, args, opts) if err != nil { diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 35a8e69a8..697659021 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -3,9 +3,6 @@ package liveshare import ( "context" "fmt" - "os" - "strconv" - "strings" "time" "github.com/opentracing/opentracing-go" @@ -26,8 +23,6 @@ type LiveshareSession interface { KeepAlive(string) OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) StartSharing(context.Context, string, int) (ChannelID, error) - StartSSHServer(context.Context) (int, string, error) - StartSSHServerWithOptions(context.Context, StartSSHServerOptions) (int, string, error) } // A Session represents the session between a connected Live Share client and server. @@ -40,10 +35,6 @@ type Session struct { logger logger } -type StartSSHServerOptions struct { - UserPublicKeyFile string -} - // Close should be called by users to clean up RPC and SSH resources whenever the session // is no longer active. func (s *Session) Close() error { @@ -63,51 +54,6 @@ func (s *Session) registerRequestHandler(requestType string, h handler) func() { return s.rpc.register(requestType, h) } -// StartSSHServer starts an SSH server in the container, installing sshd if necessary, applies specified -// options, and returns the port on which it listens and the user name clients should provide. -func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { - return s.StartSSHServerWithOptions(ctx, StartSSHServerOptions{}) -} - -// StartSSHServerWithOptions starts an SSH server in the container, installing sshd if necessary, applies specified -// options, and returns the port on which it listens and the user name clients should provide. -func (s *Session) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) { - var params struct { - UserPublicKey string `json:"userPublicKey"` - } - - var response struct { - Result bool `json:"result"` - ServerPort string `json:"serverPort"` - User string `json:"user"` - Message string `json:"message"` - } - - if options.UserPublicKeyFile != "" { - publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile) - if err != nil { - return 0, "", fmt.Errorf("failed to read public key file: %w", err) - } - - params.UserPublicKey = strings.TrimSpace(string(publicKeyBytes)) - } - - if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", params, &response); err != nil { - return 0, "", err - } - - if !response.Result { - return 0, "", fmt.Errorf("failed to start server: %s", response.Message) - } - - port, err := strconv.Atoi(response.ServerPort) - if err != nil { - return 0, "", fmt.Errorf("failed to parse port: %w", err) - } - - return port, response.User, nil -} - // heartbeat runs until context cancellation, periodically checking whether there is a // reason to keep the connection alive, and if so, notifying the Live Share host to do so. // Heartbeat ensures it does not send more than one request every "interval" to ratelimit