cli/third-party/github.com/letsencrypt/boulder/cmd/admin/key_test.go
2025-05-30 12:50:20 -04:00

136 lines
4.2 KiB
Go

package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"os"
"os/user"
"path"
"strconv"
"strings"
"testing"
"time"
"github.com/jmhodges/clock"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/letsencrypt/boulder/core"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/mocks"
sapb "github.com/letsencrypt/boulder/sa/proto"
"github.com/letsencrypt/boulder/test"
)
func TestSPKIHashFromPrivateKey(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
test.AssertNotError(t, err, "creating test private key")
keyHash, err := core.KeyDigest(privKey.Public())
test.AssertNotError(t, err, "computing test SPKI hash")
keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
test.AssertNotError(t, err, "marshalling test private key bytes")
keyFile := path.Join(t.TempDir(), "key.pem")
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
err = os.WriteFile(keyFile, keyPEM, os.ModeAppend)
test.AssertNotError(t, err, "writing test private key file")
a := admin{}
res, err := a.spkiHashFromPrivateKey(keyFile)
test.AssertNotError(t, err, "")
test.AssertByteEquals(t, res, keyHash[:])
}
func TestSPKIHashesFromFile(t *testing.T) {
var spkiHexes []string
for i := range 10 {
h := sha256.Sum256([]byte(strconv.Itoa(i)))
spkiHexes = append(spkiHexes, hex.EncodeToString(h[:]))
}
spkiFile := path.Join(t.TempDir(), "spkis.txt")
err := os.WriteFile(spkiFile, []byte(strings.Join(spkiHexes, "\n")), os.ModeAppend)
test.AssertNotError(t, err, "writing test spki file")
a := admin{}
res, err := a.spkiHashesFromFile(spkiFile)
test.AssertNotError(t, err, "")
for i, spkiHash := range res {
test.AssertEquals(t, hex.EncodeToString(spkiHash), spkiHexes[i])
}
}
// mockSARecordingBlocks is a mock which only implements the AddBlockedKey gRPC
// method.
type mockSARecordingBlocks struct {
sapb.StorageAuthorityClient
blockRequests []*sapb.AddBlockedKeyRequest
}
// AddBlockedKey is a mock which always succeeds and records the request it
// received.
func (msa *mockSARecordingBlocks) AddBlockedKey(ctx context.Context, req *sapb.AddBlockedKeyRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
msa.blockRequests = append(msa.blockRequests, req)
return &emptypb.Empty{}, nil
}
func (msa *mockSARecordingBlocks) reset() {
msa.blockRequests = nil
}
type mockSARO struct {
sapb.StorageAuthorityReadOnlyClient
}
func (sa *mockSARO) GetSerialsByKey(ctx context.Context, _ *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
}
func (sa *mockSARO) KeyBlocked(ctx context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (*sapb.Exists, error) {
return &sapb.Exists{Exists: false}, nil
}
func TestBlockSPKIHash(t *testing.T) {
fc := clock.NewFake()
fc.Set(time.Now())
log := blog.NewMock()
msa := mockSARecordingBlocks{}
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
test.AssertNotError(t, err, "creating test private key")
keyHash, err := core.KeyDigest(privKey.Public())
test.AssertNotError(t, err, "computing test SPKI hash")
a := admin{saroc: &mockSARO{}, sac: &msa, clk: fc, log: log}
u := &user.User{}
// A full run should result in one request with the right fields.
msa.reset()
log.Clear()
a.dryRun = false
err = a.blockSPKIHash(context.Background(), keyHash[:], u, "hello world")
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
test.AssertEquals(t, len(msa.blockRequests), 1)
test.AssertByteEquals(t, msa.blockRequests[0].KeyHash, keyHash[:])
test.AssertContains(t, msa.blockRequests[0].Comment, "hello world")
// A dry-run should result in zero requests and two log lines.
msa.reset()
log.Clear()
a.dryRun = true
a.sac = dryRunSAC{log: log}
err = a.blockSPKIHash(context.Background(), keyHash[:], u, "")
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 1)
test.AssertEquals(t, len(msa.blockRequests), 0)
}