X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/fd740bc499ebc6d3a381479316f74cdc736d02de..e59041902fbc6fb62b5c9ff8e9b4849d4bf853ea:/hotline/server.go?ds=sidebyside diff --git a/hotline/server.go b/hotline/server.go index 58a9209..a521a6b 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -9,10 +9,12 @@ import ( "errors" "fmt" "golang.org/x/text/encoding/charmap" + "golang.org/x/time/rate" "io" "log" "log/slog" "net" + "os" "strings" "sync" "time" @@ -36,6 +38,8 @@ type Server struct { NetInterface string Port int + rateLimiters map[string]*rate.Limiter + handlers map[TranType]HandlerFunc Config Config @@ -97,6 +101,7 @@ func NewServer(options ...Option) (*Server, error) { server := Server{ handlers: make(map[TranType]HandlerFunc), outbox: make(chan Transaction), + rateLimiters: make(map[string]*rate.Limiter), FS: &OSFileStore{}, ChatMgr: NewMemChatManager(), ClientMgr: NewMemClientMgr(), @@ -169,7 +174,7 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error ) if err != nil { - s.Logger.Error("file transfer error", "reason", err) + s.Logger.Error("file transfer error", "err", err) } }() } @@ -201,6 +206,10 @@ func (s *Server) processOutbox() { } } +// perIPRateLimit controls how frequently an IP address can connect before being throttled. +// 0.5 = 1 connection every 2 seconds +const perIPRateLimit = rate.Limit(0.5) + func (s *Server) Serve(ctx context.Context, ln net.Listener) error { for { select { @@ -215,13 +224,29 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { } go func() { + ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0] + connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ remoteAddr: conn.RemoteAddr().String(), }) - s.Logger.Info("Connection established", "addr", conn.RemoteAddr()) + s.Logger.Info("Connection established", "ip", ipAddr) defer conn.Close() + // Check if we have an existing rate limit for the IP and create one if we do not. + rl, ok := s.rateLimiters[ipAddr] + if !ok { + rl = rate.NewLimiter(perIPRateLimit, 1) + s.rateLimiters[ipAddr] = rl + } + + // Check if the rate limit is exceeded and close the connection if so. + if !rl.Allow() { + s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr()) + conn.Close() + return + } + if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { if err == io.EOF { s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr()) @@ -240,15 +265,13 @@ const trackerUpdateFrequency = 300 // registerWithTrackers runs every trackerUpdateFrequency seconds to update the server's tracker entry on all configured // trackers. func (s *Server) registerWithTrackers(ctx context.Context) { - ticker := time.NewTicker(trackerUpdateFrequency * time.Second) - defer ticker.Stop() + if s.Config.EnableTrackerRegistration { + s.Logger.Info("Tracker registration enabled", "trackers", s.Config.Trackers) + } for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if s.Config.EnableTrackerRegistration { + if s.Config.EnableTrackerRegistration { + for _, t := range s.Config.Trackers { tr := &TrackerRegistration{ UserCount: len(s.ClientMgr.List()), PassID: s.TrackerPassID, @@ -257,15 +280,23 @@ func (s *Server) registerWithTrackers(ctx context.Context) { } binary.BigEndian.PutUint16(tr.Port[:], uint16(s.Port)) - for _, t := range s.Config.Trackers { - if err := register(&RealDialer{}, t, tr); err != nil { - s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err) - } + // Check the tracker string for a password. This is janky but avoids a breaking change to the Config + // Trackers field. + splitAddr := strings.Split(":", t) + if len(splitAddr) == 3 { + tr.Password = splitAddr[2] + } + + if err := register(&RealDialer{}, t, tr); err != nil { + s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err) } } } + // Using time.Ticker with for/select would be more idiomatic, but it's super annoying that it doesn't tick on + // first pass. Revist, maybe. + // https://github.com/golang/go/issues/17601 + time.Sleep(trackerUpdateFrequency * time.Second) } - } const ( @@ -397,7 +428,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser return err } - c.Logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version)) + c.Logger.Info("Incorrect login") return nil } @@ -523,7 +554,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro "Name", string(fileTransfer.ClientConn.UserName), ) - fullPath, err := ReadPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) + fullPath, err := ReadPath(fileTransfer.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) if err != nil { return err } @@ -531,7 +562,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro switch fileTransfer.Type { case BannerDownload: if _, err := io.Copy(rwc, bytes.NewBuffer(s.Banner)); err != nil { - return fmt.Errorf("error sending Banner: %w", err) + return fmt.Errorf("banner download: %w", err) } case FileDownload: s.Stats.Increment(StatDownloadCounter, StatDownloadsInProgress) @@ -552,7 +583,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro err = UploadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("file upload: %w", err) } case FolderDownload: @@ -563,7 +594,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro err = DownloadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("folder download: %w", err) } case FolderUpload: @@ -572,17 +603,42 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro s.Stats.Decrement(StatUploadsInProgress) }() + var transferSizeValue uint32 + switch len(fileTransfer.TransferSize) { + case 2: // 16-bit + transferSizeValue = uint32(binary.BigEndian.Uint16(fileTransfer.TransferSize)) + case 4: // 32-bit + transferSizeValue = binary.BigEndian.Uint32(fileTransfer.TransferSize) + default: + rLogger.Warn("Unexpected TransferSize length: %d bytes", len(fileTransfer.TransferSize)) + } + rLogger.Info( "Folder upload started", "dstPath", fullPath, - "TransferSize", binary.BigEndian.Uint32(fileTransfer.TransferSize), + "TransferSize", transferSizeValue, "FolderItemCount", fileTransfer.FolderItemCount, ) err = UploadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("folder upload: %w", err) } } return nil } + +func (s *Server) SendAll(t TranType, fields ...Field) { + for _, c := range s.ClientMgr.List() { + s.outbox <- NewTransaction(t, c.ID, fields...) + } +} + +func (s *Server) Shutdown(msg []byte) { + s.Logger.Info("Shutdown signal received") + s.SendAll(TranDisconnectMsg, NewField(FieldData, msg)) + + time.Sleep(3 * time.Second) + + os.Exit(0) +}