437 lines
14 KiB
Go
437 lines
14 KiB
Go
package portforwarder
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/cli/cli/v2/internal/codespaces/connection"
|
|
"github.com/microsoft/dev-tunnels/go/tunnels"
|
|
)
|
|
|
|
const (
|
|
githubSubjectId = "1"
|
|
InternalPortLabel = "InternalPort"
|
|
UserForwardedPortLabel = "UserForwardedPort"
|
|
)
|
|
|
|
const (
|
|
PrivatePortVisibility = "private"
|
|
OrgPortVisibility = "org"
|
|
PublicPortVisibility = "public"
|
|
)
|
|
|
|
const (
|
|
trafficTypeInput = "input"
|
|
trafficTypeOutput = "output"
|
|
)
|
|
|
|
type ForwardPortOpts struct {
|
|
Port int
|
|
Internal bool
|
|
KeepAlive bool
|
|
Visibility string
|
|
}
|
|
|
|
type CodespacesPortForwarder struct {
|
|
connection connection.CodespaceConnection
|
|
keepAliveReason chan string
|
|
}
|
|
|
|
type PortForwarder interface {
|
|
ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error
|
|
ForwardPort(ctx context.Context, opts ForwardPortOpts) error
|
|
ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error
|
|
ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error)
|
|
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
|
|
KeepAlive(reason string)
|
|
GetKeepAliveReason() string
|
|
Close() error
|
|
}
|
|
|
|
// NewPortForwarder returns a new PortForwarder for the specified codespace.
|
|
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd PortForwarder, err error) {
|
|
return &CodespacesPortForwarder{
|
|
connection: *codespaceConnection,
|
|
keepAliveReason: make(chan string, 1),
|
|
}, nil
|
|
}
|
|
|
|
// ForwardPortToListener forwards the specified port to the given TCP listener.
|
|
func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error {
|
|
err := fwd.ForwardPort(ctx, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("error forwarding port: %w", err)
|
|
}
|
|
|
|
done := make(chan error)
|
|
go func() {
|
|
// Convert the port number to a uint16
|
|
port, err := convertIntToUint16(opts.Port)
|
|
if err != nil {
|
|
done <- fmt.Errorf("error converting port: %w", err)
|
|
return
|
|
}
|
|
|
|
// Ensure the port is forwarded before connecting
|
|
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, port)
|
|
if err != nil {
|
|
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Connect to the forwarded port
|
|
err = fwd.connectListenerToForwardedPort(ctx, opts, listener)
|
|
if err != nil {
|
|
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
return fmt.Errorf("error connecting to tunnel: %w", err)
|
|
}
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ForwardPort informs the host that we would like to forward the given port.
|
|
func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts ForwardPortOpts) error {
|
|
// Convert the port number to a uint16
|
|
port, err := convertIntToUint16(opts.Port)
|
|
if err != nil {
|
|
return fmt.Errorf("error converting port: %w", err)
|
|
}
|
|
|
|
// In v0.0.25 of dev-tunnels, the dev-tunnel manager `CreateTunnelPort` would "accept" requests that
|
|
// change the port protocol but they would not result in any actual change. This has changed, resulting in
|
|
// an error `Invalid arguments. The tunnel port protocol cannot be changed.`. It's not clear why the previous
|
|
// behaviour existed, whether it was truly the API version, or whether the `If-Not-Match` header being set inside
|
|
// `CreateTunnelPort` avoided the server accepting the request to change the protocol and that has since regressed.
|
|
//
|
|
// In any case, now we check whether a port exists with the given port number, if it does, we use the existing protocol.
|
|
// If it doesn't exist, we default to HTTP, which was the previous behaviour for all ports.
|
|
protocol := tunnels.TunnelProtocolHttp
|
|
|
|
existingPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, opts.Port, fwd.connection.Options)
|
|
if err != nil && !strings.Contains(err.Error(), "404") {
|
|
return fmt.Errorf("error checking whether tunnel port already exists: %v", err)
|
|
}
|
|
|
|
if existingPort != nil {
|
|
protocol = tunnels.TunnelProtocol(existingPort.Protocol)
|
|
}
|
|
|
|
tunnelPort := tunnels.NewTunnelPort(port, "", "", protocol)
|
|
|
|
// If no visibility is provided, Dev Tunnels will use the default (private)
|
|
if opts.Visibility != "" {
|
|
// Check if the requested visibility is allowed
|
|
allowed := false
|
|
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
|
|
if allowedVisibility == opts.Visibility {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// If the requested visibility is not allowed, return an error
|
|
if !allowed {
|
|
return fmt.Errorf("visibility %s is not allowed", opts.Visibility)
|
|
}
|
|
|
|
accessControlEntries := visibilityToAccessControlEntries(opts.Visibility)
|
|
if len(accessControlEntries) > 0 {
|
|
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
|
|
Entries: accessControlEntries,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
|
|
if opts.Internal {
|
|
tunnelPort.Labels = []string{InternalPortLabel}
|
|
} else {
|
|
tunnelPort.Labels = []string{UserForwardedPortLabel}
|
|
}
|
|
|
|
// Create the tunnel port
|
|
_, err = fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
|
|
if err != nil && !strings.Contains(err.Error(), "409") {
|
|
return fmt.Errorf("create tunnel port failed: %v", err)
|
|
}
|
|
|
|
// Connect to the tunnel
|
|
err = fwd.connection.Connect(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("connect failed: %v", err)
|
|
}
|
|
|
|
// Inform the host that we've forwarded the port locally
|
|
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("refresh ports failed: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// connectListenerToForwardedPort connects to the forwarded port via a local TCP port.
|
|
func (fwd *CodespacesPortForwarder) connectListenerToForwardedPort(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) (err error) {
|
|
errc := make(chan error, 1)
|
|
sendError := func(err error) {
|
|
// Use non-blocking send, to avoid goroutines getting
|
|
// stuck in case of concurrent or sequential errors.
|
|
select {
|
|
case errc <- err:
|
|
default:
|
|
}
|
|
}
|
|
go func() {
|
|
for {
|
|
conn, err := listener.AcceptTCP()
|
|
if err != nil {
|
|
sendError(err)
|
|
return
|
|
}
|
|
|
|
// Connect to the forwarded port in a goroutine so we can accept new connections
|
|
go func() {
|
|
if err := fwd.ConnectToForwardedPort(ctx, conn, opts); err != nil {
|
|
sendError(err)
|
|
}
|
|
}()
|
|
}
|
|
}()
|
|
|
|
// Wait for an error or for the context to be cancelled
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
case <-ctx.Done():
|
|
return ctx.Err() // canceled
|
|
}
|
|
}
|
|
|
|
// ConnectToForwardedPort connects to the forwarded port via a given ReadWriteCloser.
|
|
// Optionally, it detects traffic over the connection and sends activity signals to the server to keep the codespace from shutting down.
|
|
func (fwd *CodespacesPortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error {
|
|
// Create a traffic monitor to keep the session alive
|
|
if opts.KeepAlive {
|
|
conn = newTrafficMonitor(conn, fwd)
|
|
}
|
|
|
|
// Convert the port number to a uint16
|
|
port, err := convertIntToUint16(opts.Port)
|
|
if err != nil {
|
|
return fmt.Errorf("error converting port: %w", err)
|
|
}
|
|
|
|
// Connect to the forwarded port
|
|
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, conn, port)
|
|
if err != nil {
|
|
return fmt.Errorf("error connecting to forwarded port: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListPorts fetches the list of ports that are currently forwarded.
|
|
func (fwd *CodespacesPortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
|
|
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error listing ports: %w", err)
|
|
}
|
|
|
|
return ports, nil
|
|
}
|
|
|
|
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
|
|
func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
|
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting tunnel port: %w", err)
|
|
}
|
|
|
|
// If the port visibility isn't changing, don't do anything
|
|
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
|
|
return nil
|
|
}
|
|
|
|
// Delete the existing tunnel port to update
|
|
port, err := convertIntToUint16(remotePort)
|
|
if err != nil {
|
|
return fmt.Errorf("error converting port: %w", err)
|
|
}
|
|
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, port, fwd.connection.Options)
|
|
if err != nil {
|
|
return fmt.Errorf("error deleting tunnel port: %w", err)
|
|
}
|
|
|
|
done := make(chan error)
|
|
go func() {
|
|
// Connect to the tunnel
|
|
err := fwd.connection.Connect(ctx)
|
|
if err != nil {
|
|
done <- fmt.Errorf("connect failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Inform the host that we've deleted the port
|
|
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
|
if err != nil {
|
|
done <- fmt.Errorf("refresh ports failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Re-forward the port with the updated visibility
|
|
err = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: visibility})
|
|
if err != nil {
|
|
done <- fmt.Errorf("error forwarding port: %w", err)
|
|
return
|
|
}
|
|
|
|
done <- nil
|
|
}()
|
|
|
|
// Wait for the done channel to be closed
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
// If we fail to re-forward the port, we need to forward again with the original visibility so the port is still accessible
|
|
_ = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries)})
|
|
|
|
return fmt.Errorf("error connecting to tunnel: %w", err)
|
|
}
|
|
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// KeepAlive accepts a reason that is retained if there is no active reason
|
|
// to send to the server.
|
|
func (fwd *CodespacesPortForwarder) KeepAlive(reason string) {
|
|
select {
|
|
case fwd.keepAliveReason <- reason:
|
|
default:
|
|
// there is already an active keep alive reason
|
|
// so we can ignore this one
|
|
}
|
|
}
|
|
|
|
// GetKeepAliveReason fetches the keep alive reason from the channel and returns it.
|
|
func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
|
|
return <-fwd.keepAliveReason
|
|
}
|
|
|
|
// Close closes the port forwarder's tunnel client connection.
|
|
func (fwd *CodespacesPortForwarder) Close() error {
|
|
return fwd.connection.Close()
|
|
}
|
|
|
|
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
|
|
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
|
|
for _, entry := range accessControlEntries {
|
|
// If we have the anonymous type (and we're not denying it), it's public
|
|
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
|
|
return PublicPortVisibility
|
|
}
|
|
|
|
// If we have the organizations type (and we're not denying it), it's org
|
|
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
|
|
return OrgPortVisibility
|
|
}
|
|
}
|
|
|
|
// Else, it's private
|
|
return PrivatePortVisibility
|
|
}
|
|
|
|
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
|
|
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
|
|
switch visibility {
|
|
case PublicPortVisibility:
|
|
return []tunnels.TunnelAccessControlEntry{{
|
|
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
|
|
Subjects: []string{},
|
|
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
|
|
}}
|
|
case OrgPortVisibility:
|
|
return []tunnels.TunnelAccessControlEntry{{
|
|
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
|
Subjects: []string{githubSubjectId},
|
|
Scopes: []string{
|
|
string(tunnels.TunnelAccessScopeConnect),
|
|
},
|
|
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
|
}}
|
|
default:
|
|
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
|
|
return []tunnels.TunnelAccessControlEntry{{
|
|
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
|
Subjects: []string{githubSubjectId},
|
|
Scopes: []string{},
|
|
IsDeny: true,
|
|
}}
|
|
}
|
|
}
|
|
|
|
// IsInternalPort returns true if the port is internal.
|
|
func IsInternalPort(port *tunnels.TunnelPort) bool {
|
|
for _, label := range port.Labels {
|
|
if strings.EqualFold(label, InternalPortLabel) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// convertIntToUint16 converts the given int to a uint16.
|
|
func convertIntToUint16(port int) (uint16, error) {
|
|
var updatedPort uint16
|
|
if port >= 0 && port <= 65535 {
|
|
updatedPort = uint16(port)
|
|
} else {
|
|
return 0, fmt.Errorf("invalid port number: %d", port)
|
|
}
|
|
|
|
return updatedPort, nil
|
|
}
|
|
|
|
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
|
// it of the traffic type during Read operations.
|
|
type trafficMonitor struct {
|
|
rwc io.ReadWriteCloser
|
|
fwd PortForwarder
|
|
}
|
|
|
|
// newTrafficMonitor returns a trafficMonitor for the specified codespace connection.
|
|
// It wraps the provided io.ReaderWriteCloser with its own Read/Write/Close methods.
|
|
func newTrafficMonitor(rwc io.ReadWriteCloser, fwd PortForwarder) *trafficMonitor {
|
|
return &trafficMonitor{rwc, fwd}
|
|
}
|
|
|
|
// Read wraps the underlying ReadWriteCloser's Read method and keeps the session alive with the "input" traffic type.
|
|
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
|
t.fwd.KeepAlive(trafficTypeInput)
|
|
return t.rwc.Read(p)
|
|
}
|
|
|
|
// Write wraps the underlying ReadWriteCloser's Write method and keeps the session alive with the "output" traffic type.
|
|
func (t *trafficMonitor) Write(p []byte) (n int, err error) {
|
|
t.fwd.KeepAlive(trafficTypeOutput)
|
|
return t.rwc.Write(p)
|
|
}
|
|
|
|
// Close closes the underlying ReadWriteCloser.
|
|
func (t *trafficMonitor) Close() error {
|
|
return t.rwc.Close()
|
|
}
|