Support PostgreSQL for auth
parent
485db9d950
commit
5f27eefc9e
12
auth.go
12
auth.go
|
@ -117,8 +117,8 @@ func CreateUser(name string, verifier, salt []byte) error {
|
||||||
name,
|
name,
|
||||||
password
|
password
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?,
|
$1,
|
||||||
?
|
$2
|
||||||
);`, name, pwd)
|
);`, name, pwd)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -132,11 +132,15 @@ func Password(name string) ([]byte, []byte, error) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
var pwd string
|
var pwd string
|
||||||
err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd)
|
err = db.QueryRow(`SELECT password FROM auth WHERE name = $1;`, name).Scan(&pwd)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pwd == "" {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
salt, verifier, err := decodeVerifierAndSalt(pwd)
|
salt, verifier, err := decodeVerifierAndSalt(pwd)
|
||||||
return verifier, salt, err
|
return verifier, salt, err
|
||||||
}
|
}
|
||||||
|
@ -151,7 +155,7 @@ func SetPassword(name string, verifier, salt []byte) error {
|
||||||
|
|
||||||
pwd := encodeVerifierAndSalt(salt, verifier)
|
pwd := encodeVerifierAndSalt(salt, verifier)
|
||||||
|
|
||||||
_, err = db.Exec(`UPDATE auth SET password = ? WHERE name = ?;`, pwd, name)
|
_, err = db.Exec(`UPDATE auth SET password = $1 WHERE name = $2;`, pwd, name)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
8
ban.go
8
ban.go
|
@ -48,7 +48,7 @@ func IsBanned(addr string) (bool, string, error) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name)
|
err = db.QueryRow(`SELECT name FROM ban WHERE addr = $1;`, addr).Scan(&name)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return true, "", err
|
return true, "", err
|
||||||
}
|
}
|
||||||
|
@ -84,8 +84,8 @@ func Ban(addr, name string) error {
|
||||||
addr,
|
addr,
|
||||||
name
|
name
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?,
|
$1,
|
||||||
?
|
$2
|
||||||
);`, addr, name)
|
);`, addr, name)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -118,6 +118,6 @@ func Unban(id string) error {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id, id)
|
_, err = db.Exec(`DELETE FROM ban WHERE name = $1 OR addr = $2;`, id, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
22
db.go
22
db.go
|
@ -4,7 +4,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"regexp"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
@ -65,16 +65,26 @@ func (db *DB) Type() int { return db.dbType }
|
||||||
|
|
||||||
// Exec executes a SQL statement
|
// Exec executes a SQL statement
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
|
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
|
||||||
if db.Type() == DBTypePSQL {
|
if db.Type() == DBTypeSQLite3 {
|
||||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
r, err := regexp.Compile("\\$+[0-9]")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = r.ReplaceAllString(sql, "?")
|
||||||
}
|
}
|
||||||
return db.DB.Exec(sql, values...)
|
return db.DB.Exec(sql, values...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a SQL statement and stores the results
|
// QueryRow executes a SQL statement and stores the results
|
||||||
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
|
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
|
||||||
if db.Type() == DBTypePSQL {
|
if db.Type() == DBTypeSQLite3 {
|
||||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
r, err := regexp.Compile("\\$+[0-9]")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = r.ReplaceAllString(sql, "?")
|
||||||
}
|
}
|
||||||
return db.DB.QueryRow(sql, values...)
|
return db.DB.QueryRow(sql, values...)
|
||||||
}
|
}
|
||||||
|
|
10
privs.go
10
privs.go
|
@ -56,7 +56,7 @@ func Privs(name string) (map[string]bool, error) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
var eprivs string
|
var eprivs string
|
||||||
err = db.QueryRow(`SELECT privileges FROM privileges WHERE name = ?;`, name).Scan(&eprivs)
|
err = db.QueryRow(`SELECT privileges FROM privileges WHERE name = $1;`, name).Scan(&eprivs)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return make(map[string]bool), err
|
return make(map[string]bool), err
|
||||||
}
|
}
|
||||||
|
@ -77,13 +77,15 @@ func SetPrivs(name string, privs map[string]bool) error {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec(`REPLACE INTO privileges (
|
_, err = db.Exec(`INSERT INTO privileges (
|
||||||
name,
|
name,
|
||||||
privileges
|
privileges
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?,
|
$1,
|
||||||
?
|
$2
|
||||||
);`, name, encodePrivs(privs))
|
);`, name, encodePrivs(privs))
|
||||||
|
_, err = db.Exec(`UPDATE privileges SET privileges = $1 WHERE name = $2;`, encodePrivs(privs), name)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue