]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Add initial validations of config.yaml
[rbdr/mobius] / hotline / server.go
index bace9865e78a447d0c3b0b89662036bc6055b93a..e6322e04193db80676a76f9e9623dde8e80a4c6c 100644 (file)
@@ -5,6 +5,7 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
        "encoding/binary"
        "errors"
        "fmt"
+       "github.com/go-playground/validator/v10"
        "go.uber.org/zap"
        "io"
        "io/fs"
        "go.uber.org/zap"
        "io"
        "io/fs"
@@ -131,7 +132,7 @@ func (s *Server) sendTransaction(t Transaction) error {
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
-               "remoteAddr", client.Connection.RemoteAddr(),
+               "remoteAddr", client.RemoteAddr,
        )
        return nil
 }
        )
        return nil
 }
@@ -156,7 +157,7 @@ func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln ne
                        }
                }()
                go func() {
                        }
                }()
                go func() {
-                       if err := s.handleNewConnection(conn); err != nil {
+                       if err := s.handleNewConnection(conn, conn.RemoteAddr().String()); err != nil {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
@@ -236,9 +237,15 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
        *server.NextGuestID = 1
 
        if server.Config.EnableTrackerRegistration {
        *server.NextGuestID = 1
 
        if server.Config.EnableTrackerRegistration {
+               server.Logger.Infow(
+                       "Tracker registration enabled",
+                       "frequency", fmt.Sprintf("%vs", trackerUpdateFrequency),
+                       "trackers", server.Config.Trackers,
+               )
+
                go func() {
                        for {
                go func() {
                        for {
-                               tr := TrackerRegistration{
+                               tr := &TrackerRegistration{
                                        Port:        []byte{0x15, 0x7c},
                                        UserCount:   server.userCount(),
                                        PassID:      server.TrackerPassID,
                                        Port:        []byte{0x15, 0x7c},
                                        UserCount:   server.userCount(),
                                        PassID:      server.TrackerPassID,
@@ -246,11 +253,10 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                                        Description: server.Config.Description,
                                }
                                for _, t := range server.Config.Trackers {
                                        Description: server.Config.Description,
                                }
                                for _, t := range server.Config.Trackers {
-                                       server.Logger.Infof("Registering with tracker %v", t)
-
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
                                        }
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
                                        }
+                                       server.Logger.Infow("Sent Tracker registration", "data", tr)
                                }
 
                                time.Sleep(trackerUpdateFrequency * time.Second)
                                }
 
                                time.Sleep(trackerUpdateFrequency * time.Second)
@@ -314,7 +320,7 @@ func (s *Server) writeThreadedNews() error {
        return err
 }
 
        return err
 }
 
-func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
+func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn {
        s.mux.Lock()
        defer s.mux.Unlock()
 
        s.mux.Lock()
        defer s.mux.Unlock()
 
@@ -329,6 +335,7 @@ func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
+               RemoteAddr: remoteAddr,
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
@@ -471,6 +478,12 @@ func (s *Server) loadConfig(path string) error {
        if err != nil {
                return err
        }
        if err != nil {
                return err
        }
+
+       validate := validator.New()
+       err = validate.Struct(s.Config)
+       if err != nil {
+               return err
+       }
        return nil
 }
 
        return nil
 }
 
@@ -488,16 +501,15 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(conn net.Conn) error {
-       handshakeBuf := make([]byte, 12) // handshakes are always 12 bytes in length
-       if _, err := conn.Read(handshakeBuf); err != nil {
-               return err
-       }
-       if err := Handshake(conn, handshakeBuf[:12]); err != nil {
+func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
+       defer dontPanic(s.Logger)
+
+       if err := Handshake(conn); err != nil {
                return err
        }
 
        buf := make([]byte, 1024)
                return err
        }
 
        buf := make([]byte, 1024)
+       // TODO: fix potential short read with io.ReadFull
        readLen, err := conn.Read(buf)
        if readLen < minTransactionLen {
                return err
        readLen, err := conn.Read(buf)
        if readLen < minTransactionLen {
                return err
@@ -511,10 +523,8 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                return err
        }
 
                return err
        }
 
-       c := s.NewClientConn(conn)
-
+       c := s.NewClientConn(conn, remoteAddr)
        defer c.Disconnect()
        defer c.Disconnect()
-       defer dontPanic(s.Logger)
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
        encodedPassword := clientLogin.GetField(fieldUserPassword).Data
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
        encodedPassword := clientLogin.GetField(fieldUserPassword).Data
@@ -555,7 +565,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                *c.Flags = []byte{0, 2}
        }
 
                *c.Flags = []byte{0, 2}
        }
 
-       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", conn.RemoteAddr().String())
+       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr)
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
@@ -736,11 +746,27 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                }
        case FileUpload:
                destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName)
                }
        case FileUpload:
                destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName)
-               tmpFile := destinationFile + ".incomplete"
 
 
-               file, err := effectiveFile(destinationFile)
+               var file *os.File
+
+               // A file upload has three possible cases:
+               // 1) Upload a new file
+               // 2) Resume a partially transferred file
+               // 3) Replace a fully uploaded file
+               // Unfortunately we have to infer which case applies by inspecting what is already on the file system
+
+               // 1) Check for existing file:
+               _, err := os.Stat(destinationFile)
+               if err == nil {
+                       // If found, that means this upload is intended to replace the file
+                       if err = os.Remove(destinationFile); err != nil {
+                               return err
+                       }
+                       file, err = os.Create(destinationFile + incompleteFileSuffix)
+               }
                if errors.Is(err, fs.ErrNotExist) {
                if errors.Is(err, fs.ErrNotExist) {
-                       file, err = FS.Create(tmpFile)
+                       // If not found, open or create a new incomplete file
+                       file, err = os.OpenFile(destinationFile+incompleteFileSuffix, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
                        if err != nil {
                                return err
                        }
                        if err != nil {
                                return err
                        }
@@ -798,7 +824,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                s.Logger.Infow("Start folder download", "path", fullFilePath, "ReferenceNumber", fileTransfer.ReferenceNumber)
 
                nextAction := make([]byte, 2)
                s.Logger.Infow("Start folder download", "path", fullFilePath, "ReferenceNumber", fileTransfer.ReferenceNumber)
 
                nextAction := make([]byte, 2)
-               if _, err := conn.Read(nextAction); err != nil {
+               if _, err := io.ReadFull(conn, nextAction); err != nil {
                        return err
                }
 
                        return err
                }
 
@@ -824,7 +850,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        }
 
                        // Read the client's Next Action request
                        }
 
                        // Read the client's Next Action request
-                       if _, err := conn.Read(nextAction); err != nil {
+                       if _, err := io.ReadFull(conn, nextAction); err != nil {
                                return err
                        }
 
                                return err
                        }
 
@@ -837,13 +863,13 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                // client asked to resume this file
                                var frd FileResumeData
                                // get size of resumeData
                                // client asked to resume this file
                                var frd FileResumeData
                                // get size of resumeData
-                               if _, err := conn.Read(nextAction); err != nil {
+                               if _, err := io.ReadFull(conn, nextAction); err != nil {
                                        return err
                                }
 
                                resumeDataLen := binary.BigEndian.Uint16(nextAction)
                                resumeDataBytes := make([]byte, resumeDataLen)
                                        return err
                                }
 
                                resumeDataLen := binary.BigEndian.Uint16(nextAction)
                                resumeDataBytes := make([]byte, resumeDataLen)
-                               if _, err := conn.Read(resumeDataBytes); err != nil {
+                               if _, err := io.ReadFull(conn, resumeDataBytes); err != nil {
                                        return err
                                }
 
                                        return err
                                }
 
@@ -920,7 +946,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        // TODO: optionally send resource fork header and resource fork data
 
                        // Read the client's Next Action request.  This is always 3, I think?
                        // TODO: optionally send resource fork header and resource fork data
 
                        // Read the client's Next Action request.  This is always 3, I think?
-                       if _, err := conn.Read(nextAction); err != nil {
+                       if _, err := io.ReadFull(conn, nextAction); err != nil {
                                return err
                        }
 
                                return err
                        }
 
@@ -956,7 +982,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                readBuffer := make([]byte, 1024)
 
                for i := 0; i < fileTransfer.ItemCount(); i++ {
                readBuffer := make([]byte, 1024)
 
                for i := 0; i < fileTransfer.ItemCount(); i++ {
-
+                       // TODO: fix potential short read with io.ReadFull
                        _, err := conn.Read(readBuffer)
                        if err != nil {
                                return err
                        _, err := conn.Read(readBuffer)
                        if err != nil {
                                return err
@@ -1034,7 +1060,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                                return err
                                        }
 
                                                return err
                                        }
 
-                                       if _, err := conn.Read(fileSize); err != nil {
+                                       if _, err := io.ReadFull(conn, fileSize); err != nil {
                                                return err
                                        }
 
                                                return err
                                        }