diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index acb35f557..4785a9d99 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1624,10 +1624,11 @@ func Test_createRun(t *testing.T) { func Test_tryDetermineTrackingRef(t *testing.T) { tests := []struct { - name string - cmdStubs func(*run.CommandStubber) - remotes context.Remotes - assert func(ref git.TrackingRef, found bool, t *testing.T) + name string + cmdStubs func(*run.CommandStubber) + remotes context.Remotes + expectedTrackingRef git.TrackingRef + expectedFound bool }{ { name: "empty", @@ -1635,10 +1636,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, { name: "no match", @@ -1656,10 +1655,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, { name: "match", @@ -1681,13 +1678,11 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Equal(t, git.TrackingRef{ - RemoteName: "upstream", - BranchName: "feature", - }, ref) - assert.True(t, found) + expectedTrackingRef: git.TrackingRef{ + RemoteName: "upstream", + BranchName: "feature", }, + expectedFound: true, }, { name: "respect tracking config", @@ -1707,10 +1702,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, } for _, tt := range tests { @@ -1727,7 +1720,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig) - tt.assert(ref, found, t) + assert.Equal(t, tt.expectedTrackingRef, ref) + assert.Equal(t, tt.expectedFound, found) }) } }