]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Merge pull request #35 from aptonline/patch-1
[rbdr/mobius] / hotline / server.go
index 64cc5277c7512abd43452ea97437f4256a6b9b0a..76cc248ba57088ebf950a449c4af9551d49361d2 100644 (file)
@@ -1,6 +1,7 @@
 package hotline
 
 import (
+       "bytes"
        "context"
        "encoding/binary"
        "errors"
@@ -31,6 +32,8 @@ const (
        trackerUpdateFrequency = 300 // time in seconds between tracker re-registration
 )
 
+var nostalgiaVersion = []byte{0, 0, 2, 0x2c} // version ID used by the Nostalgia client
+
 type Server struct {
        Port          int
        Accounts      map[string]*Account
@@ -47,8 +50,7 @@ type Server struct {
        TrackerPassID [4]byte
        Stats         *Stats
 
-       APIListener  net.Listener
-       FileListener net.Listener
+       FS FileStore
 
        // newsReader io.Reader
        // newsWriter io.WriteCloser
@@ -74,10 +76,25 @@ func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFu
        var wg sync.WaitGroup
 
        wg.Add(1)
-       go func() { s.Logger.Fatal(s.Serve(ctx, cancelRoot, s.APIListener)) }()
+       go func() {
+               ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", "", s.Port))
+               if err != nil {
+                       s.Logger.Fatal(err)
+               }
+
+               s.Logger.Fatal(s.Serve(ctx, cancelRoot, ln))
+       }()
 
        wg.Add(1)
-       go func() { s.Logger.Fatal(s.ServeFileTransfers(s.FileListener)) }()
+       go func() {
+               ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", "", s.Port+1))
+               if err != nil {
+                       s.Logger.Fatal(err)
+
+               }
+
+               s.Logger.Fatal(s.ServeFileTransfers(ln))
+       }()
 
        wg.Wait()
 
@@ -171,7 +188,7 @@ const (
 )
 
 // NewServer constructs a new Server from a config dir
-func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredLogger) (*Server, error) {
+func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredLogger, FS FileStore) (*Server, error) {
        server := Server{
                Port:          netPort,
                Accounts:      make(map[string]*Account),
@@ -185,26 +202,18 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                outbox:        make(chan Transaction),
                Stats:         &Stats{StartTime: time.Now()},
                ThreadedNews:  &ThreadedNews{},
+               FS:            FS,
        }
 
-       ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
-       if err != nil {
-               return nil, err
-       }
-       server.APIListener = ln
-
-       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort+1))
-       server.FileListener = ln2
-       if err != nil {
-               return nil, err
-       }
+       var err error
 
        // generate a new random passID for tracker registration
        if _, err := rand.Read(server.TrackerPassID[:]); err != nil {
                return nil, err
        }
 
-       if server.Agreement, err = os.ReadFile(configDir + agreementFile); err != nil {
+       server.Agreement, err = os.ReadFile(configDir + agreementFile)
+       if err != nil {
                return nil, err
        }
 
@@ -355,7 +364,7 @@ func (s *Server) NewUser(login, name, password string, access []byte) error {
        }
        s.Accounts[login] = &account
 
-       return FS.WriteFile(s.ConfigDir+"Users/"+login+".yaml", out, 0666)
+       return s.FS.WriteFile(s.ConfigDir+"Users/"+login+".yaml", out, 0666)
 }
 
 func (s *Server) UpdateUser(login, newLogin, name, password string, access []byte) error {
@@ -395,7 +404,7 @@ func (s *Server) DeleteUser(login string) error {
 
        delete(s.Accounts, login)
 
-       return FS.Remove(s.ConfigDir + "Users/" + login + ".yaml")
+       return s.FS.Remove(s.ConfigDir + "Users/" + login + ".yaml")
 }
 
 func (s *Server) connectedUsers() []Field {
@@ -441,7 +450,7 @@ func (s *Server) loadAccounts(userDir string) error {
        }
 
        for _, file := range matches {
-               fh, err := FS.Open(file)
+               fh, err := s.FS.Open(file)
                if err != nil {
                        return err
                }
@@ -458,7 +467,7 @@ func (s *Server) loadAccounts(userDir string) error {
 }
 
 func (s *Server) loadConfig(path string) error {
-       fh, err := FS.Open(path)
+       fh, err := s.FS.Open(path)
        if err != nil {
                return err
        }
@@ -569,8 +578,8 @@ func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
        // Show agreement to client
        c.Server.outbox <- *NewTransaction(tranShowAgreement, c.ID, NewField(fieldData, s.Agreement))
 
-       // assume simplified hotline v1.2.3 login flow that does not require agreement
-       if *c.Version == nil {
+       // Used simplified hotline v1.2.3 login flow for clients that do not send login info in tranAgreed
+       if *c.Version == nil || bytes.Equal(*c.Version, nostalgiaVersion) {
                c.Agreed = true
 
                c.notifyOthers(
@@ -684,6 +693,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
 
        switch fileTransfer.Type {
        case FileDownload:
+               s.Stats.DownloadCounter += 1
+
                fullFilePath, err := readPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName)
                if err != nil {
                        return err
@@ -708,7 +719,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        }
                }
 
-               file, err := FS.Open(fullFilePath)
+               file, err := s.FS.Open(fullFilePath)
                if err != nil {
                        return err
                }
@@ -735,6 +746,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        }
                }
        case FileUpload:
+               s.Stats.UploadCounter += 1
+
                destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName)
 
                var file *os.File
@@ -820,6 +833,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
 
                i := 0
                err = filepath.Walk(fullFilePath+"/", func(path string, info os.FileInfo, err error) error {
+                       s.Stats.DownloadCounter += 1
+
                        if err != nil {
                                return err
                        }
@@ -900,7 +915,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                return err
                        }
 
-                       file, err := FS.Open(path)
+                       file, err := s.FS.Open(path)
                        if err != nil {
                                return err
                        }
@@ -957,8 +972,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                )
 
                // Check if the target folder exists.  If not, create it.
-               if _, err := FS.Stat(dstPath); os.IsNotExist(err) {
-                       if err := FS.Mkdir(dstPath, 0777); err != nil {
+               if _, err := s.FS.Stat(dstPath); os.IsNotExist(err) {
+                       if err := s.FS.Mkdir(dstPath, 0777); err != nil {
                                return err
                        }
                }
@@ -969,15 +984,26 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                }
 
                fileSize := make([]byte, 4)
-               readBuffer := make([]byte, 1024)
 
                for i := 0; i < fileTransfer.ItemCount(); i++ {
-                       // TODO: fix potential short read with io.ReadFull
-                       _, err := conn.Read(readBuffer)
-                       if err != nil {
+                       s.Stats.UploadCounter += 1
+
+                       var fu folderUpload
+                       if _, err := io.ReadFull(conn, fu.DataSize[:]); err != nil {
+                               return err
+                       }
+
+                       if _, err := io.ReadFull(conn, fu.IsFolder[:]); err != nil {
+                               return err
+                       }
+                       if _, err := io.ReadFull(conn, fu.PathItemCount[:]); err != nil {
+                               return err
+                       }
+                       fu.FileNamePath = make([]byte, binary.BigEndian.Uint16(fu.DataSize[:])-4)
+
+                       if _, err := io.ReadFull(conn, fu.FileNamePath); err != nil {
                                return err
                        }
-                       fu := readFolderUpload(readBuffer)
 
                        s.Logger.Infow(
                                "Folder upload continued",
@@ -1019,8 +1045,6 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                        nextAction = dlFldrActionResumeFile
                                }
 
-                               fmt.Printf("Next Action: %v\n", nextAction)
-
                                if _, err := conn.Write([]byte{0, uint8(nextAction)}); err != nil {
                                        return err
                                }
@@ -1064,14 +1088,14 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                        }
 
                                case dlFldrActionSendFile:
-                                       if _, err := conn.Read(fileSize); err != nil {
+                                       if _, err := io.ReadFull(conn, fileSize); err != nil {
                                                return err
                                        }
 
                                        filePath := dstPath + "/" + fu.FormattedPath()
                                        s.Logger.Infow("Starting file transfer", "path", filePath, "fileNum", i+1, "totalFiles", "zz", "fileSize", binary.BigEndian.Uint32(fileSize))
 
-                                       newFile, err := FS.Create(filePath + ".incomplete")
+                                       newFile, err := s.FS.Create(filePath + ".incomplete")
                                        if err != nil {
                                                return err
                                        }