Use new DB API
parent
2b7abb3e63
commit
485db9d950
5
auth.go
5
auth.go
|
@ -2,7 +2,9 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -130,7 +132,8 @@ func Password(name string) ([]byte, []byte, error) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
var pwd string
|
var pwd string
|
||||||
if err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd); err != nil {
|
err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
9
ban.go
9
ban.go
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -48,7 +49,11 @@ func IsBanned(addr string) (bool, string, error) {
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name)
|
err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name)
|
||||||
return name != "", name, err
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return true, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return name != "", name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBanned reports whether a Conn is banned
|
// IsBanned reports whether a Conn is banned
|
||||||
|
@ -113,6 +118,6 @@ func Unban(id string) error {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id)
|
_, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
32
command.go
32
command.go
|
@ -326,21 +326,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
|
||||||
s := ReadBytes16(r)
|
s := ReadBytes16(r)
|
||||||
v := ReadBytes16(r)
|
v := ReadBytes16(r)
|
||||||
|
|
||||||
pwd := encodeVerifierAndSalt(s, v)
|
SetPassword(src.Username(), v, s)
|
||||||
|
|
||||||
db, err := initAuthDB()
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
err = modAuthItem(db, src.Username(), pwd)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
} else {
|
} else {
|
||||||
log.Print("User " + src.Username() + " at " + src.Addr().String() + " did not enter sudo mode before attempting to change the password")
|
log.Print("User " + src.Username() + " at " + src.Addr().String() + " did not enter sudo mode before attempting to change the password")
|
||||||
}
|
}
|
||||||
|
@ -350,21 +336,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool {
|
||||||
if !src.sudoMode {
|
if !src.sudoMode {
|
||||||
A := ReadBytes16(r)
|
A := ReadBytes16(r)
|
||||||
|
|
||||||
db, err := initAuthDB()
|
v, s, err := Password(src.Username())
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
pwd, err := readAuthItem(db, src.Username())
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
|
|
||||||
s, v, err := decodeVerifierAndSalt(pwd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return true
|
return true
|
||||||
|
|
23
db.go
23
db.go
|
@ -10,8 +10,14 @@ import (
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DBTypeSQLite3 = iota
|
||||||
|
DBTypePSQL
|
||||||
|
)
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
|
dbType int
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenSQLite3 opens and returns a SQLite3 database
|
// OpenSQLite3 opens and returns a SQLite3 database
|
||||||
|
@ -28,7 +34,7 @@ func OpenSQLite3(name, initSQL string) (*DB, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DB{DB: db}, nil
|
return &DB{DB: db, dbType: DBTypeSQLite3}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenPSQL opens and returns a PostgreSQL database
|
// OpenPSQL opens and returns a PostgreSQL database
|
||||||
|
@ -51,17 +57,24 @@ func OpenPSQL(host, name, user, password, initSQL string, port int) (*DB, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DB{DB: db}, nil
|
return &DB{DB: db, dbType: DBTypePSQL}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type returns the type of database that is being interacted with
|
||||||
|
func (db *DB) Type() int { return db.dbType }
|
||||||
|
|
||||||
// Exec executes a SQL statement
|
// Exec executes a SQL statement
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
|
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
|
||||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
if db.Type() == DBTypePSQL {
|
||||||
return db.DB.Exec(sql, values)
|
sql = strings.ReplaceAll(sql, "?", "$x")
|
||||||
|
}
|
||||||
|
return db.DB.Exec(sql, values...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a SQL statement and stores the results
|
// Query executes a SQL statement and stores the results
|
||||||
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
|
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
|
||||||
sql = strings.ReplaceAll(sql, "?", "$x")
|
if db.Type() == DBTypePSQL {
|
||||||
|
sql = strings.ReplaceAll(sql, "?", "$x")
|
||||||
|
}
|
||||||
return db.DB.QueryRow(sql, values...)
|
return db.DB.QueryRow(sql, values...)
|
||||||
}
|
}
|
||||||
|
|
47
init.go
47
init.go
|
@ -348,21 +348,13 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := initAuthDB()
|
v, s, err := Password(c2.Username())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pwd, err := readAuthItem(db, c2.Username())
|
if v == nil || s == nil {
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
|
|
||||||
if pwd == "" {
|
|
||||||
// New player
|
// New player
|
||||||
c2.authMech = AuthMechFirstSRP
|
c2.authMech = AuthMechFirstSRP
|
||||||
binary.BigEndian.PutUint32(data[7:11], uint32(AuthMechFirstSRP))
|
binary.BigEndian.PutUint32(data[7:11], uint32(AuthMechFirstSRP))
|
||||||
|
@ -408,28 +400,11 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pwd := encodeVerifierAndSalt(s, v)
|
if err := CreateUser(c2.Username(), v, s); err != nil {
|
||||||
|
|
||||||
db, err := initAuthDB()
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addAuthItem(db, c2.Username(), pwd)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addPrivItem(db, c2.Username())
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
|
|
||||||
// Send AUTH_ACCEPT
|
// Send AUTH_ACCEPT
|
||||||
data := []byte{
|
data := []byte{
|
||||||
0, ToClientAuthAccept,
|
0, ToClientAuthAccept,
|
||||||
|
@ -464,21 +439,7 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) {
|
||||||
|
|
||||||
A := ReadBytes16(r)
|
A := ReadBytes16(r)
|
||||||
|
|
||||||
db, err := initAuthDB()
|
v, s, err := Password(c2.Username())
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
pwd, err := readAuthItem(db, c2.Username())
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
|
|
||||||
s, v, err := decodeVerifierAndSalt(pwd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
continue
|
continue
|
||||||
|
|
2
privs.go
2
privs.go
|
@ -79,7 +79,7 @@ func SetPrivs(name string, privs map[string]bool) error {
|
||||||
|
|
||||||
_, err = db.Exec(`REPLACE INTO privileges (
|
_, err = db.Exec(`REPLACE INTO privileges (
|
||||||
name,
|
name,
|
||||||
priviliges
|
privileges
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?,
|
?,
|
||||||
?
|
?
|
||||||
|
|
10
rpc.go
10
rpc.go
|
@ -139,25 +139,19 @@ func processRpc(c *Conn, r *bytes.Reader) bool {
|
||||||
}
|
}
|
||||||
go c.doRpc("->ADDR "+addr, rq)
|
go c.doRpc("->ADDR "+addr, rq)
|
||||||
case "<-ISBANNED":
|
case "<-ISBANNED":
|
||||||
db, err := initAuthDB()
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
target := strings.Split(msg, " ")[2]
|
target := strings.Split(msg, " ")[2]
|
||||||
|
|
||||||
if net.ParseIP(target) == nil {
|
if net.ParseIP(target) == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
name, err := readBanItem(db, target)
|
banned, _, err := IsBanned(target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
r := "false"
|
r := "false"
|
||||||
if name != "" {
|
if banned {
|
||||||
r = "true"
|
r = "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue