multiserver/rpc.go

439 lines
8.2 KiB
Go

package main
import (
"bytes"
"encoding/binary"
"errors"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/anon55555/mt/rudp"
)
const rpcCh = "multiserver"
const (
ModChSigJoinOk = iota
ModChSigJoinFail
ModChSigLeaveOk
ModChSigLeaveFail
ModChSigChNotRegistered
ModChSigSetState
)
const (
ModChStateInit = iota
ModChStateRW
ModChStateRO
)
var rpcSrvMu sync.Mutex
var rpcSrvs map[*Conn]struct{}
func (c *Conn) joinRpc() {
data := make([]byte, 4+len(rpcCh))
data[0] = uint8(0x00)
data[1] = uint8(ToServerModChannelJoin)
binary.BigEndian.PutUint16(data[2:4], uint16(len(rpcCh)))
copy(data[4:], []byte(rpcCh))
ack, err := c.Send(rudp.Pkt{Reader: bytes.NewReader(data)})
if err != nil {
return
}
<-ack
}
func (c *Conn) leaveRpc() {
data := make([]byte, 4+len(rpcCh))
data[0] = uint8(0x00)
data[1] = uint8(ToServerModChannelLeave)
binary.BigEndian.PutUint16(data[2:4], uint16(len(rpcCh)))
copy(data[4:], []byte(rpcCh))
ack, err := c.Send(rudp.Pkt{Reader: bytes.NewReader(data)})
if err != nil {
return
}
<-ack
}
func processRpc(c *Conn, r *bytes.Reader) bool {
ch := string(ReadBytes16(r))
sender := string(ReadBytes16(r))
msg := string(ReadBytes16(r))
if ch != rpcCh || sender != "" {
return false
}
rq := strings.Split(msg, " ")[0]
switch cmd := strings.Split(msg, " ")[1]; cmd {
case "<-ALERT":
ChatSendAll(strings.Join(strings.Split(msg, " ")[2:], " "))
case "<-GETDEFSRV":
defsrv, ok := ConfKey("default_server").(string)
if !ok {
return true
}
go c.doRpc("->DEFSRV "+defsrv, rq)
case "<-GETPEERCNT":
cnt := strconv.Itoa(ConnCount())
go c.doRpc("->PEERCNT "+cnt, rq)
case "<-ISONLINE":
online := "false"
if IsOnline(strings.Join(strings.Split(msg, " ")[2:], " ")) {
online = "true"
}
go c.doRpc("->ISONLINE "+online, rq)
case "<-CHECKPRIVS":
name := strings.Split(msg, " ")[2]
privs := decodePrivs(strings.Join(strings.Split(msg, " ")[3:], " "))
hasprivs := "false"
has, err := CheckPrivs(name, privs)
if err == nil && has {
hasprivs = "true"
}
go c.doRpc("->HASPRIVS "+hasprivs, rq)
case "<-GETPRIVS":
name := strings.Split(msg, " ")[2]
var r string
privs, err := Privs(name)
if err == nil {
r = strings.Replace(encodePrivs(privs), "|", ",", -1)
}
go c.doRpc("->PRIVS "+r, rq)
case "<-SETPRIVS":
name := strings.Split(msg, " ")[2]
privs := decodePrivs(strings.Join(strings.Split(msg, " ")[3:], " "))
SetPrivs(name, privs)
case "<-GETSRV":
name := strings.Split(msg, " ")[2]
var srv string
if IsOnline(name) {
srv = ConnByUsername(name).ServerName()
}
go c.doRpc("->SRV "+srv, rq)
case "<-REDIRECT":
name := strings.Split(msg, " ")[2]
tosrv := strings.Split(msg, " ")[3]
if IsOnline(name) {
go ConnByUsername(name).Redirect(tosrv)
}
case "<-GETADDR":
name := strings.Split(msg, " ")[2]
var addr string
if IsOnline(name) {
addr = ConnByUsername(name).Addr().String()
}
go c.doRpc("->ADDR "+addr, rq)
case "<-ISBANNED":
target := strings.Split(msg, " ")[2]
if net.ParseIP(target) == nil {
return true
}
banned, _, err := IsBanned(target)
if err != nil {
return true
}
r := "false"
if banned {
r = "true"
}
go c.doRpc("->ISBANNED "+r, rq)
case "<-BAN":
target := strings.Split(msg, " ")[2]
err := Ban(target, "not known")
if err != nil {
c2 := ConnByUsername(target)
if c2 == nil {
return true
}
c2.Ban()
}
case "<-UNBAN":
target := strings.Split(msg, " ")[2]
Unban(target)
case "<-GETSRVS":
var srvs string
servers := ConfKey("servers").(map[interface{}]interface{})
for server := range servers {
srvs += server.(string) + ","
}
srvs = srvs[:len(srvs)-1]
go c.doRpc("->SRVS "+srvs, rq)
case "<-MT2MT":
msg := strings.Join(strings.Split(msg, " ")[2:], " ")
rpcSrvMu.Lock()
for srv := range rpcSrvs {
if srv.Addr().String() != c.Addr().String() {
go srv.doRpc("->MT2MT true "+msg, "--")
}
}
rpcSrvMu.Unlock()
case "<-MSG2MT":
tosrv := strings.Split(msg, " ")[2]
addr, ok := ConfKey("servers:" + tosrv + ":address").(string)
if !ok || addr == c.Addr().String() {
return true
}
msg := strings.Join(strings.Split(msg, " ")[3:], " ")
rpcSrvMu.Lock()
for srv := range rpcSrvs {
if srv.Addr().String() == addr {
go srv.doRpc("->MT2MT false "+msg, "--")
}
}
rpcSrvMu.Unlock()
}
return true
}
func (c *Conn) doRpc(rpc, rq string) {
if !c.UseRpc() {
return
}
msg := rq + " " + rpc
w := bytes.NewBuffer([]byte{0x00, ToServerModChannelMsg})
WriteBytes16(w, []byte(rpcCh))
WriteBytes16(w, []byte(msg))
_, err := c.Send(rudp.Pkt{Reader: w})
if err != nil {
return
}
}
func connectRpc() {
log.Print("Establishing RPC connections")
servers := ConfKey("servers").(map[interface{}]interface{})
for server := range servers {
clt := &Conn{username: "rpc"}
straddr := ConfKey("servers:" + server.(string) + ":address")
srvaddr, err := net.ResolveUDPAddr("udp", straddr.(string))
if err != nil {
log.Print(err)
continue
}
conn, err := net.DialUDP("udp", nil, srvaddr)
if err != nil {
log.Print(err)
continue
}
srv, err := Connect(conn)
if err != nil {
log.Print(err)
continue
}
fin := make(chan *Conn) // close-only
go Init(clt, srv, true, true, fin)
go func() {
<-fin
rpcSrvMu.Lock()
rpcSrvs[srv] = struct{}{}
rpcSrvMu.Unlock()
go srv.joinRpc()
go handleRpc(srv)
}()
}
}
func handleRpc(srv *Conn) {
srv.MakeRpcOnly()
for {
pkt, err := srv.Recv()
if err != nil {
if errors.Is(err, net.ErrClosed) {
rpcSrvMu.Lock()
delete(rpcSrvs, srv)
rpcSrvMu.Unlock()
break
}
log.Print(err)
continue
}
r := ByteReader(pkt)
switch cmd := ReadUint16(r); cmd {
case ToClientModChannelSignal:
r.Seek(1, io.SeekCurrent)
ch := string(ReadBytes16(r))
state := ReadUint8(r)
if ch == rpcCh {
r.Seek(2, io.SeekStart)
switch sig := ReadUint8(r); sig {
case ModChSigJoinOk:
srv.SetUseRpc(true)
case ModChSigSetState:
if state == ModChStateRO {
srv.SetUseRpc(false)
}
}
}
case ToClientModChannelMSG:
processRpc(srv, r)
}
}
}
func OptimizeRPCConns() {
rpcSrvMu.Lock()
defer rpcSrvMu.Unlock()
ServerLoop:
for c := range rpcSrvs {
for _, c2 := range Conns() {
if c2.Server() == nil {
continue
}
if c2.Server().Addr().String() == c.Addr().String() {
if c.NoClt() {
c.Close()
} else {
c.SetUseRpc(false)
c.leaveRpc()
}
delete(rpcSrvs, c)
c3 := c2.Server()
c3.SetUseRpc(true)
c3.joinRpc()
rpcSrvs[c3] = struct{}{}
go func() {
<-c3.Closed()
rpcSrvMu.Lock()
delete(rpcSrvs, c3)
rpcSrvMu.Unlock()
for c2.Server().Addr().String() == c3.Addr().String() {
}
OptimizeRPCConns()
}()
continue ServerLoop
}
}
}
go reconnectRpc(false)
}
func reconnectRpc(media bool) {
servers := ConfKey("servers").(map[interface{}]interface{})
ServerLoop:
for server := range servers {
clt := &Conn{username: "rpc"}
straddr := ConfKey("servers:" + server.(string) + ":address").(string)
rpcSrvMu.Lock()
for rpcsrv := range rpcSrvs {
if rpcsrv.Addr().String() == straddr {
rpcSrvMu.Unlock()
continue ServerLoop
}
}
rpcSrvMu.Unlock()
// Also refetch media in case something has not
// been downloaded yet
if media {
loadMedia(map[string]struct{}{server.(string): {}})
}
srvaddr, err := net.ResolveUDPAddr("udp", straddr)
if err != nil {
log.Print(err)
continue
}
conn, err := net.DialUDP("udp", nil, srvaddr)
if err != nil {
log.Print(err)
continue
}
srv, err := Connect(conn)
if err != nil {
log.Print(err)
continue
}
fin := make(chan *Conn) // close-only
go Init(clt, srv, true, true, fin)
go func() {
<-fin
rpcSrvMu.Lock()
rpcSrvs[srv] = struct{}{}
rpcSrvMu.Unlock()
go srv.joinRpc()
go handleRpc(srv)
}()
}
}
func init() {
rpcSrvMu.Lock()
rpcSrvs = make(map[*Conn]struct{})
rpcSrvMu.Unlock()
reconnect, ok := ConfKey("server_reintegration_interval").(int)
if !ok {
reconnect = 600
}
connectRpc()
go func() {
reconnect := time.NewTicker(time.Duration(reconnect) * time.Second)
for {
select {
case <-reconnect.C:
log.Print("Reintegrating servers")
reconnectRpc(true)
}
}
}()
}