Rework ref usage when finding and creating PRs

This commit is contained in:
William Martin 2025-04-02 13:21:10 +02:00 committed by Kynan Ware
parent ebd147b43e
commit a9dbda6913
28 changed files with 2254 additions and 1110 deletions

View file

@ -381,7 +381,6 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte,
// Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors,
// as an empty config is not necessarily breaking.
func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) {
prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch))
args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|pushremote|%s)$", prefix, MergeBaseConfig)}
cmd, err := c.Command(ctx, args...)
@ -441,18 +440,50 @@ func (c *Client) SetBranchConfig(ctx context.Context, branch, name, value string
return err
}
// PushDefault defines the action git push should take if no refspec is given.
// See: https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault
type PushDefault string
const (
PushDefaultNothing PushDefault = "nothing"
PushDefaultCurrent PushDefault = "current"
PushDefaultUpstream PushDefault = "upstream"
PushDefaultTracking PushDefault = "tracking"
PushDefaultSimple PushDefault = "simple"
PushDefaultMatching PushDefault = "matching"
)
func ParsePushDefault(s string) (PushDefault, error) {
validPushDefaults := map[string]struct{}{
string(PushDefaultNothing): {},
string(PushDefaultCurrent): {},
string(PushDefaultUpstream): {},
string(PushDefaultTracking): {},
string(PushDefaultSimple): {},
string(PushDefaultMatching): {},
}
if _, ok := validPushDefaults[s]; ok {
return PushDefault(s), nil
}
return "", fmt.Errorf("unknown push.default value: %s", s)
}
// PushDefault returns the value of push.default in the config. If the value
// is not set, it returns "simple" (the default git value). See
// https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault
func (c *Client) PushDefault(ctx context.Context) (string, error) {
func (c *Client) PushDefault(ctx context.Context) (PushDefault, error) {
pushDefault, err := c.Config(ctx, "push.default")
if err == nil {
return pushDefault, nil
return ParsePushDefault(pushDefault)
}
// If there is an error that the config key is not set, return the default value
// that git uses since 2.0.
var gitError *GitError
if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 {
return "simple", nil
return PushDefaultSimple, nil
}
return "", err
}
@ -473,13 +504,48 @@ func (c *Client) RemotePushDefault(ctx context.Context) (string, error) {
return "", err
}
// ParsePushRevision gets the value of the @{push} revision syntax
// RemoteTrackingRef is the structured form of the string "refs/remotes/<remote>/<branch>".
// For example, the @{push} revision syntax could report "refs/remotes/origin/main" which would
// be parsed into RemoteTrackingRef{Remote: "origin", Branch: "main"}.
type RemoteTrackingRef struct {
Remote string
Branch string
}
func (r RemoteTrackingRef) String() string {
return fmt.Sprintf("refs/remotes/%s/%s", r.Remote, r.Branch)
}
// ParseRemoteTrackingRef parses a string of the form "refs/remotes/<remote>/<branch>" into
// a RemoteTrackingBranch struct. If the string does not match this format, an error is returned.
func ParseRemoteTrackingRef(s string) (RemoteTrackingRef, error) {
parts := strings.Split(s, "/")
if len(parts) != 4 || parts[0] != "refs" || parts[1] != "remotes" {
return RemoteTrackingRef{}, fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: %s", s)
}
return RemoteTrackingRef{
Remote: parts[2],
Branch: parts[3],
}, nil
}
// PushRevision gets the value of the @{push} revision syntax
// An error here doesn't necessarily mean something is broken, but may mean that the @{push}
// revision syntax couldn't be resolved, such as in non-centralized workflows with
// push.default = simple. Downstream consumers should consider how to handle this error.
func (c *Client) ParsePushRevision(ctx context.Context, branch string) (string, error) {
revParseOut, err := c.revParse(ctx, "--abbrev-ref", branch+"@{push}")
return firstLine(revParseOut), err
func (c *Client) PushRevision(ctx context.Context, branch string) (RemoteTrackingRef, error) {
revParseOut, err := c.revParse(ctx, "--symbolic-full-name", branch+"@{push}")
if err != nil {
return RemoteTrackingRef{}, err
}
ref, err := ParseRemoteTrackingRef(firstLine(revParseOut))
if err != nil {
return RemoteTrackingRef{}, fmt.Errorf("could not parse push revision: %v", err)
}
return ref, nil
}
func (c *Client) DeleteLocalTag(ctx context.Context, tag string) error {

View file

@ -952,7 +952,7 @@ func TestClientPushDefault(t *testing.T) {
tests := []struct {
name string
commandResult commandResult
wantPushDefault string
wantPushDefault PushDefault
wantError *GitError
}{
{
@ -961,7 +961,7 @@ func TestClientPushDefault(t *testing.T) {
ExitStatus: 1,
Stderr: "error: key does not contain a section: remote.pushDefault",
},
wantPushDefault: "simple",
wantPushDefault: PushDefaultSimple,
wantError: nil,
},
{
@ -970,7 +970,7 @@ func TestClientPushDefault(t *testing.T) {
ExitStatus: 0,
Stdout: "current",
},
wantPushDefault: "current",
wantPushDefault: PushDefaultCurrent,
wantError: nil,
},
{
@ -1077,17 +1077,17 @@ func TestClientParsePushRevision(t *testing.T) {
name string
branch string
commandResult commandResult
wantParsedPushRevision string
wantError *GitError
wantParsedPushRevision RemoteTrackingRef
wantError error
}{
{
name: "@{push} resolves to origin/branchName",
name: "@{push} resolves to refs/remotes/origin/branchName",
branch: "branchName",
commandResult: commandResult{
ExitStatus: 0,
Stdout: "origin/branchName",
Stdout: "refs/remotes/origin/branchName",
},
wantParsedPushRevision: "origin/branchName",
wantParsedPushRevision: RemoteTrackingRef{Remote: "origin", Branch: "branchName"},
},
{
name: "@{push} doesn't resolve",
@ -1095,16 +1095,25 @@ func TestClientParsePushRevision(t *testing.T) {
ExitStatus: 128,
Stderr: "fatal: git error",
},
wantParsedPushRevision: "",
wantParsedPushRevision: RemoteTrackingRef{},
wantError: &GitError{
ExitCode: 128,
Stderr: "fatal: git error",
},
},
{
name: "@{push} resolves to something surprising",
commandResult: commandResult{
ExitStatus: 0,
Stdout: "not/a/valid/remote/ref",
},
wantParsedPushRevision: RemoteTrackingRef{},
wantError: fmt.Errorf("could not parse push revision: remote tracking branch must have format refs/remotes/<remote>/<branch> but was: not/a/valid/remote/ref"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := fmt.Sprintf("path/to/git rev-parse --abbrev-ref %s@{push}", tt.branch)
cmd := fmt.Sprintf("path/to/git rev-parse --symbolic-full-name %s@{push}", tt.branch)
cmdCtx := createMockedCommandContext(t, mockedCommands{
args(cmd): tt.commandResult,
})
@ -1112,20 +1121,91 @@ func TestClientParsePushRevision(t *testing.T) {
GitPath: "path/to/git",
commandContext: cmdCtx,
}
pushDefault, err := client.ParsePushRevision(context.Background(), tt.branch)
trackingRef, err := client.PushRevision(context.Background(), tt.branch)
if tt.wantError != nil {
var gitError *GitError
require.ErrorAs(t, err, &gitError)
assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode)
assert.Equal(t, tt.wantError.Stderr, gitError.Stderr)
var wantErrorAsGit *GitError
if errors.As(err, &wantErrorAsGit) {
var gitError *GitError
require.ErrorAs(t, err, &gitError)
assert.Equal(t, wantErrorAsGit.ExitCode, gitError.ExitCode)
assert.Equal(t, wantErrorAsGit.Stderr, gitError.Stderr)
} else {
assert.Equal(t, err, tt.wantError)
}
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantParsedPushRevision, pushDefault)
assert.Equal(t, tt.wantParsedPushRevision, trackingRef)
})
}
}
func TestRemoteTrackingRef(t *testing.T) {
t.Run("parsing", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
remoteTrackingRef string
wantRemoteTrackingRef RemoteTrackingRef
wantError error
}{
{
name: "valid remote tracking ref",
remoteTrackingRef: "refs/remotes/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{
Remote: "origin",
Branch: "branchName",
},
},
{
name: "incorrect parts",
remoteTrackingRef: "refs/remotes/origin",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: refs/remotes/origin"),
},
{
name: "incorrect prefix type",
remoteTrackingRef: "invalid/remotes/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: invalid/remotes/origin/branchName"),
},
{
name: "incorrect ref type",
remoteTrackingRef: "refs/invalid/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: refs/invalid/origin/branchName"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
trackingRef, err := ParseRemoteTrackingRef(tt.remoteTrackingRef)
if tt.wantError != nil {
require.Equal(t, tt.wantError, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantRemoteTrackingRef, trackingRef)
})
}
})
t.Run("stringifying", func(t *testing.T) {
t.Parallel()
remoteTrackingRef := RemoteTrackingRef{
Remote: "origin",
Branch: "branchName",
}
require.Equal(t, "refs/remotes/origin/branchName", remoteTrackingRef.String())
})
}
func TestClientDeleteLocalTag(t *testing.T) {
tests := []struct {
name string
@ -1992,6 +2072,41 @@ func TestCredentialPatternFromHost(t *testing.T) {
}
}
func TestPushDefault(t *testing.T) {
t.Run("it parses valid values correctly", func(t *testing.T) {
t.Parallel()
tests := []struct {
value string
expectedPushDefault PushDefault
}{
{"nothing", PushDefaultNothing},
{"current", PushDefaultCurrent},
{"upstream", PushDefaultUpstream},
{"tracking", PushDefaultTracking},
{"simple", PushDefaultSimple},
{"matching", PushDefaultMatching},
}
for _, test := range tests {
t.Run(test.value, func(t *testing.T) {
t.Parallel()
pushDefault, err := ParsePushDefault(test.value)
require.NoError(t, err)
assert.Equal(t, test.expectedPushDefault, pushDefault)
})
}
})
t.Run("it returns an error for invalid values", func(t *testing.T) {
t.Parallel()
_, err := ParsePushDefault("invalid")
require.Error(t, err)
})
}
func createCommandContext(t *testing.T, exitStatus int, stdout, stderr string) (*exec.Cmd, commandCtx) {
cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperProcess", "--")
cmd.Env = []string{