]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Fix Nostalgia ghost user bug
[rbdr/mobius] / hotline / server.go
index f164df1c8ff9e2b760f390f3242e8d45baa19246..014824b6a7d0c9a6942f6531d764b782dfbd7ba9 100644 (file)
@@ -1,10 +1,12 @@
 package hotline
 
 import (
 package hotline
 
 import (
+       "bytes"
        "context"
        "encoding/binary"
        "errors"
        "fmt"
        "context"
        "encoding/binary"
        "errors"
        "fmt"
+       "github.com/go-playground/validator/v10"
        "go.uber.org/zap"
        "io"
        "io/fs"
        "go.uber.org/zap"
        "io"
        "io/fs"
@@ -30,6 +32,8 @@ const (
        trackerUpdateFrequency = 300 // time in seconds between tracker re-registration
 )
 
        trackerUpdateFrequency = 300 // time in seconds between tracker re-registration
 )
 
+var nostalgiaVersion = []byte{0, 0, 2, 0x2c} // version ID used by the Nostalgia client
+
 type Server struct {
        Port          int
        Accounts      map[string]*Account
 type Server struct {
        Port          int
        Accounts      map[string]*Account
@@ -43,7 +47,7 @@ type Server struct {
        Logger        *zap.SugaredLogger
        PrivateChats  map[uint32]*PrivateChat
        NextGuestID   *uint16
        Logger        *zap.SugaredLogger
        PrivateChats  map[uint32]*PrivateChat
        NextGuestID   *uint16
-       TrackerPassID []byte
+       TrackerPassID [4]byte
        Stats         *Stats
 
        APIListener  net.Listener
        Stats         *Stats
 
        APIListener  net.Listener
@@ -64,7 +68,12 @@ type PrivateChat struct {
 }
 
 func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFunc) error {
 }
 
 func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFunc) error {
-       s.Logger.Infow("Hotline server started", "version", VERSION)
+       s.Logger.Infow("Hotline server started",
+               "version", VERSION,
+               "API port", fmt.Sprintf(":%v", s.Port),
+               "Transfer port", fmt.Sprintf(":%v", s.Port+1),
+       )
+
        var wg sync.WaitGroup
 
        wg.Add(1)
        var wg sync.WaitGroup
 
        wg.Add(1)
@@ -78,13 +87,7 @@ func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFu
        return nil
 }
 
        return nil
 }
 
-func (s *Server) APIPort() int {
-       return s.APIListener.Addr().(*net.TCPAddr).Port
-}
-
 func (s *Server) ServeFileTransfers(ln net.Listener) error {
 func (s *Server) ServeFileTransfers(ln net.Listener) error {
-       s.Logger.Infow("Hotline file transfer server started", "Addr", fmt.Sprintf(":%v", s.Port+1))
-
        for {
                conn, err := ln.Accept()
                if err != nil {
        for {
                conn, err := ln.Accept()
                if err != nil {
@@ -131,13 +134,12 @@ func (s *Server) sendTransaction(t Transaction) error {
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
-               "remoteAddr", client.Connection.RemoteAddr(),
+               "remoteAddr", client.RemoteAddr,
        )
        return nil
 }
 
 func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln net.Listener) error {
        )
        return nil
 }
 
 func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln net.Listener) error {
-       s.Logger.Infow("Hotline server started", "Addr", fmt.Sprintf(":%v", s.Port))
 
        for {
                conn, err := ln.Accept()
 
        for {
                conn, err := ln.Accept()
@@ -156,7 +158,7 @@ func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln ne
                        }
                }()
                go func() {
                        }
                }()
                go func() {
-                       if err := s.handleNewConnection(conn); err != nil {
+                       if err := s.handleNewConnection(conn, conn.RemoteAddr().String()); err != nil {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
@@ -186,7 +188,6 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                outbox:        make(chan Transaction),
                Stats:         &Stats{StartTime: time.Now()},
                ThreadedNews:  &ThreadedNews{},
                outbox:        make(chan Transaction),
                Stats:         &Stats{StartTime: time.Now()},
                ThreadedNews:  &ThreadedNews{},
-               TrackerPassID: make([]byte, 4),
        }
 
        ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
        }
 
        ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
@@ -195,22 +196,17 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
        }
        server.APIListener = ln
 
        }
        server.APIListener = ln
 
-       if netPort != 0 {
-               netPort += 1
-       }
-
-       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
+       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort+1))
        server.FileListener = ln2
        if err != nil {
                return nil, err
        }
 
        // generate a new random passID for tracker registration
        server.FileListener = ln2
        if err != nil {
                return nil, err
        }
 
        // generate a new random passID for tracker registration
