ocsp: Add caching staples to disk to store dir

Signed-off-by: Waldemar Quevedo <wally@synadia.com>
This commit is contained in:
Waldemar Quevedo
2021-05-21 18:27:04 -07:00
parent b2e1ff7a7c
commit d78a91836b
6 changed files with 270 additions and 14 deletions

View File

@@ -1464,6 +1464,16 @@ func TestConfigCheck(t *testing.T) {
errorLine: 6,
errorPos: 10,
},
{
name: "ambiguous store dir",
config: `
store_dir: "foo"
jetstream {
store_dir: "bar"
}
`,
err: fmt.Errorf(`Duplicate 'store_dir' configuration`),
},
}
checkConfig := func(config string) error {
@@ -1499,10 +1509,17 @@ func TestConfigCheck(t *testing.T) {
}
if err != nil && expectedErr != nil {
msg := fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error())
if test.reason != "" {
msg += ": " + test.reason
var msg string
if test.errorPos > 0 {
msg = fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error())
if test.reason != "" {
msg += ": " + test.reason
}
} else {
msg = test.reason
}
if !strings.Contains(err.Error(), msg) {
t.Errorf("Expected:\n%q\ngot:\n%q", msg, err.Error())
}

View File

@@ -14,6 +14,7 @@
package server
import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
@@ -22,6 +23,8 @@ import (
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -30,6 +33,7 @@ import (
)
const (
defaultOCSPStoreDir = "ocsp"
defaultOCSPCheckInterval = 24 * time.Hour
minOCSPCheckInterval = 2 * time.Minute
)
@@ -92,8 +96,17 @@ func (oc *OCSPMonitor) getNextRun() time.Duration {
func (oc *OCSPMonitor) getStatus() ([]byte, *ocsp.Response, error) {
raw, resp := oc.getCacheStatus()
if len(raw) > 0 && resp != nil {
// Check if the OCSP is still valid.
if err := validOCSPResponse(resp); err == nil {
return raw, resp, nil
}
}
var err error
raw, resp, err = oc.getLocalStatus()
if err == nil {
return raw, resp, nil
}
return oc.getRemoteStatus()
}
@@ -103,6 +116,41 @@ func (oc *OCSPMonitor) getCacheStatus() ([]byte, *ocsp.Response) {
return oc.raw, oc.resp
}
func (oc *OCSPMonitor) getLocalStatus() ([]byte, *ocsp.Response, error) {
opts := oc.srv.getOpts()
storeDir := opts.StoreDir
if storeDir == _EMPTY_ {
return nil, nil, fmt.Errorf("store_dir not set")
}
// This key must be based upon the current full certificate, not the public key,
// so MUST be on the full raw certificate and not an SPKI or other reduced form.
key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw))
oc.mu.Lock()
raw, err := ioutil.ReadFile(filepath.Join(storeDir, defaultOCSPStoreDir, key))
oc.mu.Unlock()
if err != nil {
return nil, nil, err
}
resp, err := ocsp.ParseResponse(raw, oc.Issuer)
if err != nil {
return nil, nil, err
}
if err := validOCSPResponse(resp); err != nil {
return nil, nil, err
}
// Cache the response.
oc.mu.Lock()
oc.raw = raw
oc.resp = resp
oc.mu.Unlock()
return raw, resp, nil
}
func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) {
opts := oc.srv.getOpts()
var overrideURLs []string
@@ -163,6 +211,13 @@ func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) {
return nil, nil, err
}
if storeDir := opts.StoreDir; storeDir != _EMPTY_ {
key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw))
if err := oc.writeOCSPStatus(storeDir, key, raw); err != nil {
return nil, nil, fmt.Errorf("failed to write ocsp status: %w", err)
}
}
oc.mu.Lock()
oc.raw = raw
oc.resp = resp
@@ -201,7 +256,7 @@ func (oc *OCSPMonitor) run() {
return
}
for s.Running() {
for {
// On reload, if the certificate changes then need to stop this monitor.
select {
case <-time.After(nextRun):
@@ -294,6 +349,10 @@ func (srv *Server) NewOCSPMonitor(tc *tls.Config) (*tls.Config, *OCSPMonitor, er
return tc, nil, nil
}
if err := srv.setupOCSPStapleStoreDir(); err != nil {
return nil, nil, err
}
// TODO: Add OCSP 'responder_cert' option in case CA cert not available.
issuer, err := getOCSPIssuer(caFile, cert.Certificate)
if err != nil {
@@ -384,6 +443,35 @@ func hasOCSPStatusRequest(cert *x509.Certificate) bool {
return false
}
// writeOCSPStatus writes an OCSP status to a temporary file then moves it to a
// new path, in an attempt to avoid corrupting existing data.
func (oc *OCSPMonitor) writeOCSPStatus(storeDir, file string, data []byte) error {
storeDir = filepath.Join(storeDir, defaultOCSPStoreDir)
tmp, err := ioutil.TempFile(storeDir, "tmp-cert-status")
if err != nil {
return err
}
if _, err := tmp.Write(data); err != nil {
tmp.Close()
os.Remove(tmp.Name())
return err
}
if err := tmp.Close(); err != nil {
return err
}
oc.mu.Lock()
err = os.Rename(tmp.Name(), filepath.Join(storeDir, file))
oc.mu.Unlock()
if err != nil {
os.Remove(tmp.Name())
return err
}
return nil
}
func parseCertPEM(name string) (*x509.Certificate, error) {
data, err := ioutil.ReadFile(name)
if err != nil {

View File

@@ -793,6 +793,13 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
*errors = append(*errors, err)
return
}
case "store_dir", "storedir":
// Check if JetStream configuration is also setting the storage directory.
if o.StoreDir != "" {
*errors = append(*errors, &configErr{tk, "Duplicate 'store_dir' configuration"})
return
}
o.StoreDir = v.(string)
case "jetstream":
err := parseJetStream(tk, o, errors, warnings)
if err != nil {
@@ -1558,6 +1565,10 @@ func parseJetStream(v interface{}, opts *Options, errors *[]error, warnings *[]e
tk, mv = unwrapValue(mv, &lt)
switch strings.ToLower(mk) {
case "store_dir", "storedir":
// StoreDir can be set at the top level as well so have to prevent ambiguous declarations.
if opts.StoreDir != "" {
return &configErr{tk, "Duplicate 'store_dir' configuration"}
}
opts.StoreDir = mv.(string)
case "max_memory_store", "max_mem_store", "max_mem":
opts.JetStreamMaxMemory = mv.(int64)

View File

@@ -1285,12 +1285,18 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
func (s *Server) reloadOCSP() error {
opts := s.getOpts()
if err := s.setupOCSPStapleStoreDir(); err != nil {
return err
}
s.mu.Lock()
ocsps := s.ocsps
s.mu.Unlock()
// Stop all OCSP Stapling monitors in case there were any running.
var wasEnabled bool
for _, oc := range ocsps {
wasEnabled = true
oc.stop()
}
@@ -1303,15 +1309,17 @@ func (s *Server) reloadOCSP() error {
}
// Check if an OCSP stapling monitor is required for this certificate.
if mon != nil {
s.Noticef("OCSP Stapling enabled for client connections")
ocspm = append(ocspm, mon)
// Override the TLS config with one that follows OCSP.
// Override the TLS config with one that has OCSP enabled.
s.optsMu.Lock()
s.opts.TLSConfig = tc
s.optsMu.Unlock()
s.startGoRoutine(func() { mon.run() })
} else if wasEnabled {
s.Warnf("OCSP Stapling disabled for client connections")
}
s.Noticef("OCSP Stapling enabled for client connections")
}
// Replace stopped monitors with the new ones.
s.mu.Lock()

View File

@@ -1470,6 +1470,24 @@ func (s *Server) fetchAccount(name string) (*Account, error) {
return acc, nil
}
func (s *Server) setupOCSPStapleStoreDir() error {
opts := s.getOpts()
storeDir := opts.StoreDir
if storeDir == _EMPTY_ {
s.Warnf("OCSP Stapling disk cache is disabled (missing 'store_dir')")
return nil
}
storeDir = filepath.Join(storeDir, defaultOCSPStoreDir)
if stat, err := os.Stat(storeDir); os.IsNotExist(err) {
if err := os.MkdirAll(storeDir, defaultDirPerms); err != nil {
return fmt.Errorf("could not create OCSP storage directory - %v", err)
}
} else if stat == nil || !stat.IsDir() {
return fmt.Errorf("OCSP storage directory is not a directory")
}
return nil
}
func (s *Server) enableOCSP() error {
opts := s.getOpts()
@@ -1481,14 +1499,15 @@ func (s *Server) enableOCSP() error {
}
// Check if an OCSP stapling monitor is required for this certificate.
if mon != nil {
s.Noticef("OCSP Stapling enabled for client connections")
s.ocsps = append(s.ocsps, mon)
// Override the TLS config with one that follows OCSP.
opts.TLSConfig = tc
s.startGoRoutine(func() { mon.run() })
}
s.Noticef("OCSP Stapling enabled for client connections")
}
// FIXME: Add support for leafnodes, routes, MQTT, WebSocket
// FIXME: Add support for leafnodes, routes, gateways, MQTT, WebSocket
return nil
}

View File

@@ -14,6 +14,7 @@
package test
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
@@ -23,6 +24,8 @@ import (
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
@@ -50,7 +53,11 @@ func TestOCSPAlwaysMustStapleAndShutdown(t *testing.T) {
addr := fmt.Sprintf("http://%s", ocspr.Addr)
setOCSPStatus(t, addr, serverCert, ocsp.Good)
opts := DefaultTestOptions
opts := server.Options{}
opts.Host = "127.0.0.1"
opts.NoLog = true
opts.NoSigs = true
opts.MaxControlLine = 4096
opts.Port = -1
opts.TLSCert = serverCert
opts.TLSKey = serverKey
@@ -140,7 +147,11 @@ func TestOCSPMustStapleShutdown(t *testing.T) {
addr := fmt.Sprintf("http://%s", ocspr.Addr)
setOCSPStatus(t, addr, serverCert, ocsp.Good)
opts := DefaultTestOptions
opts := server.Options{}
opts.Host = "127.0.0.1"
opts.NoLog = true
opts.NoSigs = true
opts.MaxControlLine = 4096
opts.Port = -1
opts.TLSCert = serverCert
opts.TLSKey = serverKey
@@ -321,7 +332,11 @@ func TestOCSPAutoWithoutMustStapleDoesNotShutdownOnRevoke(t *testing.T) {
addr := fmt.Sprintf("http://%s", ocspr.Addr)
setOCSPStatus(t, addr, serverCert, ocsp.Good)
opts := DefaultTestOptions
opts := server.Options{}
opts.Host = "127.0.0.1"
opts.NoLog = true
opts.NoSigs = true
opts.MaxControlLine = 4096
opts.Port = -1
opts.TLSCert = serverCert
opts.TLSKey = serverKey
@@ -663,9 +678,14 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
addr := fmt.Sprintf("http://%s", ocspr.Addr)
setOCSPStatus(t, addr, serverCert, ocsp.Good)
content := `
storeDir := createDir(t, "_ocsp")
defer removeDir(t, storeDir)
originalContent := `
port: -1
store_dir: "%s"
tls {
cert_file: "configs/certs/ocsp/server-status-request-url-cert.pem"
key_file: "configs/certs/ocsp/server-status-request-url-key.pem"
@@ -673,14 +693,18 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
timeout: 5
}
`
content := fmt.Sprintf(originalContent, storeDir)
conf := createConfFile(t, []byte(content))
defer removeFile(t, conf)
s, opts := RunServerWithConfig(conf)
defer s.Shutdown()
var staple []byte
nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
nats.Secure(&tls.Config{
VerifyConnection: func(s tls.ConnectionState) error {
staple = s.OCSPResponse
resp, err := getOCSPStatus(s)
if err != nil {
return err
@@ -710,10 +734,37 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
}
nc.Close()
files := []string{}
err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
return nil
}
files = append(files, path)
return nil
})
if err != nil {
t.Fatal(err)
}
found := false
for _, file := range files {
data, err := ioutil.ReadFile(file)
if err != nil {
t.Error(err)
}
if bytes.Equal(staple, data) {
found = true
}
}
if !found {
t.Error("Could not find OCSP Staple")
}
// Change the contents with another that has OCSP Stapling disabled.
content = `
updatedContent := `
port: -1
store_dir: "%s"
tls {
cert_file: "configs/certs/ocsp/server-cert.pem"
key_file: "configs/certs/ocsp/server-key.pem"
@@ -721,6 +772,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
timeout: 5
}
`
content = fmt.Sprintf(updatedContent, storeDir)
if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil {
t.Fatalf("Error writing config: %v", err)
}
@@ -729,7 +781,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
}
// The new certificate does not have must staple so they will be missing.
time.Sleep(2 * time.Second)
time.Sleep(4 * time.Second)
nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
nats.Secure(&tls.Config{
@@ -747,6 +799,67 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
t.Fatal(err)
}
nc.Close()
// Re-enable OCSP Stapling
content = fmt.Sprintf(originalContent, storeDir)
if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil {
t.Fatalf("Error writing config: %v", err)
}
if err := s.Reload(); err != nil {
t.Fatal(err)
}
var newStaple []byte
nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
nats.Secure(&tls.Config{
VerifyConnection: func(s tls.ConnectionState) error {
newStaple = s.OCSPResponse
resp, err := getOCSPStatus(s)
if err != nil {
return err
}
if resp.Status != ocsp.Good {
t.Errorf("Expected valid OCSP staple status")
}
return nil
},
}),
nats.RootCAs(caCert),
nats.ErrorHandler(noOpErrHandler),
)
if err != nil {
t.Fatal(err)
}
nc.Close()
// Confirm that it got a new staple.
files = []string{}
err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
return nil
}
files = append(files, path)
return nil
})
if err != nil {
t.Fatal(err)
}
found = false
for _, file := range files {
data, err := ioutil.ReadFile(file)
if err != nil {
t.Error(err)
}
if bytes.Equal(newStaple, data) {
found = true
}
}
if !found {
t.Error("Could not find OCSP Staple")
}
if bytes.Equal(staple, newStaple) {
t.Error("Expected new OCSP Staple")
}
}
func TestOCSPReloadRotateTLSCertEnableMustStaple(t *testing.T) {