136 lines
4.2 KiB
Go
136 lines
4.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"os"
|
|
"os/user"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jmhodges/clock"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"github.com/letsencrypt/boulder/core"
|
|
blog "github.com/letsencrypt/boulder/log"
|
|
"github.com/letsencrypt/boulder/mocks"
|
|
sapb "github.com/letsencrypt/boulder/sa/proto"
|
|
"github.com/letsencrypt/boulder/test"
|
|
)
|
|
|
|
func TestSPKIHashFromPrivateKey(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
test.AssertNotError(t, err, "creating test private key")
|
|
keyHash, err := core.KeyDigest(privKey.Public())
|
|
test.AssertNotError(t, err, "computing test SPKI hash")
|
|
|
|
keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
|
|
test.AssertNotError(t, err, "marshalling test private key bytes")
|
|
keyFile := path.Join(t.TempDir(), "key.pem")
|
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
|
|
err = os.WriteFile(keyFile, keyPEM, os.ModeAppend)
|
|
test.AssertNotError(t, err, "writing test private key file")
|
|
|
|
a := admin{}
|
|
|
|
res, err := a.spkiHashFromPrivateKey(keyFile)
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertByteEquals(t, res, keyHash[:])
|
|
}
|
|
|
|
func TestSPKIHashesFromFile(t *testing.T) {
|
|
var spkiHexes []string
|
|
for i := range 10 {
|
|
h := sha256.Sum256([]byte(strconv.Itoa(i)))
|
|
spkiHexes = append(spkiHexes, hex.EncodeToString(h[:]))
|
|
}
|
|
|
|
spkiFile := path.Join(t.TempDir(), "spkis.txt")
|
|
err := os.WriteFile(spkiFile, []byte(strings.Join(spkiHexes, "\n")), os.ModeAppend)
|
|
test.AssertNotError(t, err, "writing test spki file")
|
|
|
|
a := admin{}
|
|
|
|
res, err := a.spkiHashesFromFile(spkiFile)
|
|
test.AssertNotError(t, err, "")
|
|
for i, spkiHash := range res {
|
|
test.AssertEquals(t, hex.EncodeToString(spkiHash), spkiHexes[i])
|
|
}
|
|
}
|
|
|
|
// mockSARecordingBlocks is a mock which only implements the AddBlockedKey gRPC
|
|
// method.
|
|
type mockSARecordingBlocks struct {
|
|
sapb.StorageAuthorityClient
|
|
blockRequests []*sapb.AddBlockedKeyRequest
|
|
}
|
|
|
|
// AddBlockedKey is a mock which always succeeds and records the request it
|
|
// received.
|
|
func (msa *mockSARecordingBlocks) AddBlockedKey(ctx context.Context, req *sapb.AddBlockedKeyRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
|
|
msa.blockRequests = append(msa.blockRequests, req)
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (msa *mockSARecordingBlocks) reset() {
|
|
msa.blockRequests = nil
|
|
}
|
|
|
|
type mockSARO struct {
|
|
sapb.StorageAuthorityReadOnlyClient
|
|
}
|
|
|
|
func (sa *mockSARO) GetSerialsByKey(ctx context.Context, _ *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
|
|
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
|
|
}
|
|
|
|
func (sa *mockSARO) KeyBlocked(ctx context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (*sapb.Exists, error) {
|
|
return &sapb.Exists{Exists: false}, nil
|
|
}
|
|
|
|
func TestBlockSPKIHash(t *testing.T) {
|
|
fc := clock.NewFake()
|
|
fc.Set(time.Now())
|
|
log := blog.NewMock()
|
|
msa := mockSARecordingBlocks{}
|
|
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
test.AssertNotError(t, err, "creating test private key")
|
|
keyHash, err := core.KeyDigest(privKey.Public())
|
|
test.AssertNotError(t, err, "computing test SPKI hash")
|
|
|
|
a := admin{saroc: &mockSARO{}, sac: &msa, clk: fc, log: log}
|
|
u := &user.User{}
|
|
|
|
// A full run should result in one request with the right fields.
|
|
msa.reset()
|
|
log.Clear()
|
|
a.dryRun = false
|
|
err = a.blockSPKIHash(context.Background(), keyHash[:], u, "hello world")
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
|
|
test.AssertEquals(t, len(msa.blockRequests), 1)
|
|
test.AssertByteEquals(t, msa.blockRequests[0].KeyHash, keyHash[:])
|
|
test.AssertContains(t, msa.blockRequests[0].Comment, "hello world")
|
|
|
|
// A dry-run should result in zero requests and two log lines.
|
|
msa.reset()
|
|
log.Clear()
|
|
a.dryRun = true
|
|
a.sac = dryRunSAC{log: log}
|
|
err = a.blockSPKIHash(context.Background(), keyHash[:], u, "")
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
|
|
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 1)
|
|
test.AssertEquals(t, len(msa.blockRequests), 0)
|
|
}
|