Skip to content
This repository was archived by the owner on Jun 12, 2024. It is now read-only.

Commit 96447ad

Browse files
authored
feat: support PGSSL* env variables for configuring the ssl mode (#187)
When creating database connections, lookup the SSL related flags from the standard PGSSL* env varibles. This allows configuring the connection via the same values that can be used to configure psql or other postgres tools. Signed-off-by: Lucas Roesler <[email protected]>
1 parent 09a25fd commit 96447ad

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pkg/config/database.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package config
33
import (
44
"fmt"
55
"io/ioutil"
6+
"os"
67
"strings"
78

89
"github.com/pkg/errors"
@@ -12,7 +13,28 @@ var defaultPorts = map[string]uint32{
1213
"postgres": 5432,
1314
}
1415

16+
const (
17+
// these environment variables can be used to control the SSL
18+
// validation behavior and behave the same as the documented
19+
// postegres environment variables
20+
PGSSLModeEnvKey = "PGSSLMODE"
21+
PGSSLCertPathEnvKey = "PGSSLCERT"
22+
PGSSLKeyPathEnvKey = "PGSSLKEY"
23+
PGSSLRootCertPathEnvKey = "PGSSLROOTCERT"
24+
)
25+
26+
const (
27+
pgSSLDisabled = "disable"
28+
pgSSLRequire = "require"
29+
pgSSLVerifyCA = "verify-ca"
30+
pgSSLVerifyFull = "verify-full"
31+
)
32+
1533
// Database contains all the configuration parameters for a database
34+
// SSL mode and options can be configured via the standard Postgres env variables
35+
// documented here https://www.postgresql.org/docs/current/libpq-envars.html
36+
//
37+
// Specifically, it supports: PGSSLMODE, PGSSLCERT, PGSSLKEY, PGSSLROOTCERT.
1638
type Database struct {
1739
// Host of the database server
1840
Host string `json:"host"`
@@ -65,7 +87,30 @@ func (cfg *Database) GetPort() uint32 {
6587

6688
// GetConnectionString returns the formed connection string
6789
func (cfg *Database) GetConnectionString() (connStr string, err error) {
68-
connStr = "sslmode=disable "
90+
sslMode, found := os.LookupEnv(PGSSLModeEnvKey)
91+
if !found {
92+
sslMode = pgSSLDisabled
93+
}
94+
95+
switch sslMode {
96+
case pgSSLDisabled, pgSSLRequire, pgSSLVerifyCA, pgSSLVerifyFull:
97+
// nothing to do
98+
default:
99+
return "", fmt.Errorf("unknown or unsupported ssl mode: %q", sslMode)
100+
}
101+
102+
connStr = fmt.Sprintf("sslmode=%s ", sslMode)
103+
if value, found := os.LookupEnv(PGSSLCertPathEnvKey); found {
104+
connStr += fmt.Sprintf("sslcert=%s ", value)
105+
}
106+
107+
if value, found := os.LookupEnv(PGSSLKeyPathEnvKey); found {
108+
connStr += fmt.Sprintf("sslkey=%s ", value)
109+
}
110+
111+
if value, found := os.LookupEnv(PGSSLRootCertPathEnvKey); found {
112+
connStr += fmt.Sprintf("sslrootcert=%s ", value)
113+
}
69114

70115
if cfg.Host != "" {
71116
connStr += fmt.Sprintf("host=%s ", cfg.Host)

pkg/config/database_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"os"
45
"testing"
56

67
"github.com/stretchr/testify/require"
@@ -61,6 +62,7 @@ func TestDatabaseGetConnectionString(t *testing.T) {
6162
cases := []struct {
6263
name string
6364
cfg Database
65+
env map[string]string
6466
expected string
6567
err string
6668
}{
@@ -75,6 +77,23 @@ func TestDatabaseGetConnectionString(t *testing.T) {
7577
},
7678
expected: "sslmode=disable host=example.com port=80 dbname=database user=root password=password",
7779
},
80+
{
81+
name: "Can set sslmode to require via env variables",
82+
cfg: Database{
83+
Host: "example.com",
84+
Port: 80,
85+
Name: "database",
86+
Username: "root",
87+
PasswordPath: "./testdata/password",
88+
},
89+
env: map[string]string{
90+
"PGSSLMODE": "require",
91+
"PGSSLCERT": "/cert/path",
92+
"PGSSLKEY": "/key/path",
93+
"PGSSLROOTCERT": "/root/cert/path",
94+
},
95+
expected: "sslmode=require sslcert=/cert/path sslkey=/key/path sslrootcert=/root/cert/path host=example.com port=80 dbname=database user=root password=password",
96+
},
7897
{
7998
name: "Returns a connection string when Host is not set",
8099
cfg: Database{
@@ -140,6 +159,9 @@ func TestDatabaseGetConnectionString(t *testing.T) {
140159

141160
for _, tc := range cases {
142161
t.Run(tc.name, func(t *testing.T) {
162+
reset := envSetup(tc.env)
163+
defer reset()
164+
143165
str, err := tc.cfg.GetConnectionString()
144166
if tc.err != "" {
145167
require.Error(t, err)
@@ -151,3 +173,29 @@ func TestDatabaseGetConnectionString(t *testing.T) {
151173
})
152174
}
153175
}
176+
177+
func envSetup(envs map[string]string) (resetter func()) {
178+
if len(envs) == 0 {
179+
return func() {}
180+
}
181+
182+
originalEnvs := map[string]string{}
183+
184+
for name, value := range envs {
185+
if originalValue, ok := os.LookupEnv(name); ok {
186+
originalEnvs[name] = originalValue
187+
}
188+
_ = os.Setenv(name, value)
189+
}
190+
191+
return func() {
192+
for name := range envs {
193+
origValue, has := originalEnvs[name]
194+
if has {
195+
_ = os.Setenv(name, origValue)
196+
} else {
197+
_ = os.Unsetenv(name)
198+
}
199+
}
200+
}
201+
}

0 commit comments

Comments
 (0)