]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Account for the root
[rbdr/mobius] / hotline / server.go
index 58a9209f1574d90387cc4389f21ea569be14e9c9..a521a6bfecd22adc2a28c68ca1a4b6062ca3bbed 100644 (file)
@@ -9,10 +9,12 @@ import (
        "errors"
        "fmt"
        "golang.org/x/text/encoding/charmap"
+       "golang.org/x/time/rate"
        "io"
        "log"
        "log/slog"
        "net"
+       "os"
        "strings"
        "sync"
        "time"
@@ -36,6 +38,8 @@ type Server struct {
        NetInterface string
        Port         int
 
+       rateLimiters map[string]*rate.Limiter
+
        handlers map[TranType]HandlerFunc
 
        Config Config
@@ -97,6 +101,7 @@ func NewServer(options ...Option) (*Server, error) {
        server := Server{
                handlers:        make(map[TranType]HandlerFunc),
                outbox:          make(chan Transaction),
+               rateLimiters:    make(map[string]*rate.Limiter),
                FS:              &OSFileStore{},
                ChatMgr:         NewMemChatManager(),
                ClientMgr:       NewMemClientMgr(),
@@ -169,7 +174,7 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error
                        )
 
                        if err != nil {
-                               s.Logger.Error("file transfer error", "reason", err)
+                               s.Logger.Error("file transfer error", "err", err)
                        }
                }()
        }
@@ -201,6 +206,10 @@ func (s *Server) processOutbox() {
        }
 }
 
+// perIPRateLimit controls how frequently an IP address can connect before being throttled.
+// 0.5 = 1 connection every 2 seconds
+const perIPRateLimit = rate.Limit(0.5)
+
 func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
        for {
                select {
@@ -215,13 +224,29 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
                        }
 
                        go func() {
+                               ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0]
+
                                connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
                                        remoteAddr: conn.RemoteAddr().String(),
                                })
 
-                               s.Logger.Info("Connection established", "addr", conn.RemoteAddr())
+                               s.Logger.Info("Connection established", "ip", ipAddr)
                                defer conn.Close()
 
+                               // Check if we have an existing rate limit for the IP and create one if we do not.
+                               rl, ok := s.rateLimiters[ipAddr]
+                               if !ok {
+                                       rl = rate.NewLimiter(perIPRateLimit, 1)
+                                       s.rateLimiters[ipAddr] = rl
+                               }
+
+                               // Check if the rate limit is exceeded and close the connection if so.
+                               if !rl.Allow() {
+                                       s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr())
+                                       conn.Close()
+                                       return
+                               }
+
                                if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
                                        if err == io.EOF {
                                                s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())
@@ -240,15 +265,13 @@ const trackerUpdateFrequency = 300
 // registerWithTrackers runs every trackerUpdateFrequency seconds to update the server's tracker entry on all configured
 // trackers.
 func (s *Server) registerWithTrackers(ctx context.Context) {
-       ticker := time.NewTicker(trackerUpdateFrequency * time.Second)
-       defer ticker.Stop()
+       if s.Config.EnableTrackerRegistration {
+               s.Logger.Info("Tracker registration enabled", "trackers", s.Config.Trackers)
+       }
 
        for {
-               select {
-               case <-ctx.Done():
-                       return
-               case <-ticker.C:
-                       if s.Config.EnableTrackerRegistration {
+               if s.Config.EnableTrackerRegistration {
+                       for _, t := range s.Config.Trackers {
                                tr := &TrackerRegistration{
                                        UserCount:   len(s.ClientMgr.List()),
                                        PassID:      s.TrackerPassID,
@@ -257,15 +280,23 @@ func (s *Server) registerWithTrackers(ctx context.Context) {
                                }
                                binary.BigEndian.PutUint16(tr.Port[:], uint16(s.Port))
 
-                               for _, t := range s.Config.Trackers {
-                                       if err := register(&RealDialer{}, t, tr); err != nil {
-                                               s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err)
-                                       }
+                               // Check the tracker string for a password.  This is janky but avoids a breaking change to the Config
+                               // Trackers field.
+                               splitAddr := strings.Split(":", t)
+                               if len(splitAddr) == 3 {
+                                       tr.Password = splitAddr[2]
+                               }
+
+                               if err := register(&RealDialer{}, t, tr); err != nil {
+                                       s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err)
                                }
                        }
                }
+               // Using time.Ticker with for/select would be more idiomatic, but it's super annoying that it doesn't tick on
+               // first pass.  Revist, maybe.
+               // https://github.com/golang/go/issues/17601
+               time.Sleep(trackerUpdateFrequency * time.Second)
        }
-
 }
 
 const (
@@ -397,7 +428,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser
                        return err
                }
 
-               c.Logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version))
+               c.Logger.Info("Incorrect login")
 
                return nil
        }
@@ -523,7 +554,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
                "Name", string(fileTransfer.ClientConn.UserName),
        )
 
-       fullPath, err := ReadPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName)
+       fullPath, err := ReadPath(fileTransfer.FileRoot, fileTransfer.FilePath, fileTransfer.FileName)
        if err != nil {
                return err
        }
@@ -531,7 +562,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
        switch fileTransfer.Type {
        case BannerDownload:
                if _, err := io.Copy(rwc, bytes.NewBuffer(s.Banner)); err != nil {
-                       return fmt.Errorf("error sending Banner: %w", err)
+                       return fmt.Errorf("banner download: %w", err)
                }
        case FileDownload:
                s.Stats.Increment(StatDownloadCounter, StatDownloadsInProgress)
@@ -552,7 +583,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
 
                err = UploadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks)
                if err != nil {
-                       return fmt.Errorf("file upload error: %w", err)
+                       return fmt.Errorf("file upload: %w", err)
                }
 
        case FolderDownload:
@@ -563,7 +594,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
 
                err = DownloadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks)
                if err != nil {
-                       return fmt.Errorf("file upload error: %w", err)
+                       return fmt.Errorf("folder download: %w", err)
                }
 
        case FolderUpload:
@@ -572,17 +603,42 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
                        s.Stats.Decrement(StatUploadsInProgress)
                }()
 
+               var transferSizeValue uint32
+               switch len(fileTransfer.TransferSize) {
+                       case 2: // 16-bit
+                               transferSizeValue = uint32(binary.BigEndian.Uint16(fileTransfer.TransferSize))
+                       case 4: // 32-bit
+                               transferSizeValue = binary.BigEndian.Uint32(fileTransfer.TransferSize)
+                       default:
+                               rLogger.Warn("Unexpected TransferSize length: %d bytes", len(fileTransfer.TransferSize))
+               }
+
                rLogger.Info(
                        "Folder upload started",
                        "dstPath", fullPath,
-                       "TransferSize", binary.BigEndian.Uint32(fileTransfer.TransferSize),
+                       "TransferSize", transferSizeValue,
                        "FolderItemCount", fileTransfer.FolderItemCount,
                )
 
                err = UploadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks)
                if err != nil {
-                       return fmt.Errorf("file upload error: %w", err)
+                       return fmt.Errorf("folder upload: %w", err)
                }
        }
        return nil
 }
+
+func (s *Server) SendAll(t TranType, fields ...Field) {
+       for _, c := range s.ClientMgr.List() {
+               s.outbox <- NewTransaction(t, c.ID, fields...)
+       }
+}
+
+func (s *Server) Shutdown(msg []byte) {
+       s.Logger.Info("Shutdown signal received")
+       s.SendAll(TranDisconnectMsg, NewField(FieldData, msg))
+
+       time.Sleep(3 * time.Second)
+
+       os.Exit(0)
+}