@@ -10,6 +10,7 @@ import (
10
10
"strings"
11
11
"time"
12
12
13
+ "github.com/fly-examples/postgres-ha/pkg/flypg/admin"
13
14
"github.com/fly-examples/postgres-ha/pkg/privnet"
14
15
"github.com/fly-examples/postgres-ha/pkg/supervisor"
15
16
"github.com/jackc/pgx/v4"
@@ -71,16 +72,8 @@ func Run() error {
71
72
return errors .Wrap (err , "failed opening connection to postgres" )
72
73
}
73
74
74
- if err = setInternalCredential (conn , "flypgadmin" , os .Getenv ("SU_PASSWORD" ), false ); err != nil {
75
- return err
76
- }
77
-
78
- if err = setInternalCredential (conn , "repluser" , os .Getenv ("REPL_PASSWORD" ), false ); err != nil {
79
- return err
80
- }
81
-
82
- if err = setInternalCredential (conn , "postgres" , os .Getenv ("OPERATOR_PASSWORD" ), true ); err != nil {
83
- return err
75
+ if err = createRequiredUsers (conn ); err != nil {
76
+ return errors .Wrap (err , "failed creating required users" )
84
77
}
85
78
86
79
if err := restoreHBAFile (); err != nil {
@@ -164,16 +157,48 @@ func openConn() (*pgx.Conn, error) {
164
157
}
165
158
}
166
159
167
- func setInternalCredential (conn * pgx.Conn , user , password string , optional bool ) error {
168
- sql := fmt .Sprintf ("ALTER USER %s WITH PASSWORD '%s'" , user , password )
169
- _ , err := conn .Exec (context .Background (), sql )
160
+ func createRequiredUsers (conn * pgx.Conn ) error {
161
+ curUsers , err := admin .ListUsers (context .TODO (), conn )
170
162
if err != nil {
171
- if optional {
172
- fmt .Printf ("failed to reset credentials for user: %q. error: %v" , user , err )
173
- return nil
163
+ return errors .Wrap (err , "failed to list current users" )
164
+ }
165
+
166
+ credMap := map [string ]string {
167
+ "flypgadmin" : os .Getenv ("SU_PASSWORD" ),
168
+ "repluser" : os .Getenv ("REPL_PASSWORD" ),
169
+ "postgres" : os .Getenv ("OPERATOR_PASSWORD" ),
170
+ }
171
+
172
+ for user , pass := range credMap {
173
+
174
+ exists := false
175
+ for _ , curUser := range curUsers {
176
+ if user == curUser .Username {
177
+ exists = true
178
+ }
179
+ }
180
+ var sql string
181
+
182
+ if exists {
183
+ sql = fmt .Sprintf ("ALTER USER %s WITH PASSWORD '%s'" , user , pass )
184
+ } else {
185
+ // create user
186
+ switch user {
187
+ case "flypgadmin" :
188
+ sql = fmt .Sprintf (`CREATE USER %s WITH SUPERUSER LOGIN PASSWORD '%s'` , user , pass )
189
+ case "repluser" :
190
+ sql = fmt .Sprintf (`CREATE USER %s WITH REPLICATION PASSWORD '%s'` , user , pass )
191
+ case "postgres" :
192
+ sql = fmt .Sprintf (`CREATE USER %s WITH LOGIN PASSWORD '%s'` , user , pass )
193
+ }
194
+ }
195
+
196
+ _ , err := conn .Exec (context .Background (), sql )
197
+ if err != nil {
198
+ return err
174
199
}
175
- return err
176
200
}
201
+
177
202
return nil
178
203
}
179
204
0 commit comments