multiserver/db.go

91 lines
1.8 KiB
Go

package main
import (
"database/sql"
"fmt"
"os"
"regexp"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
const (
DBTypeSQLite3 = iota
DBTypePSQL
)
type DB struct {
*sql.DB
dbType int
}
// OpenSQLite3 opens and returns a SQLite3 database
func OpenSQLite3(name, initSQL string) (*DB, error) {
os.Mkdir("storage", 0777)
db, err := sql.Open("sqlite3", "storage/"+name)
if err != nil {
return nil, err
}
if _, err := db.Exec(initSQL); err != nil {
db.Close()
return nil, err
}
return &DB{DB: db, dbType: DBTypeSQLite3}, nil
}
// OpenPSQL opens and returns a PostgreSQL database
func OpenPSQL(host, name, user, password, initSQL string, port int) (*DB, error) {
psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host, port, user, password, name)
db, err := sql.Open("postgres", psqlconn)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec(initSQL); err != nil {
db.Close()
return nil, err
}
return &DB{DB: db, dbType: DBTypePSQL}, nil
}
// Type returns the type of database that is being interacted with
func (db *DB) Type() int { return db.dbType }
// Exec executes a SQL statement
func (db *DB) Exec(sql string, values ...interface{}) (sql.Result, error) {
if db.Type() == DBTypeSQLite3 {
r, err := regexp.Compile("\\$+[0-9]")
if err != nil {
return nil, err
}
sql = r.ReplaceAllString(sql, "?")
}
return db.DB.Exec(sql, values...)
}
// QueryRow executes a SQL statement and stores the results
func (db *DB) QueryRow(sql string, values ...interface{}) *sql.Row {
if db.Type() == DBTypeSQLite3 {
r, err := regexp.Compile("\\$+[0-9]")
if err != nil {
return nil
}
sql = r.ReplaceAllString(sql, "?")
}
return db.DB.QueryRow(sql, values...)
}