diff --git a/auth.go b/auth.go index ac7a67c..0a7860b 100644 --- a/auth.go +++ b/auth.go @@ -2,7 +2,9 @@ package main import ( "crypto/rand" + "database/sql" "encoding/base64" + "errors" "log" "strings" @@ -130,7 +132,8 @@ func Password(name string) ([]byte, []byte, error) { defer db.Close() var pwd string - if err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd); err != nil { + err = db.QueryRow(`SELECT password FROM auth WHERE name = ?;`, name).Scan(&pwd) + if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, nil, err } diff --git a/ban.go b/ban.go index 1e836ef..a4b4574 100644 --- a/ban.go +++ b/ban.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "errors" "fmt" "net" @@ -48,7 +49,11 @@ func IsBanned(addr string) (bool, string, error) { var name string err = db.QueryRow(`SELECT name FROM ban WHERE addr = ?;`, addr).Scan(&name) - return name != "", name, err + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return true, "", err + } + + return name != "", name, nil } // IsBanned reports whether a Conn is banned @@ -113,6 +118,6 @@ func Unban(id string) error { } defer db.Close() - _, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id) + _, err = db.Exec(`DELETE FROM ban WHERE name = ? OR addr = ?;`, id, id) return err } diff --git a/command.go b/command.go index e029b36..84a921b 100644 --- a/command.go +++ b/command.go @@ -326,21 +326,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool { s := ReadBytes16(r) v := ReadBytes16(r) - pwd := encodeVerifierAndSalt(s, v) - - db, err := initAuthDB() - if err != nil { - log.Print(err) - return true - } - - err = modAuthItem(db, src.Username(), pwd) - if err != nil { - log.Print(err) - return true - } - - db.Close() + SetPassword(src.Username(), v, s) } else { log.Print("User " + src.Username() + " at " + src.Addr().String() + " did not enter sudo mode before attempting to change the password") } @@ -350,21 +336,7 @@ func processPktCommand(src, dst *Conn, pkt *rudp.Pkt) bool { if !src.sudoMode { A := ReadBytes16(r) - db, err := initAuthDB() - if err != nil { - log.Print(err) - return true - } - - pwd, err := readAuthItem(db, src.Username()) - if err != nil { - log.Print(err) - return true - } - - db.Close() - - s, v, err := decodeVerifierAndSalt(pwd) + v, s, err := Password(src.Username()) if err != nil { log.Print(err) return true diff --git a/db.go b/db.go index 71cf04d..fd1d0b7 100644 --- a/db.go +++ b/db.go @@ -10,8 +10,14 @@ import ( _ "github.com/mattn/go-sqlite3" ) +const ( + DBTypeSQLite3 = iota + DBTypePSQL +) + type DB struct { *sql.DB + dbType int } // OpenSQLite3 opens and returns a SQLite3 database @@ -28,7 +34,7 @@ func OpenSQLite3(name, initSQL string) (*DB, error) { return nil, err } - return &DB{DB: db}, nil + return &DB{DB: db, dbType: DBTypeSQLite3}, nil } // OpenPSQL opens and returns a PostgreSQL database @@ -51,17 +57,24 @@ func OpenPSQL(host, name, user, password, initSQL string, port int) (*DB, error) return nil, err } - return &DB{DB: db}, nil + return &DB{DB: db, dbType: DBTypePSQL}, nil } +// Type returns the type of database that is being interacted with +func (db *DB) Type() int { return db.dbType } + // Exec executes a SQL statement func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) { - sql = strings.ReplaceAll(sql, "?", "$x") - return db.DB.Exec(sql, values) + if db.Type() == DBTypePSQL { + sql = strings.ReplaceAll(sql, "?", "$x") + } + return db.DB.Exec(sql, values...) } // Query executes a SQL statement and stores the results func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row { - sql = strings.ReplaceAll(sql, "?", "$x") + if db.Type() == DBTypePSQL { + sql = strings.ReplaceAll(sql, "?", "$x") + } return db.DB.QueryRow(sql, values...) } diff --git a/init.go b/init.go index 488711c..7f7484c 100644 --- a/init.go +++ b/init.go @@ -348,21 +348,13 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) { return } - db, err := initAuthDB() + v, s, err := Password(c2.Username()) if err != nil { log.Print(err) continue } - pwd, err := readAuthItem(db, c2.Username()) - if err != nil { - log.Print(err) - continue - } - - db.Close() - - if pwd == "" { + if v == nil || s == nil { // New player c2.authMech = AuthMechFirstSRP binary.BigEndian.PutUint32(data[7:11], uint32(AuthMechFirstSRP)) @@ -408,28 +400,11 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) { return } - pwd := encodeVerifierAndSalt(s, v) - - db, err := initAuthDB() - if err != nil { + if err := CreateUser(c2.Username(), v, s); err != nil { log.Print(err) continue } - err = addAuthItem(db, c2.Username(), pwd) - if err != nil { - log.Print(err) - continue - } - - err = addPrivItem(db, c2.Username()) - if err != nil { - log.Print(err) - continue - } - - db.Close() - // Send AUTH_ACCEPT data := []byte{ 0, ToClientAuthAccept, @@ -464,21 +439,7 @@ func Init(c, c2 *Conn, ignMedia, noAccessDenied bool, fin chan *Conn) { A := ReadBytes16(r) - db, err := initAuthDB() - if err != nil { - log.Print(err) - continue - } - - pwd, err := readAuthItem(db, c2.Username()) - if err != nil { - log.Print(err) - continue - } - - db.Close() - - s, v, err := decodeVerifierAndSalt(pwd) + v, s, err := Password(c2.Username()) if err != nil { log.Print(err) continue diff --git a/privs.go b/privs.go index 4ea2d97..f815083 100644 --- a/privs.go +++ b/privs.go @@ -79,7 +79,7 @@ func SetPrivs(name string, privs map[string]bool) error { _, err = db.Exec(`REPLACE INTO privileges ( name, - priviliges + privileges ) VALUES ( ?, ? diff --git a/rpc.go b/rpc.go index be42db8..b458945 100644 --- a/rpc.go +++ b/rpc.go @@ -139,25 +139,19 @@ func processRpc(c *Conn, r *bytes.Reader) bool { } go c.doRpc("->ADDR "+addr, rq) case "<-ISBANNED": - db, err := initAuthDB() - if err != nil { - return true - } - defer db.Close() - target := strings.Split(msg, " ")[2] if net.ParseIP(target) == nil { return true } - name, err := readBanItem(db, target) + banned, _, err := IsBanned(target) if err != nil { return true } r := "false" - if name != "" { + if banned { r = "true" }