439 lines
8.2 KiB
Go
439 lines
8.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/anon55555/mt/rudp"
|
|
)
|
|
|
|
const rpcCh = "multiserver"
|
|
|
|
const (
|
|
ModChSigJoinOk = iota
|
|
ModChSigJoinFail
|
|
ModChSigLeaveOk
|
|
ModChSigLeaveFail
|
|
ModChSigChNotRegistered
|
|
ModChSigSetState
|
|
)
|
|
|
|
const (
|
|
ModChStateInit = iota
|
|
ModChStateRW
|
|
ModChStateRO
|
|
)
|
|
|
|
var rpcSrvMu sync.Mutex
|
|
var rpcSrvs map[*Conn]struct{}
|
|
|
|
func (c *Conn) joinRpc() {
|
|
data := make([]byte, 4+len(rpcCh))
|
|
data[0] = uint8(0x00)
|
|
data[1] = uint8(ToServerModChannelJoin)
|
|
binary.BigEndian.PutUint16(data[2:4], uint16(len(rpcCh)))
|
|
copy(data[4:], []byte(rpcCh))
|
|
|
|
ack, err := c.Send(rudp.Pkt{Reader: bytes.NewReader(data)})
|
|
if err != nil {
|
|
return
|
|
}
|
|
<-ack
|
|
}
|
|
|
|
func (c *Conn) leaveRpc() {
|
|
data := make([]byte, 4+len(rpcCh))
|
|
data[0] = uint8(0x00)
|
|
data[1] = uint8(ToServerModChannelLeave)
|
|
binary.BigEndian.PutUint16(data[2:4], uint16(len(rpcCh)))
|
|
copy(data[4:], []byte(rpcCh))
|
|
|
|
ack, err := c.Send(rudp.Pkt{Reader: bytes.NewReader(data)})
|
|
if err != nil {
|
|
return
|
|
}
|
|
<-ack
|
|
}
|
|
|
|
func processRpc(c *Conn, r *bytes.Reader) bool {
|
|
ch := string(ReadBytes16(r))
|
|
sender := string(ReadBytes16(r))
|
|
msg := string(ReadBytes16(r))
|
|
|
|
if ch != rpcCh || sender != "" {
|
|
return false
|
|
}
|
|
|
|
rq := strings.Split(msg, " ")[0]
|
|
|
|
switch cmd := strings.Split(msg, " ")[1]; cmd {
|
|
case "<-ALERT":
|
|
ChatSendAll(strings.Join(strings.Split(msg, " ")[2:], " "))
|
|
case "<-GETDEFSRV":
|
|
defsrv, ok := ConfKey("default_server").(string)
|
|
if !ok {
|
|
return true
|
|
}
|
|
go c.doRpc("->DEFSRV "+defsrv, rq)
|
|
case "<-GETPEERCNT":
|
|
cnt := strconv.Itoa(ConnCount())
|
|
go c.doRpc("->PEERCNT "+cnt, rq)
|
|
case "<-ISONLINE":
|
|
online := "false"
|
|
if IsOnline(strings.Join(strings.Split(msg, " ")[2:], " ")) {
|
|
online = "true"
|
|
}
|
|
go c.doRpc("->ISONLINE "+online, rq)
|
|
case "<-CHECKPRIVS":
|
|
name := strings.Split(msg, " ")[2]
|
|
privs := decodePrivs(strings.Join(strings.Split(msg, " ")[3:], " "))
|
|
hasprivs := "false"
|
|
|
|
has, err := CheckPrivs(name, privs)
|
|
if err == nil && has {
|
|
hasprivs = "true"
|
|
}
|
|
|
|
go c.doRpc("->HASPRIVS "+hasprivs, rq)
|
|
case "<-GETPRIVS":
|
|
name := strings.Split(msg, " ")[2]
|
|
var r string
|
|
|
|
privs, err := Privs(name)
|
|
if err == nil {
|
|
r = strings.Replace(encodePrivs(privs), "|", ",", -1)
|
|
}
|
|
|
|
go c.doRpc("->PRIVS "+r, rq)
|
|
case "<-SETPRIVS":
|
|
name := strings.Split(msg, " ")[2]
|
|
privs := decodePrivs(strings.Join(strings.Split(msg, " ")[3:], " "))
|
|
|
|
SetPrivs(name, privs)
|
|
case "<-GETSRV":
|
|
name := strings.Split(msg, " ")[2]
|
|
var srv string
|
|
if IsOnline(name) {
|
|
srv = ConnByUsername(name).ServerName()
|
|
}
|
|
go c.doRpc("->SRV "+srv, rq)
|
|
case "<-REDIRECT":
|
|
name := strings.Split(msg, " ")[2]
|
|
tosrv := strings.Split(msg, " ")[3]
|
|
if IsOnline(name) {
|
|
go ConnByUsername(name).Redirect(tosrv)
|
|
}
|
|
case "<-GETADDR":
|
|
name := strings.Split(msg, " ")[2]
|
|
var addr string
|
|
if IsOnline(name) {
|
|
addr = ConnByUsername(name).Addr().String()
|
|
}
|
|
go c.doRpc("->ADDR "+addr, rq)
|
|
case "<-ISBANNED":
|
|
target := strings.Split(msg, " ")[2]
|
|
|
|
if net.ParseIP(target) == nil {
|
|
return true
|
|
}
|
|
|
|
banned, _, err := IsBanned(target)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
|
|
r := "false"
|
|
if banned {
|
|
r = "true"
|
|
}
|
|
|
|
go c.doRpc("->ISBANNED "+r, rq)
|
|
case "<-BAN":
|
|
target := strings.Split(msg, " ")[2]
|
|
err := Ban(target, "not known")
|
|
if err != nil {
|
|
c2 := ConnByUsername(target)
|
|
if c2 == nil {
|
|
return true
|
|
}
|
|
|
|
c2.Ban()
|
|
}
|
|
case "<-UNBAN":
|
|
target := strings.Split(msg, " ")[2]
|
|
Unban(target)
|
|
case "<-GETSRVS":
|
|
var srvs string
|
|
|
|
servers := ConfKey("servers").(map[interface{}]interface{})
|
|
for server := range servers {
|
|
srvs += server.(string) + ","
|
|
}
|
|
srvs = srvs[:len(srvs)-1]
|
|
|
|
go c.doRpc("->SRVS "+srvs, rq)
|
|
case "<-MT2MT":
|
|
msg := strings.Join(strings.Split(msg, " ")[2:], " ")
|
|
rpcSrvMu.Lock()
|
|
for srv := range rpcSrvs {
|
|
if srv.Addr().String() != c.Addr().String() {
|
|
go srv.doRpc("->MT2MT true "+msg, "--")
|
|
}
|
|
}
|
|
rpcSrvMu.Unlock()
|
|
case "<-MSG2MT":
|
|
tosrv := strings.Split(msg, " ")[2]
|
|
addr, ok := ConfKey("servers:" + tosrv + ":address").(string)
|
|
if !ok || addr == c.Addr().String() {
|
|
return true
|
|
}
|
|
|
|
msg := strings.Join(strings.Split(msg, " ")[3:], " ")
|
|
rpcSrvMu.Lock()
|
|
for srv := range rpcSrvs {
|
|
if srv.Addr().String() == addr {
|
|
go srv.doRpc("->MT2MT false "+msg, "--")
|
|
}
|
|
}
|
|
rpcSrvMu.Unlock()
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *Conn) doRpc(rpc, rq string) {
|
|
if !c.UseRpc() {
|
|
return
|
|
}
|
|
|
|
msg := rq + " " + rpc
|
|
|
|
w := bytes.NewBuffer([]byte{0x00, ToServerModChannelMsg})
|
|
WriteBytes16(w, []byte(rpcCh))
|
|
WriteBytes16(w, []byte(msg))
|
|
|
|
_, err := c.Send(rudp.Pkt{Reader: w})
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
func connectRpc() {
|
|
log.Print("Establishing RPC connections")
|
|
|
|
servers := ConfKey("servers").(map[interface{}]interface{})
|
|
for server := range servers {
|
|
clt := &Conn{username: "rpc"}
|
|
|
|
straddr := ConfKey("servers:" + server.(string) + ":address")
|
|
|
|
srvaddr, err := net.ResolveUDPAddr("udp", straddr.(string))
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
conn, err := net.DialUDP("udp", nil, srvaddr)
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
srv, err := Connect(conn)
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
fin := make(chan *Conn) // close-only
|
|
go Init(clt, srv, true, true, fin)
|
|
|
|
go func() {
|
|
<-fin
|
|
|
|
rpcSrvMu.Lock()
|
|
rpcSrvs[srv] = struct{}{}
|
|
rpcSrvMu.Unlock()
|
|
|
|
go srv.joinRpc()
|
|
go handleRpc(srv)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func handleRpc(srv *Conn) {
|
|
srv.MakeRpcOnly()
|
|
for {
|
|
pkt, err := srv.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
rpcSrvMu.Lock()
|
|
delete(rpcSrvs, srv)
|
|
rpcSrvMu.Unlock()
|
|
break
|
|
}
|
|
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
r := ByteReader(pkt)
|
|
|
|
switch cmd := ReadUint16(r); cmd {
|
|
case ToClientModChannelSignal:
|
|
r.Seek(1, io.SeekCurrent)
|
|
|
|
ch := string(ReadBytes16(r))
|
|
state := ReadUint8(r)
|
|
|
|
if ch == rpcCh {
|
|
r.Seek(2, io.SeekStart)
|
|
|
|
switch sig := ReadUint8(r); sig {
|
|
case ModChSigJoinOk:
|
|
srv.SetUseRpc(true)
|
|
case ModChSigSetState:
|
|
if state == ModChStateRO {
|
|
srv.SetUseRpc(false)
|
|
}
|
|
}
|
|
}
|
|
case ToClientModChannelMSG:
|
|
processRpc(srv, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func OptimizeRPCConns() {
|
|
rpcSrvMu.Lock()
|
|
defer rpcSrvMu.Unlock()
|
|
|
|
ServerLoop:
|
|
for c := range rpcSrvs {
|
|
for _, c2 := range Conns() {
|
|
if c2.Server() == nil {
|
|
continue
|
|
}
|
|
if c2.Server().Addr().String() == c.Addr().String() {
|
|
if c.NoClt() {
|
|
c.Close()
|
|
} else {
|
|
c.SetUseRpc(false)
|
|
c.leaveRpc()
|
|
}
|
|
|
|
delete(rpcSrvs, c)
|
|
|
|
c3 := c2.Server()
|
|
c3.SetUseRpc(true)
|
|
c3.joinRpc()
|
|
|
|
rpcSrvs[c3] = struct{}{}
|
|
|
|
go func() {
|
|
<-c3.Closed()
|
|
rpcSrvMu.Lock()
|
|
delete(rpcSrvs, c3)
|
|
rpcSrvMu.Unlock()
|
|
|
|
for c2.Server().Addr().String() == c3.Addr().String() {
|
|
}
|
|
OptimizeRPCConns()
|
|
}()
|
|
|
|
continue ServerLoop
|
|
}
|
|
}
|
|
}
|
|
|
|
go reconnectRpc(false)
|
|
}
|
|
|
|
func reconnectRpc(media bool) {
|
|
servers := ConfKey("servers").(map[interface{}]interface{})
|
|
ServerLoop:
|
|
for server := range servers {
|
|
clt := &Conn{username: "rpc"}
|
|
|
|
straddr := ConfKey("servers:" + server.(string) + ":address").(string)
|
|
|
|
rpcSrvMu.Lock()
|
|
for rpcsrv := range rpcSrvs {
|
|
if rpcsrv.Addr().String() == straddr {
|
|
rpcSrvMu.Unlock()
|
|
continue ServerLoop
|
|
}
|
|
}
|
|
rpcSrvMu.Unlock()
|
|
|
|
// Also refetch media in case something has not
|
|
// been downloaded yet
|
|
if media {
|
|
loadMedia(map[string]struct{}{server.(string): {}})
|
|
}
|
|
|
|
srvaddr, err := net.ResolveUDPAddr("udp", straddr)
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
conn, err := net.DialUDP("udp", nil, srvaddr)
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
srv, err := Connect(conn)
|
|
if err != nil {
|
|
log.Print(err)
|
|
continue
|
|
}
|
|
|
|
fin := make(chan *Conn) // close-only
|
|
go Init(clt, srv, true, true, fin)
|
|
|
|
go func() {
|
|
<-fin
|
|
|
|
rpcSrvMu.Lock()
|
|
rpcSrvs[srv] = struct{}{}
|
|
rpcSrvMu.Unlock()
|
|
|
|
go srv.joinRpc()
|
|
go handleRpc(srv)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
rpcSrvMu.Lock()
|
|
rpcSrvs = make(map[*Conn]struct{})
|
|
rpcSrvMu.Unlock()
|
|
|
|
reconnect, ok := ConfKey("server_reintegration_interval").(int)
|
|
if !ok {
|
|
reconnect = 600
|
|
}
|
|
|
|
connectRpc()
|
|
|
|
go func() {
|
|
reconnect := time.NewTicker(time.Duration(reconnect) * time.Second)
|
|
for {
|
|
select {
|
|
case <-reconnect.C:
|
|
log.Print("Reintegrating servers")
|
|
reconnectRpc(true)
|
|
}
|
|
}
|
|
}()
|
|
}
|