]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Fix ghost login bug in client by using 1.2.3 behavior
[rbdr/mobius] / hotline / server.go
index 2bd5c04bc01dbe6cd05bb6d9f27848c0cd6b2685..768433bc60b253f853c3b944580467e150344fc5 100644 (file)
@@ -5,6 +5,7 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
+       "github.com/go-playground/validator/v10"
        "go.uber.org/zap"
        "io"
        "io/fs"
@@ -43,7 +44,7 @@ type Server struct {
        Logger        *zap.SugaredLogger
        PrivateChats  map[uint32]*PrivateChat
        NextGuestID   *uint16
-       TrackerPassID []byte
+       TrackerPassID [4]byte
        Stats         *Stats
 
        APIListener  net.Listener
@@ -64,7 +65,12 @@ type PrivateChat struct {
 }
 
 func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFunc) error {
-       s.Logger.Infow("Hotline server started", "version", VERSION)
+       s.Logger.Infow("Hotline server started",
+               "version", VERSION,
+               "API port", fmt.Sprintf(":%v", s.Port),
+               "Transfer port", fmt.Sprintf(":%v", s.Port+1),
+       )
+
        var wg sync.WaitGroup
 
        wg.Add(1)
@@ -78,13 +84,7 @@ func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFu
        return nil
 }
 
-func (s *Server) APIPort() int {
-       return s.APIListener.Addr().(*net.TCPAddr).Port
-}
-
 func (s *Server) ServeFileTransfers(ln net.Listener) error {
-       s.Logger.Infow("Hotline file transfer server started", "Addr", fmt.Sprintf(":%v", s.Port+1))
-
        for {
                conn, err := ln.Accept()
                if err != nil {
@@ -131,13 +131,12 @@ func (s *Server) sendTransaction(t Transaction) error {
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
-               "remoteAddr", client.Connection.RemoteAddr(),
+               "remoteAddr", client.RemoteAddr,
        )
        return nil
 }
 
 func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln net.Listener) error {
-       s.Logger.Infow("Hotline server started", "Addr", fmt.Sprintf(":%v", s.Port))
 
        for {
                conn, err := ln.Accept()
@@ -156,7 +155,7 @@ func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln ne
                        }
                }()
                go func() {
-                       if err := s.handleNewConnection(conn); err != nil {
+                       if err := s.handleNewConnection(conn, conn.RemoteAddr().String()); err != nil {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
@@ -186,7 +185,6 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                outbox:        make(chan Transaction),
                Stats:         &Stats{StartTime: time.Now()},
                ThreadedNews:  &ThreadedNews{},
-               TrackerPassID: make([]byte, 4),
        }
 
        ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
@@ -195,22 +193,17 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
        }
        server.APIListener = ln
 
-       if netPort != 0 {
-               netPort += 1
-       }
-
-       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
+       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort+1))
        server.FileListener = ln2
        if err != nil {
                return nil, err
        }
 
        // generate a new random passID for tracker registration
-       if _, err := rand.Read(server.TrackerPassID); err != nil {
+       if _, err := rand.Read(server.TrackerPassID[:]); err != nil {
                return nil, err
        }
 
-       server.Logger.Debugw("Loading Agreement", "path", configDir+agreementFile)
        if server.Agreement, err = os.ReadFile(configDir + agreementFile); err != nil {
                return nil, err
        }
@@ -236,21 +229,26 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
        *server.NextGuestID = 1
 
        if server.Config.EnableTrackerRegistration {
+               server.Logger.Infow(
+                       "Tracker registration enabled",
+                       "frequency", fmt.Sprintf("%vs", trackerUpdateFrequency),
+                       "trackers", server.Config.Trackers,
+               )
+
                go func() {
                        for {
-                               tr := TrackerRegistration{
-                                       Port:        []byte{0x15, 0x7c},
+                               tr := &TrackerRegistration{
                                        UserCount:   server.userCount(),
-                                       PassID:      server.TrackerPassID,
+                                       PassID:      server.TrackerPassID[:],
                                        Name:        server.Config.Name,
                                        Description: server.Config.Description,
                                }
+                               binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port))
                                for _, t := range server.Config.Trackers {
-                                       server.Logger.Infof("Registering with tracker %v", t)
-
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
                                        }
+                                       server.Logger.Infow("Sent Tracker registration", "data", tr)
                                }
 
                                time.Sleep(trackerUpdateFrequency * time.Second)
@@ -314,7 +312,7 @@ func (s *Server) writeThreadedNews() error {
        return err
 }
 
-func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
+func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn {
        s.mux.Lock()
        defer s.mux.Unlock()
 
@@ -329,6 +327,7 @@ func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
+               RemoteAddr: remoteAddr,
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
@@ -471,6 +470,12 @@ func (s *Server) loadConfig(path string) error {
        if err != nil {
                return err
        }
+
+       validate := validator.New()
+       err = validate.Struct(s.Config)
+       if err != nil {
+               return err
+       }
        return nil
 }
 
@@ -488,14 +493,10 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(conn net.Conn) error {
+func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
-       handshakeBuf := make([]byte, 12)
-       if _, err := io.ReadFull(conn, handshakeBuf); err != nil {
-               return err
-       }
-       if err := Handshake(conn, handshakeBuf); err != nil {
+       if err := Handshake(conn); err != nil {
                return err
        }
 
@@ -514,7 +515,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                return err
        }
 
-       c := s.NewClientConn(conn)
+       c := s.NewClientConn(conn, remoteAddr)
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
@@ -556,7 +557,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                *c.Flags = []byte{0, 2}
        }
 
-       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", conn.RemoteAddr().String())
+       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr)
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
@@ -737,11 +738,27 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                }
        case FileUpload:
                destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName)
-               tmpFile := destinationFile + ".incomplete"
 
-               file, err := effectiveFile(destinationFile)
+               var file *os.File
+
+               // A file upload has three possible cases:
+               // 1) Upload a new file
+               // 2) Resume a partially transferred file
+               // 3) Replace a fully uploaded file
+               // Unfortunately we have to infer which case applies by inspecting what is already on the file system
+
+               // 1) Check for existing file:
+               _, err := os.Stat(destinationFile)
+               if err == nil {
+                       // If found, that means this upload is intended to replace the file
+                       if err = os.Remove(destinationFile); err != nil {
+                               return err
+                       }
+                       file, err = os.Create(destinationFile + incompleteFileSuffix)
+               }
                if errors.Is(err, fs.ErrNotExist) {
-                       file, err = FS.Create(tmpFile)
+                       // If not found, open or create a new incomplete file
+                       file, err = os.OpenFile(destinationFile+incompleteFileSuffix, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
                        if err != nil {
                                return err
                        }