-       if _, err := rand.Read(server.TrackerPassID); err != nil {
+       if _, err := rand.Read(server.TrackerPassID[:]); err != nil {
                return nil, err
        }
 
                return nil, err
        }
 
-       server.Logger.Debugw("Loading Agreement", "path", configDir+agreementFile)
        if server.Agreement, err = os.ReadFile(configDir + agreementFile); err != nil {
                return nil, err
        }
        if server.Agreement, err = os.ReadFile(configDir + agreementFile); err != nil {
                return nil, err
        }
@@ -245,12 +241,12 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                go func() {
                        for {
                                tr := &TrackerRegistration{
                go func() {
                        for {
                                tr := &TrackerRegistration{
-                                       Port:        []byte{0x15, 0x7c},
                                        UserCount:   server.userCount(),
                                        UserCount:   server.userCount(),
-                                       PassID:      server.TrackerPassID,
+                                       PassID:      server.TrackerPassID[:],
                                        Name:        server.Config.Name,
                                        Description: server.Config.Description,
                                }
                                        Name:        server.Config.Name,
                                        Description: server.Config.Description,
                                }
+                               binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port))
                                for _, t := range server.Config.Trackers {
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
                                for _, t := range server.Config.Trackers {
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
@@ -319,7 +315,7 @@ func (s *Server) writeThreadedNews() error {
        return err
 }
 
        return err
 }
 
-func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
+func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn {
        s.mux.Lock()
        defer s.mux.Unlock()
 
        s.mux.Lock()
        defer s.mux.Unlock()
 
@@ -334,6 +330,7 @@ func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
+               RemoteAddr: remoteAddr,
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
@@ -368,8 +365,6 @@ func (s *Server) UpdateUser(login, newLogin, name, password string, access []byt
        s.mux.Lock()
        defer s.mux.Unlock()
 
        s.mux.Lock()
        defer s.mux.Unlock()
 
-       fmt.Printf("login: %v, newLogin: %v: ", login, newLogin)
-
        // update renames the user login
        if login != newLogin {
                err := os.Rename(s.ConfigDir+"Users/"+login+".yaml", s.ConfigDir+"Users/"+newLogin+".yaml")
        // update renames the user login
        if login != newLogin {
                err := os.Rename(s.ConfigDir+"Users/"+login+".yaml", s.ConfigDir+"Users/"+newLogin+".yaml")
@@ -476,6 +471,12 @@ func (s *Server) loadConfig(path string) error {
        if err != nil {
                return err
        }
        if err != nil {
                return err
        }
+
+       validate := validator.New()
+       err = validate.Struct(s.Config)
+       if err != nil {
+               return err
+       }
        return nil
 }
 
        return nil
 }
 
@@ -493,14 +494,10 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(conn net.Conn) error {
+func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
        defer dontPanic(s.Logger)
 
-       handshakeBuf := make([]byte, 12)
-       if _, err := io.ReadFull(conn, handshakeBuf); err != nil {
-               return err
-       }
-       if err := Handshake(conn, handshakeBuf); err != nil {
+       if err := Handshake(conn); err != nil {
                return err
        }
 
                return err
        }
 
@@ -519,7 +516,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                return err
        }
 
                return err
        }
 
-       c := s.NewClientConn(conn)
+       c := s.NewClientConn(conn, remoteAddr)
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
@@ -561,7 +558,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                *c.Flags = []byte{0, 2}
        }
 
                *c.Flags = []byte{0, 2}
        }
 
-       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", conn.RemoteAddr().String())
+       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr)
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
@@ -575,8 +572,8 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
        // Show agreement to client
        c.Server.outbox <- *NewTransaction(tranShowAgreement, c.ID, NewField(fieldData, s.Agreement))
 
        // Show agreement to client
        c.Server.outbox <- *NewTransaction(tranShowAgreement, c.ID, NewField(fieldData, s.Agreement))
 
-       // assume simplified hotline v1.2.3 login flow that does not require agreement
-       if *c.Version == nil {
+       // Used simplified hotline v1.2.3 login flow for clients that do not send login info in tranAgreed
+       if *c.Version == nil || bytes.Equal(*c.Version, nostalgiaVersion) {
                c.Agreed = true
 
                c.notifyOthers(
                c.Agreed = true
 
                c.notifyOthers(