]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Add flag to enable optional stats HTTP endpoint
[rbdr/mobius] / hotline / server.go
index fb06001f3aaa16fb453f8db262bccbbf47a42844..76cc248ba57088ebf950a449c4af9551d49361d2 100644 (file)
@@ -1,10 +1,12 @@
 package hotline
 
 import (
+       "bytes"
        "context"
        "encoding/binary"
        "errors"
        "fmt"
+       "github.com/go-playground/validator/v10"
        "go.uber.org/zap"
        "io"
        "io/fs"
@@ -30,6 +32,8 @@ const (
        trackerUpdateFrequency = 300 // time in seconds between tracker re-registration
 )
 
+var nostalgiaVersion = []byte{0, 0, 2, 0x2c} // version ID used by the Nostalgia client
+
 type Server struct {
        Port          int
        Accounts      map[string]*Account
@@ -43,11 +47,10 @@ type Server struct {
        Logger        *zap.SugaredLogger
        PrivateChats  map[uint32]*PrivateChat
        NextGuestID   *uint16
-       TrackerPassID []byte
+       TrackerPassID [4]byte
        Stats         *Stats
 
-       APIListener  net.Listener
-       FileListener net.Listener
+       FS FileStore
 
        // newsReader io.Reader
        // newsWriter io.WriteCloser
@@ -64,27 +67,41 @@ type PrivateChat struct {
 }
 
 func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFunc) error {
-       s.Logger.Infow("Hotline server started", "version", VERSION)
+       s.Logger.Infow("Hotline server started",
+               "version", VERSION,
+               "API port", fmt.Sprintf(":%v", s.Port),
+               "Transfer port", fmt.Sprintf(":%v", s.Port+1),
+       )
+
        var wg sync.WaitGroup
 
        wg.Add(1)
-       go func() { s.Logger.Fatal(s.Serve(ctx, cancelRoot, s.APIListener)) }()
+       go func() {
+               ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", "", s.Port))
+               if err != nil {
+                       s.Logger.Fatal(err)
+               }
+
+               s.Logger.Fatal(s.Serve(ctx, cancelRoot, ln))
+       }()
 
        wg.Add(1)
-       go func() { s.Logger.Fatal(s.ServeFileTransfers(s.FileListener)) }()
+       go func() {
+               ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", "", s.Port+1))
+               if err != nil {
+                       s.Logger.Fatal(err)
+
+               }
+
+               s.Logger.Fatal(s.ServeFileTransfers(ln))
+       }()
 
        wg.Wait()
 
        return nil
 }
 
-func (s *Server) APIPort() int {
-       return s.APIListener.Addr().(*net.TCPAddr).Port
-}
-
 func (s *Server) ServeFileTransfers(ln net.Listener) error {
-       s.Logger.Infow("Hotline file transfer server started", "Addr", fmt.Sprintf(":%v", s.Port+1))
-
        for {
                conn, err := ln.Accept()
                if err != nil {
@@ -137,7 +154,6 @@ func (s *Server) sendTransaction(t Transaction) error {
 }
 
 func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln net.Listener) error {
-       s.Logger.Infow("Hotline server started", "Addr", fmt.Sprintf(":%v", s.Port))
 
        for {
                conn, err := ln.Accept()
@@ -172,7 +188,7 @@ const (
 )
 
 // NewServer constructs a new Server from a config dir
-func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredLogger) (*Server, error) {
+func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredLogger, FS FileStore) (*Server, error) {
        server := Server{
                Port:          netPort,
                Accounts:      make(map[string]*Account),
@@ -186,32 +202,18 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                outbox:        make(chan Transaction),
                Stats:         &Stats{StartTime: time.Now()},
                ThreadedNews:  &ThreadedNews{},
-               TrackerPassID: make([]byte, 4),
+               FS:            FS,
        }
 
-       ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
-       if err != nil {
-               return nil, err
-       }
-       server.APIListener = ln
-
-       if netPort != 0 {
-               netPort += 1
-       }
-
-       ln2, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort))
-       server.FileListener = ln2
-       if err != nil {
-               return nil, err
-       }
+       var err error
 
        // generate a new random passID for tracker registration
-       if _, err := rand.Read(server.TrackerPassID); err != nil {
+       if _, err := rand.Read(server.TrackerPassID[:]); err != nil {
                return nil, err
        }
 
-       server.Logger.Debugw("Loading Agreement", "path", configDir+agreementFile)
-       if server.Agreement, err = os.ReadFile(configDir + agreementFile); err != nil {
+       server.Agreement, err = os.ReadFile(configDir + agreementFile)
+       if err != nil {
                return nil, err
        }
 
@@ -245,12 +247,12 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL
                go func() {
                        for {
                                tr := &TrackerRegistration{
-                                       Port:        []byte{0x15, 0x7c},
                                        UserCount:   server.userCount(),
-                                       PassID:      server.TrackerPassID,
+                                       PassID:      server.TrackerPassID[:],
                                        Name:        server.Config.Name,
                                        Description: server.Config.Description,
                                }
+                               binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port))
                                for _, t := range server.Config.Trackers {
                                        if err := register(t, tr); err != nil {
                                                server.Logger.Errorw("unable to register with tracker %v", "error", err)
@@ -362,15 +364,13 @@ func (s *Server) NewUser(login, name, password string, access []byte) error {
        }
        s.Accounts[login] = &account
 
-       return FS.WriteFile(s.ConfigDir+"Users/"+login+".yaml", out, 0666)
+       return s.FS.WriteFile(s.ConfigDir+"Users/"+login+".yaml", out, 0666)
 }
 
 func (s *Server) UpdateUser(login, newLogin, name, password string, access []byte) error {
        s.mux.Lock()
        defer s.mux.Unlock()
 
-       fmt.Printf("login: %v, newLogin: %v: ", login, newLogin)
-
        // update renames the user login
        if login != newLogin {
                err := os.Rename(s.ConfigDir+"Users/"+login+".yaml", s.ConfigDir+"Users/"+newLogin+".yaml")
@@ -404,7 +404,7 @@ func (s *Server) DeleteUser(login string) error {
 
        delete(s.Accounts, login)
 
-       return FS.Remove(s.ConfigDir + "Users/" + login + ".yaml")
+       return s.FS.Remove(s.ConfigDir + "Users/" + login + ".yaml")
 }
 
 func (s *Server) connectedUsers() []Field {
@@ -450,7 +450,7 @@ func (s *Server) loadAccounts(userDir string) error {
        }
 
        for _, file := range matches {
-               fh, err := FS.Open(file)
+               fh, err := s.FS.Open(file)
                if err != nil {
                        return err
                }
@@ -467,7 +467,7 @@ func (s *Server) loadAccounts(userDir string) error {
 }
 
 func (s *Server) loadConfig(path string) error {
-       fh, err := FS.Open(path)
+       fh, err := s.FS.Open(path)
        if err != nil {
                return err
        }
@@ -477,6 +477,12 @@ func (s *Server) loadConfig(path string) error {
        if err != nil {
                return err
        }
+
+       validate := validator.New()
+       err = validate.Struct(s.Config)
+       if err != nil {
+               return err
+       }
        return nil
 }
 
@@ -572,8 +578,8 @@ func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
        // Show agreement to client
        c.Server.outbox <- *NewTransaction(tranShowAgreement, c.ID, NewField(fieldData, s.Agreement))
 
-       // assume simplified hotline v1.2.3 login flow that does not require agreement
-       if *c.Version == nil {
+       // Used simplified hotline v1.2.3 login flow for clients that do not send login info in tranAgreed
+       if *c.Version == nil || bytes.Equal(*c.Version, nostalgiaVersion) {
                c.Agreed = true
 
                c.notifyOthers(
@@ -687,6 +693,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
 
        switch fileTransfer.Type {
        case FileDownload:
+               s.Stats.DownloadCounter += 1
+
                fullFilePath, err := readPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName)
                if err != nil {
                        return err
@@ -711,7 +719,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        }
                }
 
-               file, err := FS.Open(fullFilePath)
+               file, err := s.FS.Open(fullFilePath)
                if err != nil {
                        return err
                }
@@ -738,6 +746,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                        }
                }
        case FileUpload:
+               s.Stats.UploadCounter += 1
+
                destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName)
 
                var file *os.File
@@ -823,6 +833,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
 
                i := 0
                err = filepath.Walk(fullFilePath+"/", func(path string, info os.FileInfo, err error) error {
+                       s.Stats.DownloadCounter += 1
+
                        if err != nil {
                                return err
                        }
@@ -903,7 +915,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                return err
                        }
 
-                       file, err := FS.Open(path)
+                       file, err := s.FS.Open(path)
                        if err != nil {
                                return err
                        }
@@ -960,8 +972,8 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                )
 
                // Check if the target folder exists.  If not, create it.
-               if _, err := FS.Stat(dstPath); os.IsNotExist(err) {
-                       if err := FS.Mkdir(dstPath, 0777); err != nil {
+               if _, err := s.FS.Stat(dstPath); os.IsNotExist(err) {
+                       if err := s.FS.Mkdir(dstPath, 0777); err != nil {
                                return err
                        }
                }
@@ -972,15 +984,26 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                }
 
                fileSize := make([]byte, 4)
-               readBuffer := make([]byte, 1024)
 
                for i := 0; i < fileTransfer.ItemCount(); i++ {
-                       // TODO: fix potential short read with io.ReadFull
-                       _, err := conn.Read(readBuffer)
-                       if err != nil {
+                       s.Stats.UploadCounter += 1
+
+                       var fu folderUpload
+                       if _, err := io.ReadFull(conn, fu.DataSize[:]); err != nil {
+                               return err
+                       }
+
+                       if _, err := io.ReadFull(conn, fu.IsFolder[:]); err != nil {
+                               return err
+                       }
+                       if _, err := io.ReadFull(conn, fu.PathItemCount[:]); err != nil {
+                               return err
+                       }
+                       fu.FileNamePath = make([]byte, binary.BigEndian.Uint16(fu.DataSize[:])-4)
+
+                       if _, err := io.ReadFull(conn, fu.FileNamePath); err != nil {
                                return err
                        }
-                       fu := readFolderUpload(readBuffer)
 
                        s.Logger.Infow(
                                "Folder upload continued",
@@ -1022,8 +1045,6 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                        nextAction = dlFldrActionResumeFile
                                }
 
-                               fmt.Printf("Next Action: %v\n", nextAction)
-
                                if _, err := conn.Write([]byte{0, uint8(nextAction)}); err != nil {
                                        return err
                                }
@@ -1067,14 +1088,14 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error {
                                        }
 
                                case dlFldrActionSendFile:
-                                       if _, err := conn.Read(fileSize); err != nil {
+                                       if _, err := io.ReadFull(conn, fileSize); err != nil {
                                                return err
                                        }
 
                                        filePath := dstPath + "/" + fu.FormattedPath()
                                        s.Logger.Infow("Starting file transfer", "path", filePath, "fileNum", i+1, "totalFiles", "zz", "fileSize", binary.BigEndian.Uint32(fileSize))
 
-                                       newFile, err := FS.Create(filePath + ".incomplete")
+                                       newFile, err := s.FS.Create(filePath + ".incomplete")
                                        if err != nil {
                                                return err
                                        }