Checking point after continuing to flesh out mock server

This commit is contained in:
Jose Garcia 2021-07-23 19:15:54 +00:00 committed by GitHub
parent b9cd9af7fa
commit 9132a28e9c
4 changed files with 299 additions and 55 deletions

View file

@ -2,6 +2,7 @@ package liveshare
import (
"context"
"crypto/tls"
"fmt"
"golang.org/x/crypto/ssh"
@ -10,6 +11,7 @@ import (
// A Client capable of joining a liveshare connection
type Client struct {
connection Connection
tlsConfig *tls.Config
ssh *sshSession
rpc *rpcClient
@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption {
}
}
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
return func(c *Client) error {
c.tlsConfig = tlsConfig
return nil
}
}
// Join is a method that joins the client to the liveshare session
func (c *Client) Join(ctx context.Context) (err error) {
clientSocket := newSocket(c.connection)
clientSocket := newSocket(c.connection, c.tlsConfig)
if err := clientSocket.connect(ctx); err != nil {
return fmt.Errorf("error connecting websocket: %v", err)
}

View file

@ -1,12 +1,14 @@
package liveshare
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/websocket"
livesharetest "github.com/github/go-liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewClient(t *testing.T) {
@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) {
}
}
var upgrader = websocket.Upgrader{}
func newMockLiveShareServer() *httptest.Server {
endpoint := func(w http.ResponseWriter, req *http.Request) {
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println(err)
break
}
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println(err)
break
}
}
func TestClientJoin(t *testing.T) {
sessionToken := "session-token"
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return 1, nil
}
return httptest.NewTLSServer(http.HandlerFunc(endpoint))
}
func TestClientJoin(t *testing.T) {
// server := newMockLiveShareServer()
// defer server.Close()
// connection := Connection{
// SessionID: "session-id",
// SessionToken: "session-token",
// RelaySAS: "relay-sas",
// RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"),
// }
// client, err := NewClient(WithConnection(connection))
// if err != nil {
// t.Errorf("error creating new client: %v", err)
// }
// ctx := context.Background()
// if err := client.Join(ctx); err != nil {
// t.Errorf("error joining client: %v", err)
// }
server, err := livesharetest.NewServer(
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
if err != nil {
t.Errorf("error creating liveshare server: %v", err)
}
defer server.Close()
ctx := context.Background()
connection := Connection{
SessionID: "session-id",
SessionToken: sessionToken,
RelaySAS: "relay-sas",
RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"),
}
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
client, err := NewClient(WithConnection(connection), tlsConfig)
if err != nil {
t.Errorf("error creating new client: %v", err)
}
clientErr := make(chan error)
go func() {
if err := client.Join(ctx); err != nil {
clientErr <- fmt.Errorf("error joining client: %v", err)
return
}
ctx.Done()
}()
select {
case err := <-server.Err():
t.Errorf("error from server: %v", err)
case err := <-clientErr:
t.Errorf("error from client: %v", err)
case <-ctx.Done():
return
}
}

View file

@ -2,9 +2,11 @@ package liveshare
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"sync"
"time"
@ -12,19 +14,26 @@ import (
)
type socket struct {
addr string
addr string
tlsConfig *tls.Config
conn *websocket.Conn
readMutex sync.Mutex
writeMutex sync.Mutex
reader io.Reader
}
func newSocket(clientConn Connection) *socket {
return &socket{addr: clientConn.uri("connect")}
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {
ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil)
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: s.tlsConfig,
}
ws, _, err := dialer.Dial(s.addr, nil)
if err != nil {
return err
}

226
test/server.go Normal file
View file

@ -0,0 +1,226 @@
package livesharetest
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
type Server struct {
password string
services map[string]RpcHandleFunc
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
errCh chan error
}
func NewServer(opts ...ServerOption) (*Server, error) {
server := new(Server)
for _, o := range opts {
if err := o(server); err != nil {
return nil, err
}
}
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
if err != nil {
return nil, fmt.Errorf("error reading private.key: %v", err)
}
privateKey, err := ssh.ParsePrivateKey(b)
if err != nil {
return nil, fmt.Errorf("error parsing key: %v", err)
}
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
return server, nil
}
type ServerOption func(*Server) error
func WithPassword(password string) ServerOption {
return func(s *Server) error {
s.password = password
return nil
}
}
func WithService(serviceName string, handler RpcHandleFunc) ServerOption {
return func(s *Server) error {
if s.services == nil {
s.services = make(map[string]RpcHandleFunc)
}
s.services[serviceName] = handler
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
return nil, nil
}
return nil, errors.New("password rejected")
}
}
func (s *Server) Close() {
s.httptestServer.Close()
}
func (s *Server) URL() string {
return s.httptestServer.URL
}
func (s *Server) Err() <-chan error {
return s.errCh
}
var upgrader = websocket.Upgrader{}
func newConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
server.errCh <- fmt.Errorf("error upgrading connection: %v", err)
return
}
defer c.Close()
socketConn := newSocketConn(c)
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err)
return
}
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
ch, reqs, err := newChannel.Accept()
if err != nil {
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
}
type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
}
func newRpcHandler(server *Server) *rpcHandler {
return &rpcHandler{server}
}
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
handler, found := r.server.services[req.Method]
if !found {
r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method)
return
}
result, err := handler(req)
if err != nil {
r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err)
return
}
if err := conn.Reply(ctx, req.ID, result); err != nil {
r.server.errCh <- fmt.Errorf("error replying: %v", err)
}
}
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %v", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %v", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %v", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %v", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}