Use new DB API

master
HimbeerserverDE 2021-04-22 12:33:50 +02:00
parent 2b7abb3e63
commit 485db9d950
No known key found for this signature in database
GPG Key ID: 1A651504791E6A8B
7 changed files with 38 additions and 90 deletions

View File

@ -2,7 +2,9 @@ package main
import (
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"log"
"strings"
@ -130,7 +132,8 @@ func Password(name string) ([]byte, []byte, error) {
defer db.Close()
var pwd string
if err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd); err != nil {
err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, nil, err
}

9
ban.go
View File

@ -1,6 +1,7 @@
package main
import (
"database/sql"
"errors"
"fmt"
"net"
@ -48,7 +49,11 @@ func IsBanned(addr string) (bool, string, error) {
var name string
err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name)
return name != "", name, err
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return true, "", err
}
return name != "", name, nil
}
// IsBanned reports whether a Conn is banned
@ -113,6 +118,6 @@ func Unban(id string) error {
}
defer db.Close()
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id)
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id, id)
return err
}

View File

@ -326,21 +326,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
s := ReadBytes16(r)
v := ReadBytes16(r)
pwd := encodeVerifierAndSalt(s, v)
db, err := initAuthDB()
if err != nil {
log.Print(err)
return true
}
err = modAuthItem(db, src.Username(), pwd)
if err != nil {
log.Print(err)
return true
}
db.Close()
SetPassword(src.Username(), v, s)
} else {
log.Print("User " + src.Username() + " at " + src.Addr().String() + " did not enter sudo mode before attempting to change the password")
}
@ -350,21 +336,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
if !src.sudoMode {
A := ReadBytes16(r)
db, err := initAuthDB()
if err != nil {
log.Print(err)
return true
}
pwd, err := readAuthItem(db, src.Username())
if err != nil {
log.Print(err)
return true
}
db.Close()
s, v, err := decodeVerifierAndSalt(pwd)
v, s, err := Password(src.Username())
if err != nil {
log.Print(err)
return true

23
db.go
View File

@ -10,8 +10,14 @@ import (
_ "github.com/mattn/go-sqlite3"
)
const (
DBTypeSQLite3 = iota
DBTypePSQL
)
type DB struct {
*sql.DB
dbType int
}
// OpenSQLite3 opens and returns a SQLite3 database
@ -28,7 +34,7 @@ func OpenSQLite3(name, initSQL string) (*DB, error) {
return nil, err
}
return &DB{DB: db}, nil
return &DB{DB: db, dbType: DBTypeSQLite3}, nil
}
// OpenPSQL opens and returns a PostgreSQL database
@ -51,17 +57,24 @@ func OpenPSQL(host, name, user, password, initSQL string, port int) (*DB, error)
return nil, err
}
return &DB{DB: db}, nil
return &DB{DB: db, dbType: DBTypePSQL}, nil
}
// Type returns the type of database that is being interacted with
func (db *DB) Type() int { return db.dbType }
// Exec executes a SQL statement
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
sql = strings.ReplaceAll(sql, "?", "$x")
return db.DB.Exec(sql, values)
if db.Type() == DBTypePSQL {
sql = strings.ReplaceAll(sql, "?", "$x")
}
return db.DB.Exec(sql, values...)
}
// Query executes a SQL statement and stores the results
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
sql = strings.ReplaceAll(sql, "?", "$x")
if db.Type() == DBTypePSQL {
sql = strings.ReplaceAll(sql, "?", "$x")
}
return db.DB.QueryRow(sql, values...)
}

47
init.go
View File

@ -348,21 +348,13 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
return
}
db, err := initAuthDB()
v, s, err := Password(c2.Username())
if err != nil {
log.Print(err)
continue
}
pwd, err := readAuthItem(db, c2.Username())
if err != nil {
log.Print(err)
continue
}
db.Close()
if pwd == "" {
if v == nil || s == nil {
// New player
c2.authMech = AuthMechFirstSRP
binary.BigEndian.PutUint32(data[7:11], uint32(AuthMechFirstSRP))
@ -408,28 +400,11 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
return
}
pwd := encodeVerifierAndSalt(s, v)
db, err := initAuthDB()
if err != nil {
if err := CreateUser(c2.Username(), v, s); err != nil {
log.Print(err)
continue
}
err = addAuthItem(db, c2.Username(), pwd)
if err != nil {
log.Print(err)
continue
}
err = addPrivItem(db, c2.Username())
if err != nil {
log.Print(err)
continue
}
db.Close()
// Send AUTH_ACCEPT
data := []byte{
0, ToClientAuthAccept,
@ -464,21 +439,7 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
A := ReadBytes16(r)
db, err := initAuthDB()
if err != nil {
log.Print(err)
continue
}
pwd, err := readAuthItem(db, c2.Username())
if err != nil {
log.Print(err)
continue
}
db.Close()
s, v, err := decodeVerifierAndSalt(pwd)
v, s, err := Password(c2.Username())
if err != nil {
log.Print(err)
continue

View File

@ -79,7 +79,7 @@ func SetPrivs(name string, privs map[string]bool) error {
_, err = db.Exec(`REPLACE INTO privileges (
name,
priviliges
privileges
) VALUES (
?,
?

10
rpc.go
View File

@ -139,25 +139,19 @@ func processRpc(c *Conn, r *bytes.Reader) bool {
}
go c.doRpc("->ADDR "+addr, rq)
case "<-ISBANNED":
db, err := initAuthDB()
if err != nil {
return true
}
defer db.Close()
target := strings.Split(msg, " ")[2]
if net.ParseIP(target) == nil {
return true
}
name, err := readBanItem(db, target)
banned, _, err := IsBanned(target)
if err != nil {
return true
}
r := "false"
if name != "" {
if banned {
r = "true"
}