diff --git a/server/config_check_test.go b/server/config_check_test.go index 94e0afbe..21582afd 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1464,6 +1464,16 @@ func TestConfigCheck(t *testing.T) { errorLine: 6, errorPos: 10, }, + { + name: "ambiguous store dir", + config: ` + store_dir: "foo" + jetstream { + store_dir: "bar" + } + `, + err: fmt.Errorf(`Duplicate 'store_dir' configuration`), + }, } checkConfig := func(config string) error { @@ -1499,10 +1509,17 @@ func TestConfigCheck(t *testing.T) { } if err != nil && expectedErr != nil { - msg := fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error()) - if test.reason != "" { - msg += ": " + test.reason + var msg string + + if test.errorPos > 0 { + msg = fmt.Sprintf("%s:%d:%d: %s", conf, test.errorLine, test.errorPos, expectedErr.Error()) + if test.reason != "" { + msg += ": " + test.reason + } + } else { + msg = test.reason } + if !strings.Contains(err.Error(), msg) { t.Errorf("Expected:\n%q\ngot:\n%q", msg, err.Error()) } diff --git a/server/ocsp.go b/server/ocsp.go index 6c1c9091..a4556f0b 100644 --- a/server/ocsp.go +++ b/server/ocsp.go @@ -14,6 +14,7 @@ package server import ( + "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -22,6 +23,8 @@ import ( "fmt" "io/ioutil" "net/http" + "os" + "path/filepath" "strings" "sync" "time" @@ -30,6 +33,7 @@ import ( ) const ( + defaultOCSPStoreDir = "ocsp" defaultOCSPCheckInterval = 24 * time.Hour minOCSPCheckInterval = 2 * time.Minute ) @@ -92,8 +96,17 @@ func (oc *OCSPMonitor) getNextRun() time.Duration { func (oc *OCSPMonitor) getStatus() ([]byte, *ocsp.Response, error) { raw, resp := oc.getCacheStatus() if len(raw) > 0 && resp != nil { + // Check if the OCSP is still valid. + if err := validOCSPResponse(resp); err == nil { + return raw, resp, nil + } + } + var err error + raw, resp, err = oc.getLocalStatus() + if err == nil { return raw, resp, nil } + return oc.getRemoteStatus() } @@ -103,6 +116,41 @@ func (oc *OCSPMonitor) getCacheStatus() ([]byte, *ocsp.Response) { return oc.raw, oc.resp } +func (oc *OCSPMonitor) getLocalStatus() ([]byte, *ocsp.Response, error) { + opts := oc.srv.getOpts() + storeDir := opts.StoreDir + if storeDir == _EMPTY_ { + return nil, nil, fmt.Errorf("store_dir not set") + } + + // This key must be based upon the current full certificate, not the public key, + // so MUST be on the full raw certificate and not an SPKI or other reduced form. + key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw)) + + oc.mu.Lock() + raw, err := ioutil.ReadFile(filepath.Join(storeDir, defaultOCSPStoreDir, key)) + oc.mu.Unlock() + if err != nil { + return nil, nil, err + } + + resp, err := ocsp.ParseResponse(raw, oc.Issuer) + if err != nil { + return nil, nil, err + } + if err := validOCSPResponse(resp); err != nil { + return nil, nil, err + } + + // Cache the response. + oc.mu.Lock() + oc.raw = raw + oc.resp = resp + oc.mu.Unlock() + + return raw, resp, nil +} + func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) { opts := oc.srv.getOpts() var overrideURLs []string @@ -163,6 +211,13 @@ func (oc *OCSPMonitor) getRemoteStatus() ([]byte, *ocsp.Response, error) { return nil, nil, err } + if storeDir := opts.StoreDir; storeDir != _EMPTY_ { + key := fmt.Sprintf("%x", sha256.Sum256(oc.Leaf.Raw)) + if err := oc.writeOCSPStatus(storeDir, key, raw); err != nil { + return nil, nil, fmt.Errorf("failed to write ocsp status: %w", err) + } + } + oc.mu.Lock() oc.raw = raw oc.resp = resp @@ -201,7 +256,7 @@ func (oc *OCSPMonitor) run() { return } - for s.Running() { + for { // On reload, if the certificate changes then need to stop this monitor. select { case <-time.After(nextRun): @@ -294,6 +349,10 @@ func (srv *Server) NewOCSPMonitor(tc *tls.Config) (*tls.Config, *OCSPMonitor, er return tc, nil, nil } + if err := srv.setupOCSPStapleStoreDir(); err != nil { + return nil, nil, err + } + // TODO: Add OCSP 'responder_cert' option in case CA cert not available. issuer, err := getOCSPIssuer(caFile, cert.Certificate) if err != nil { @@ -384,6 +443,35 @@ func hasOCSPStatusRequest(cert *x509.Certificate) bool { return false } +// writeOCSPStatus writes an OCSP status to a temporary file then moves it to a +// new path, in an attempt to avoid corrupting existing data. +func (oc *OCSPMonitor) writeOCSPStatus(storeDir, file string, data []byte) error { + storeDir = filepath.Join(storeDir, defaultOCSPStoreDir) + tmp, err := ioutil.TempFile(storeDir, "tmp-cert-status") + if err != nil { + return err + } + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmp.Name()) + return err + } + if err := tmp.Close(); err != nil { + return err + } + + oc.mu.Lock() + err = os.Rename(tmp.Name(), filepath.Join(storeDir, file)) + oc.mu.Unlock() + if err != nil { + os.Remove(tmp.Name()) + return err + } + + return nil +} + func parseCertPEM(name string) (*x509.Certificate, error) { data, err := ioutil.ReadFile(name) if err != nil { diff --git a/server/opts.go b/server/opts.go index e79541fc..68d86953 100644 --- a/server/opts.go +++ b/server/opts.go @@ -793,6 +793,13 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error *errors = append(*errors, err) return } + case "store_dir", "storedir": + // Check if JetStream configuration is also setting the storage directory. + if o.StoreDir != "" { + *errors = append(*errors, &configErr{tk, "Duplicate 'store_dir' configuration"}) + return + } + o.StoreDir = v.(string) case "jetstream": err := parseJetStream(tk, o, errors, warnings) if err != nil { @@ -1558,6 +1565,10 @@ func parseJetStream(v interface{}, opts *Options, errors *[]error, warnings *[]e tk, mv = unwrapValue(mv, <) switch strings.ToLower(mk) { case "store_dir", "storedir": + // StoreDir can be set at the top level as well so have to prevent ambiguous declarations. + if opts.StoreDir != "" { + return &configErr{tk, "Duplicate 'store_dir' configuration"} + } opts.StoreDir = mv.(string) case "max_memory_store", "max_mem_store", "max_mem": opts.JetStreamMaxMemory = mv.(int64) diff --git a/server/reload.go b/server/reload.go index 5704e90e..ccfec13b 100644 --- a/server/reload.go +++ b/server/reload.go @@ -1285,12 +1285,18 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) { func (s *Server) reloadOCSP() error { opts := s.getOpts() + if err := s.setupOCSPStapleStoreDir(); err != nil { + return err + } + s.mu.Lock() ocsps := s.ocsps s.mu.Unlock() // Stop all OCSP Stapling monitors in case there were any running. + var wasEnabled bool for _, oc := range ocsps { + wasEnabled = true oc.stop() } @@ -1303,15 +1309,17 @@ func (s *Server) reloadOCSP() error { } // Check if an OCSP stapling monitor is required for this certificate. if mon != nil { + s.Noticef("OCSP Stapling enabled for client connections") ocspm = append(ocspm, mon) - // Override the TLS config with one that follows OCSP. + // Override the TLS config with one that has OCSP enabled. s.optsMu.Lock() s.opts.TLSConfig = tc s.optsMu.Unlock() s.startGoRoutine(func() { mon.run() }) + } else if wasEnabled { + s.Warnf("OCSP Stapling disabled for client connections") } - s.Noticef("OCSP Stapling enabled for client connections") } // Replace stopped monitors with the new ones. s.mu.Lock() diff --git a/server/server.go b/server/server.go index a2ef244f..ff1bebb5 100644 --- a/server/server.go +++ b/server/server.go @@ -1470,6 +1470,24 @@ func (s *Server) fetchAccount(name string) (*Account, error) { return acc, nil } +func (s *Server) setupOCSPStapleStoreDir() error { + opts := s.getOpts() + storeDir := opts.StoreDir + if storeDir == _EMPTY_ { + s.Warnf("OCSP Stapling disk cache is disabled (missing 'store_dir')") + return nil + } + storeDir = filepath.Join(storeDir, defaultOCSPStoreDir) + if stat, err := os.Stat(storeDir); os.IsNotExist(err) { + if err := os.MkdirAll(storeDir, defaultDirPerms); err != nil { + return fmt.Errorf("could not create OCSP storage directory - %v", err) + } + } else if stat == nil || !stat.IsDir() { + return fmt.Errorf("OCSP storage directory is not a directory") + } + return nil +} + func (s *Server) enableOCSP() error { opts := s.getOpts() @@ -1481,14 +1499,15 @@ func (s *Server) enableOCSP() error { } // Check if an OCSP stapling monitor is required for this certificate. if mon != nil { + s.Noticef("OCSP Stapling enabled for client connections") + s.ocsps = append(s.ocsps, mon) // Override the TLS config with one that follows OCSP. opts.TLSConfig = tc s.startGoRoutine(func() { mon.run() }) } - s.Noticef("OCSP Stapling enabled for client connections") } - // FIXME: Add support for leafnodes, routes, MQTT, WebSocket + // FIXME: Add support for leafnodes, routes, gateways, MQTT, WebSocket return nil } diff --git a/test/ocsp_test.go b/test/ocsp_test.go index d852a042..204fb571 100644 --- a/test/ocsp_test.go +++ b/test/ocsp_test.go @@ -14,6 +14,7 @@ package test import ( + "bytes" "context" "crypto/rsa" "crypto/tls" @@ -23,6 +24,8 @@ import ( "fmt" "io/ioutil" "net/http" + "os" + "path/filepath" "strconv" "strings" "sync" @@ -50,7 +53,11 @@ func TestOCSPAlwaysMustStapleAndShutdown(t *testing.T) { addr := fmt.Sprintf("http://%s", ocspr.Addr) setOCSPStatus(t, addr, serverCert, ocsp.Good) - opts := DefaultTestOptions + opts := server.Options{} + opts.Host = "127.0.0.1" + opts.NoLog = true + opts.NoSigs = true + opts.MaxControlLine = 4096 opts.Port = -1 opts.TLSCert = serverCert opts.TLSKey = serverKey @@ -140,7 +147,11 @@ func TestOCSPMustStapleShutdown(t *testing.T) { addr := fmt.Sprintf("http://%s", ocspr.Addr) setOCSPStatus(t, addr, serverCert, ocsp.Good) - opts := DefaultTestOptions + opts := server.Options{} + opts.Host = "127.0.0.1" + opts.NoLog = true + opts.NoSigs = true + opts.MaxControlLine = 4096 opts.Port = -1 opts.TLSCert = serverCert opts.TLSKey = serverKey @@ -321,7 +332,11 @@ func TestOCSPAutoWithoutMustStapleDoesNotShutdownOnRevoke(t *testing.T) { addr := fmt.Sprintf("http://%s", ocspr.Addr) setOCSPStatus(t, addr, serverCert, ocsp.Good) - opts := DefaultTestOptions + opts := server.Options{} + opts.Host = "127.0.0.1" + opts.NoLog = true + opts.NoSigs = true + opts.MaxControlLine = 4096 opts.Port = -1 opts.TLSCert = serverCert opts.TLSKey = serverKey @@ -663,9 +678,14 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { addr := fmt.Sprintf("http://%s", ocspr.Addr) setOCSPStatus(t, addr, serverCert, ocsp.Good) - content := ` + storeDir := createDir(t, "_ocsp") + defer removeDir(t, storeDir) + + originalContent := ` port: -1 + store_dir: "%s" + tls { cert_file: "configs/certs/ocsp/server-status-request-url-cert.pem" key_file: "configs/certs/ocsp/server-status-request-url-key.pem" @@ -673,14 +693,18 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { timeout: 5 } ` + + content := fmt.Sprintf(originalContent, storeDir) conf := createConfFile(t, []byte(content)) defer removeFile(t, conf) s, opts := RunServerWithConfig(conf) defer s.Shutdown() + var staple []byte nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port), nats.Secure(&tls.Config{ VerifyConnection: func(s tls.ConnectionState) error { + staple = s.OCSPResponse resp, err := getOCSPStatus(s) if err != nil { return err @@ -710,10 +734,37 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { } nc.Close() + files := []string{} + err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + files = append(files, path) + return nil + }) + if err != nil { + t.Fatal(err) + } + found := false + for _, file := range files { + data, err := ioutil.ReadFile(file) + if err != nil { + t.Error(err) + } + if bytes.Equal(staple, data) { + found = true + } + } + if !found { + t.Error("Could not find OCSP Staple") + } + // Change the contents with another that has OCSP Stapling disabled. - content = ` + updatedContent := ` port: -1 + store_dir: "%s" + tls { cert_file: "configs/certs/ocsp/server-cert.pem" key_file: "configs/certs/ocsp/server-key.pem" @@ -721,6 +772,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { timeout: 5 } ` + content = fmt.Sprintf(updatedContent, storeDir) if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil { t.Fatalf("Error writing config: %v", err) } @@ -729,7 +781,7 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { } // The new certificate does not have must staple so they will be missing. - time.Sleep(2 * time.Second) + time.Sleep(4 * time.Second) nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port), nats.Secure(&tls.Config{ @@ -747,6 +799,67 @@ func TestOCSPReloadRotateTLSCertDisableMustStaple(t *testing.T) { t.Fatal(err) } nc.Close() + + // Re-enable OCSP Stapling + content = fmt.Sprintf(originalContent, storeDir) + if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil { + t.Fatalf("Error writing config: %v", err) + } + if err := s.Reload(); err != nil { + t.Fatal(err) + } + + var newStaple []byte + nc, err = nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port), + nats.Secure(&tls.Config{ + VerifyConnection: func(s tls.ConnectionState) error { + newStaple = s.OCSPResponse + resp, err := getOCSPStatus(s) + if err != nil { + return err + } + if resp.Status != ocsp.Good { + t.Errorf("Expected valid OCSP staple status") + } + return nil + }, + }), + nats.RootCAs(caCert), + nats.ErrorHandler(noOpErrHandler), + ) + if err != nil { + t.Fatal(err) + } + nc.Close() + + // Confirm that it got a new staple. + files = []string{} + err = filepath.Walk(storeDir+"/ocsp/", func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + files = append(files, path) + return nil + }) + if err != nil { + t.Fatal(err) + } + found = false + for _, file := range files { + data, err := ioutil.ReadFile(file) + if err != nil { + t.Error(err) + } + if bytes.Equal(newStaple, data) { + found = true + } + } + if !found { + t.Error("Could not find OCSP Staple") + } + if bytes.Equal(staple, newStaple) { + t.Error("Expected new OCSP Staple") + } } func TestOCSPReloadRotateTLSCertEnableMustStaple(t *testing.T) {