95 lines
2.3 KiB
Go
95 lines
2.3 KiB
Go
package liveshare
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
livesharetest "github.com/github/go-liveshare/test"
|
|
"github.com/sourcegraph/jsonrpc2"
|
|
)
|
|
|
|
func TestNewPortForwarder(t *testing.T) {
|
|
testServer, session, err := makeMockSession()
|
|
if err != nil {
|
|
t.Errorf("create mock client: %v", err)
|
|
}
|
|
defer testServer.Close()
|
|
pf := NewPortForwarder(session, "ssh", 80)
|
|
if pf == nil {
|
|
t.Error("port forwarder is nil")
|
|
}
|
|
}
|
|
|
|
func TestPortForwarderStart(t *testing.T) {
|
|
streamName, streamCondition := "stream-name", "stream-condition"
|
|
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
|
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
|
}
|
|
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
|
return "stream-id", nil
|
|
}
|
|
|
|
stream := bytes.NewBufferString("stream-data")
|
|
testServer, session, err := makeMockSession(
|
|
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
|
livesharetest.WithService("streamManager.getStream", getStream),
|
|
livesharetest.WithStream("stream-id", stream),
|
|
)
|
|
if err != nil {
|
|
t.Errorf("create mock session: %v", err)
|
|
}
|
|
defer testServer.Close()
|
|
|
|
listen, err := net.Listen("tcp", ":8000")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer listen.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
done := make(chan error)
|
|
go func() {
|
|
const name, remote = "ssh", 8000
|
|
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
|
|
}()
|
|
|
|
go func() {
|
|
var conn net.Conn
|
|
retries := 0
|
|
for conn == nil && retries < 2 {
|
|
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
if conn == nil {
|
|
done <- errors.New("failed to connect to forwarded port")
|
|
}
|
|
b := make([]byte, len("stream-data"))
|
|
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
|
done <- fmt.Errorf("reading stream: %v", err)
|
|
}
|
|
if string(b) != "stream-data" {
|
|
done <- fmt.Errorf("stream data is not expected value, got: %v", string(b))
|
|
}
|
|
if _, err := conn.Write([]byte("new-data")); err != nil {
|
|
done <- fmt.Errorf("writing to stream: %v", err)
|
|
}
|
|
done <- nil
|
|
}()
|
|
|
|
select {
|
|
case err := <-testServer.Err():
|
|
t.Errorf("error from server: %v", err)
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Errorf("error from client: %v", err)
|
|
}
|
|
}
|
|
}
|