mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
ocsp: Add caching staples to disk to store dir
Signed-off-by: Waldemar Quevedo <wally@synadia.com>
This commit is contained in:
@@ -1464,6 +1464,16 @@ func TestConfigCheck(t *testing.T) {
|
||||
errorLine: 6,
|
||||
errorPos: 10,
|
||||
},
|
||||
{
|
||||
name: "ambiguous store dir",
|
||||
config: `
|
||||
store_dir: "foo"
|
||||
jetstream {
|
||||
store_dir: "bar"
|
||||
}
|
||||
`,
|
||||
err: fmt.Errorf(`Duplicate 'store_dir' configuration`),
|
||||
},
|
||||
}
|
||||
|
||||
checkConfig := func(config string) error {
|
||||
@@ -1499,10 +1509,17 @@ func TestConfigCheck(t *testing.T) {
|
||||
}
|
||||
|
||||
if err != nil && expectedErr != nil {
|
||||
msg := fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error())
|
||||
if test.reason != "" {
|
||||
msg += ": " + test.reason
|
||||
var msg string
|
||||
|
||||
if test.errorPos > 0 {
|
||||
msg = fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error())
|
||||
if test.reason != "" {
|
||||
msg += ": " + test.reason
|
||||
}
|
||||
} else {
|
||||
msg = test.reason
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), msg) {
|
||||
t.Errorf("Expected:\n%q\ngot:\n%q", msg, err.Error())
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
@@ -22,6 +23,8 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -30,6 +33,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultOCSPStoreDir = "ocsp"
|
||||
defaultOCSPCheckInterval = 24 * time.Hour
|
||||
minOCSPCheckInterval = 2 * time.Minute
|
||||
)
|
||||
@@ -92,8 +96,17 @@ func (oc *OCSPMonitor) getNextRun() time.Duration {
|
||||
func (oc *OCSPMonitor) getStatus() ([]byte, *ocsp.Response, error) {
|
||||
raw, resp := oc.getCacheStatus()
|
||||
if len(raw) > 0 && resp != nil {
|
||||
// Check if the OCSP is still valid.
|
||||
if err := validOCSPResponse(resp); err == nil {
|
||||
return raw, resp, nil
|
||||
}
|
||||
}
|
||||
var err error
|
||||
raw, resp, err = oc.getLocalStatus()
|
||||
if err == nil {
|
||||
return raw, resp, nil
|
||||
}
|
||||
|
||||
return oc.getRemoteStatus()
|
||||
}
|
||||
|
||||
@@ -103,6 +116,41 @@ func (oc *OCSPMonitor) getCacheStatus() ([]byte, *ocsp.Response) {
|
||||
return oc.raw, oc.resp
|
||||
}
|
||||
|
||||
func (oc *OCSPMonitor) getLocalStatus() ([]byte, *ocsp.Response, error) {
|
||||
opts := oc.srv.getOpts()
|
||||
storeDir := opts.StoreDir
|
||||
if storeDir == _EMPTY_ {
|
||||
return nil, nil, fmt.Errorf("store_dir not set")
|
||||
}
|
||||
|
||||
// This key must be based upon the current full certificate, not the public key,
|
||||
// so MUST be on the full raw certificate and not an SPKI or other reduced form.
|
||||
key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw))
|
||||
|
||||
oc.mu.Lock()
|
||||
raw, err := ioutil.ReadFile(filepath.Join(storeDir, defaultOCSPStoreDir, key))
|
||||
oc.mu.Unlock()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resp, err := ocsp.ParseResponse(raw, oc.Issuer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := validOCSPResponse(resp); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Cache the response.
|
||||
oc.mu.Lock()
|
||||
oc.raw = raw
|
||||
oc.resp = resp
|
||||
oc.mu.Unlock()
|
||||
|
||||
return raw, resp, nil
|
||||
}
|
||||
|
||||
func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) {
|
||||
opts := oc.srv.getOpts()
|
||||
var overrideURLs []string
|
||||
@@ -163,6 +211,13 @@ func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if storeDir := opts.StoreDir; storeDir != _EMPTY_ {
|
||||
key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw))
|
||||
if err := oc.writeOCSPStatus(storeDir, key, raw); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to write ocsp status: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
oc.mu.Lock()
|
||||
oc.raw = raw
|
||||
oc.resp = resp
|
||||
@@ -201,7 +256,7 @@ func (oc *OCSPMonitor) run() {
|
||||
return
|
||||
}
|
||||
|
||||
for s.Running() {
|
||||
for {
|
||||
// On reload, if the certificate changes then need to stop this monitor.
|
||||
select {
|
||||
case <-time.After(nextRun):
|
||||
@@ -294,6 +349,10 @@ func (srv *Server) NewOCSPMonitor(tc *tls.Config) (*tls.Config, *OCSPMonitor, er
|
||||
return tc, nil, nil
|
||||
}
|
||||
|
||||
if err := srv.setupOCSPStapleStoreDir(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// TODO: Add OCSP 'responder_cert' option in case CA cert not available.
|
||||
issuer, err := getOCSPIssuer(caFile, cert.Certificate)
|
||||
if err != nil {
|
||||
@@ -384,6 +443,35 @@ func hasOCSPStatusRequest(cert *x509.Certificate) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// writeOCSPStatus writes an OCSP status to a temporary file then moves it to a
|
||||
// new path, in an attempt to avoid corrupting existing data.
|
||||
func (oc *OCSPMonitor) writeOCSPStatus(storeDir, file string, data []byte) error {
|
||||
storeDir = filepath.Join(storeDir, defaultOCSPStoreDir)
|
||||
tmp, err := ioutil.TempFile(storeDir, "tmp-cert-status")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
tmp.Close()
|
||||
os.Remove(tmp.Name())
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oc.mu.Lock()
|
||||
err = os.Rename(tmp.Name(), filepath.Join(storeDir, file))
|
||||
oc.mu.Unlock()
|
||||
if err != nil {
|
||||
os.Remove(tmp.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCertPEM(name string) (*x509.Certificate, error) {
|
||||
data, err := ioutil.ReadFile(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -793,6 +793,13 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
case "store_dir", "storedir":
|
||||
// Check if JetStream configuration is also setting the storage directory.
|
||||
if o.StoreDir != "" {
|
||||
*errors = append(*errors, &configErr{tk, "Duplicate 'store_dir' configuration"})
|
||||
return
|
||||
}
|
||||
o.StoreDir = v.(string)
|
||||
case "jetstream":
|
||||
err := parseJetStream(tk, o, errors, warnings)
|
||||
if err != nil {
|
||||
@@ -1558,6 +1565,10 @@ func parseJetStream(v interface{}, opts *Options, errors *[]error, warnings *[]e
|
||||
tk, mv = unwrapValue(mv, <)
|
||||
switch strings.ToLower(mk) {
|
||||
case "store_dir", "storedir":
|
||||
// StoreDir can be set at the top level as well so have to prevent ambiguous declarations.
|
||||
if opts.StoreDir != "" {
|
||||
return &configErr{tk, "Duplicate 'store_dir' configuration"}
|
||||
}
|
||||
opts.StoreDir = mv.(string)
|
||||
case "max_memory_store", "max_mem_store", "max_mem":
|
||||
opts.JetStreamMaxMemory = mv.(int64)
|
||||
|
||||
@@ -1285,12 +1285,18 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
|
||||
func (s *Server) reloadOCSP() error {
|
||||
opts := s.getOpts()
|
||||
|
||||
if err := s.setupOCSPStapleStoreDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
ocsps := s.ocsps
|
||||
s.mu.Unlock()
|
||||
|
||||
// Stop all OCSP Stapling monitors in case there were any running.
|
||||
var wasEnabled bool
|
||||
for _, oc := range ocsps {
|
||||
wasEnabled = true
|
||||
oc.stop()
|
||||
}
|
||||
|
||||
@@ -1303,15 +1309,17 @@ func (s *Server) reloadOCSP() error {
|
||||
}
|
||||
// Check if an OCSP stapling monitor is required for this certificate.
|
||||
if mon != nil {
|
||||
s.Noticef("OCSP Stapling enabled for client connections")
|
||||
ocspm = append(ocspm, mon)
|
||||
|
||||
// Override the TLS config with one that follows OCSP.
|
||||
// Override the TLS config with one that has OCSP enabled.
|
||||
s.optsMu.Lock()
|
||||
s.opts.TLSConfig = tc
|
||||
s.optsMu.Unlock()
|
||||
s.startGoRoutine(func() { mon.run() })
|
||||
} else if wasEnabled {
|
||||
s.Warnf("OCSP Stapling disabled for client connections")
|
||||
}
|
||||
s.Noticef("OCSP Stapling enabled for client connections")
|
||||
}
|
||||
// Replace stopped monitors with the new ones.
|
||||
s.mu.Lock()
|
||||
|
||||
@@ -1470,6 +1470,24 @@ func (s *Server) fetchAccount(name string) (*Account, error) {
|
||||
return acc, nil
|
||||
}
|
||||
|
||||
func (s *Server) setupOCSPStapleStoreDir() error {
|
||||
opts := s.getOpts()
|
||||
storeDir := opts.StoreDir
|
||||
if storeDir == _EMPTY_ {
|
||||
s.Warnf("OCSP Stapling disk cache is disabled (missing 'store_dir')")
|
||||
return nil
|
||||
}
|
||||
storeDir = filepath.Join(storeDir, defaultOCSPStoreDir)
|
||||
if stat, err := os.Stat(storeDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(storeDir, defaultDirPerms); err != nil {
|
||||
return fmt.Errorf("could not create OCSP storage directory - %v", err)
|
||||
}
|
||||
} else if stat == nil || !stat.IsDir() {
|
||||
return fmt.Errorf("OCSP storage directory is not a directory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) enableOCSP() error {
|
||||
opts := s.getOpts()
|
||||
|
||||
@@ -1481,14 +1499,15 @@ func (s *Server) enableOCSP() error {
|
||||
}
|
||||
// Check if an OCSP stapling monitor is required for this certificate.
|
||||
if mon != nil {
|
||||
s.Noticef("OCSP Stapling enabled for client connections")
|
||||
|
||||
s.ocsps = append(s.ocsps, mon)
|
||||
// Override the TLS config with one that follows OCSP.
|
||||
opts.TLSConfig = tc
|
||||
s.startGoRoutine(func() { mon.run() })
|
||||
}
|
||||
s.Noticef("OCSP Stapling enabled for client connections")
|
||||
}
|
||||
// FIXME: Add support for leafnodes, routes, MQTT, WebSocket
|
||||
// FIXME: Add support for leafnodes, routes, gateways, MQTT, WebSocket
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
@@ -23,6 +24,8 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -50,7 +53,11 @@ func TestOCSPAlwaysMustStapleAndShutdown(t *testing.T) {
|
||||
addr := fmt.Sprintf("http://%s", ocspr.Addr)
|
||||
setOCSPStatus(t, addr, serverCert, ocsp.Good)
|
||||
|
||||
opts := DefaultTestOptions
|
||||
opts := server.Options{}
|
||||
opts.Host = "127.0.0.1"
|
||||
opts.NoLog = true
|
||||
opts.NoSigs = true
|
||||
opts.MaxControlLine = 4096
|
||||
opts.Port = -1
|
||||
opts.TLSCert = serverCert
|
||||
opts.TLSKey = serverKey
|
||||
@@ -140,7 +147,11 @@ func TestOCSPMustStapleShutdown(t *testing.T) {
|
||||
addr := fmt.Sprintf("http://%s", ocspr.Addr)
|
||||
setOCSPStatus(t, addr, serverCert, ocsp.Good)
|
||||
|
||||
opts := DefaultTestOptions
|
||||
opts := server.Options{}
|
||||
opts.Host = "127.0.0.1"
|
||||
opts.NoLog = true
|
||||
opts.NoSigs = true
|
||||
opts.MaxControlLine = 4096
|
||||
opts.Port = -1
|
||||
opts.TLSCert = serverCert
|
||||
opts.TLSKey = serverKey
|
||||
@@ -321,7 +332,11 @@ func TestOCSPAutoWithoutMustStapleDoesNotShutdownOnRevoke(t *testing.T) {
|
||||
addr := fmt.Sprintf("http://%s", ocspr.Addr)
|
||||
setOCSPStatus(t, addr, serverCert, ocsp.Good)
|
||||
|
||||
opts := DefaultTestOptions
|
||||
opts := server.Options{}
|
||||
opts.Host = "127.0.0.1"
|
||||
opts.NoLog = true
|
||||
opts.NoSigs = true
|
||||
opts.MaxControlLine = 4096
|
||||
opts.Port = -1
|
||||
opts.TLSCert = serverCert
|
||||
opts.TLSKey = serverKey
|
||||
@@ -663,9 +678,14 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
addr := fmt.Sprintf("http://%s", ocspr.Addr)
|
||||
setOCSPStatus(t, addr, serverCert, ocsp.Good)
|
||||
|
||||
content := `
|
||||
storeDir := createDir(t, "_ocsp")
|
||||
defer removeDir(t, storeDir)
|
||||
|
||||
originalContent := `
|
||||
port: -1
|
||||
|
||||
store_dir: "%s"
|
||||
|
||||
tls {
|
||||
cert_file: "configs/certs/ocsp/server-status-request-url-cert.pem"
|
||||
key_file: "configs/certs/ocsp/server-status-request-url-key.pem"
|
||||
@@ -673,14 +693,18 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
timeout: 5
|
||||
}
|
||||
`
|
||||
|
||||
content := fmt.Sprintf(originalContent, storeDir)
|
||||
conf := createConfFile(t, []byte(content))
|
||||
defer removeFile(t, conf)
|
||||
s, opts := RunServerWithConfig(conf)
|
||||
defer s.Shutdown()
|
||||
|
||||
var staple []byte
|
||||
nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
|
||||
nats.Secure(&tls.Config{
|
||||
VerifyConnection: func(s tls.ConnectionState) error {
|
||||
staple = s.OCSPResponse
|
||||
resp, err := getOCSPStatus(s)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -710,10 +734,37 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
}
|
||||
nc.Close()
|
||||
|
||||
files := []string{}
|
||||
err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
files = append(files, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
found := false
|
||||
for _, file := range files {
|
||||
data, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Equal(staple, data) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Could not find OCSP Staple")
|
||||
}
|
||||
|
||||
// Change the contents with another that has OCSP Stapling disabled.
|
||||
content = `
|
||||
updatedContent := `
|
||||
port: -1
|
||||
|
||||
store_dir: "%s"
|
||||
|
||||
tls {
|
||||
cert_file: "configs/certs/ocsp/server-cert.pem"
|
||||
key_file: "configs/certs/ocsp/server-key.pem"
|
||||
@@ -721,6 +772,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
timeout: 5
|
||||
}
|
||||
`
|
||||
content = fmt.Sprintf(updatedContent, storeDir)
|
||||
if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil {
|
||||
t.Fatalf("Error writing config: %v", err)
|
||||
}
|
||||
@@ -729,7 +781,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
}
|
||||
|
||||
// The new certificate does not have must staple so they will be missing.
|
||||
time.Sleep(2 * time.Second)
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
|
||||
nats.Secure(&tls.Config{
|
||||
@@ -747,6 +799,67 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.Close()
|
||||
|
||||
// Re-enable OCSP Stapling
|
||||
content = fmt.Sprintf(originalContent, storeDir)
|
||||
if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil {
|
||||
t.Fatalf("Error writing config: %v", err)
|
||||
}
|
||||
if err := s.Reload(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var newStaple []byte
|
||||
nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port),
|
||||
nats.Secure(&tls.Config{
|
||||
VerifyConnection: func(s tls.ConnectionState) error {
|
||||
newStaple = s.OCSPResponse
|
||||
resp, err := getOCSPStatus(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Status != ocsp.Good {
|
||||
t.Errorf("Expected valid OCSP staple status")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
nats.RootCAs(caCert),
|
||||
nats.ErrorHandler(noOpErrHandler),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.Close()
|
||||
|
||||
// Confirm that it got a new staple.
|
||||
files = []string{}
|
||||
err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
files = append(files, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
found = false
|
||||
for _, file := range files {
|
||||
data, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Equal(newStaple, data) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Could not find OCSP Staple")
|
||||
}
|
||||
if bytes.Equal(staple, newStaple) {
|
||||
t.Error("Expected new OCSP Staple")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOCSPReloadRotateTLSCertEnableMustStaple(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user