Use new DB API
parent
2b7abb3e63
commit
485db9d950
5
auth.go
5
auth.go
|
@ -2,7 +2,9 @@ package main
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
|
@ -130,7 +132,8 @@ func Password(name string) ([]byte, []byte, error) {
|
|||
defer db.Close()
|
||||
|
||||
var pwd string
|
||||
if err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd); err != nil {
|
||||
err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
9
ban.go
9
ban.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -48,7 +49,11 @@ func IsBanned(addr string) (bool, string, error) {
|
|||
|
||||
var name string
|
||||
err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name)
|
||||
return name != "", name, err
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return true, "", err
|
||||
}
|
||||
|
||||
return name != "", name, nil
|
||||
}
|
||||
|
||||
// IsBanned reports whether a Conn is banned
|
||||
|
@ -113,6 +118,6 @@ func Unban(id string) error {
|
|||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id)
|
||||
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id, id)
|
||||
return err
|
||||
}
|
||||
|
|
32
command.go
32
command.go
|
@ -326,21 +326,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
|
|||
s := ReadBytes16(r)
|
||||
v := ReadBytes16(r)
|
||||
|
||||
pwd := encodeVerifierAndSalt(s, v)
|
||||
|
||||
db, err := initAuthDB()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return true
|
||||
}
|
||||
|
||||
err = modAuthItem(db, src.Username(), pwd)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return true
|
||||
}
|
||||
|
||||
db.Close()
|
||||
SetPassword(src.Username(), v, s)
|
||||
} else {
|
||||
log.Print("User " + src.Username() + " at " + src.Addr().String() + " did not enter sudo mode before attempting to change the password")
|
||||
}
|
||||
|
@ -350,21 +336,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
|
|||
if !src.sudoMode {
|
||||
A := ReadBytes16(r)
|
||||
|
||||
db, err := initAuthDB()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return true
|
||||
}
|
||||
|
||||
pwd, err := readAuthItem(db, src.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return true
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
s, v, err := decodeVerifierAndSalt(pwd)
|
||||
v, s, err := Password(src.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return true
|
||||
|
|
19
db.go
19
db.go
|
@ -10,8 +10,14 @@ import (
|
|||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
const (
|
||||
DBTypeSQLite3 = iota
|
||||
DBTypePSQL
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
dbType int
|
||||
}
|
||||
|
||||
// OpenSQLite3 opens and returns a SQLite3 database
|
||||
|
@ -28,7 +34,7 @@ func OpenSQLite3(name, initSQL string) (*DB, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &DB{DB: db}, nil
|
||||
return &DB{DB: db, dbType: DBTypeSQLite3}, nil
|
||||
}
|
||||
|
||||
// OpenPSQL opens and returns a PostgreSQL database
|
||||
|
@ -51,17 +57,24 @@ func OpenPSQL(host, name, user, password, initSQL string, port int) (*DB, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &DB{DB: db}, nil
|
||||
return &DB{DB: db, dbType: DBTypePSQL}, nil
|
||||
}
|
||||
|
||||
// Type returns the type of database that is being interacted with
|
||||
func (db *DB) Type() int { return db.dbType }
|
||||
|
||||
// Exec executes a SQL statement
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
|
||||
if db.Type() == DBTypePSQL {
|
||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
||||
return db.DB.Exec(sql, values)
|
||||
}
|
||||
return db.DB.Exec(sql, values...)
|
||||
}
|
||||
|
||||
// Query executes a SQL statement and stores the results
|
||||
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
|
||||
if db.Type() == DBTypePSQL {
|
||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
||||
}
|
||||
return db.DB.QueryRow(sql, values...)
|
||||
}
|
||||
|
|
47
init.go
47
init.go
|
@ -348,21 +348,13 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
|||
return
|
||||
}
|
||||
|
||||
db, err := initAuthDB()
|
||||
v, s, err := Password(c2.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
pwd, err := readAuthItem(db, c2.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
if pwd == "" {
|
||||
if v == nil || s == nil {
|
||||
// New player
|
||||
c2.authMech = AuthMechFirstSRP
|
||||
binary.BigEndian.PutUint32(data[7:11], uint32(AuthMechFirstSRP))
|
||||
|
@ -408,28 +400,11 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
|||
return
|
||||
}
|
||||
|
||||
pwd := encodeVerifierAndSalt(s, v)
|
||||
|
||||
db, err := initAuthDB()
|
||||
if err != nil {
|
||||
if err := CreateUser(c2.Username(), v, s); err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = addAuthItem(db, c2.Username(), pwd)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = addPrivItem(db, c2.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
// Send AUTH_ACCEPT
|
||||
data := []byte{
|
||||
0, ToClientAuthAccept,
|
||||
|
@ -464,21 +439,7 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
|||
|
||||
A := ReadBytes16(r)
|
||||
|
||||
db, err := initAuthDB()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
pwd, err := readAuthItem(db, c2.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
s, v, err := decodeVerifierAndSalt(pwd)
|
||||
v, s, err := Password(c2.Username())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
continue
|
||||
|
|
2
privs.go
2
privs.go
|
@ -79,7 +79,7 @@ func SetPrivs(name string, privs map[string]bool) error {
|
|||
|
||||
_, err = db.Exec(`REPLACE INTO privileges (
|
||||
name,
|
||||
priviliges
|
||||
privileges
|
||||
) VALUES (
|
||||
?,
|
||||
?
|
||||
|
|
10
rpc.go
10
rpc.go
|
@ -139,25 +139,19 @@ func processRpc(c *Conn, r *bytes.Reader) bool {
|
|||
}
|
||||
go c.doRpc("->ADDR "+addr, rq)
|
||||
case "<-ISBANNED":
|
||||
db, err := initAuthDB()
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
target := strings.Split(msg, " ")[2]
|
||||
|
||||
if net.ParseIP(target) == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
name, err := readBanItem(db, target)
|
||||
banned, _, err := IsBanned(target)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
r := "false"
|
||||
if name != "" {
|
||||
if banned {
|
||||
r = "true"
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue