X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/80aed6b19ff0b0927670e459ce5cc7a16ef047ec..fd740bc499ebc6d3a381479316f74cdc736d02de:/hotline/server.go diff --git a/hotline/server.go b/hotline/server.go index 6232253..58a9209 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -9,13 +9,10 @@ import ( "errors" "fmt" "golang.org/x/text/encoding/charmap" - "gopkg.in/yaml.v3" "io" "log" "log/slog" "net" - "os" - "path/filepath" "strings" "sync" "time" @@ -39,9 +36,10 @@ type Server struct { NetInterface string Port int - Config Config - ConfigDir string - Logger *slog.Logger + handlers map[TranType]HandlerFunc + + Config Config + Logger *slog.Logger TrackerPassID [4]byte @@ -51,10 +49,8 @@ type Server struct { outbox chan Transaction - // TODO - Agreement []byte - banner []byte - // END TODO + Agreement io.ReadSeeker + Banner []byte FileTransferMgr FileTransferMgr ChatMgr ChatManager @@ -66,60 +62,57 @@ type Server struct { MessageBoard io.ReadWriteSeeker } -// NewServer constructs a new Server from a config dir -func NewServer(config Config, configDir, netInterface string, netPort int, logger *slog.Logger, fs FileStore) (*Server, error) { - server := Server{ - NetInterface: netInterface, - Port: netPort, - Config: config, - ConfigDir: configDir, - Logger: logger, - outbox: make(chan Transaction), - Stats: NewStats(), - FS: fs, - ChatMgr: NewMemChatManager(), - ClientMgr: NewMemClientMgr(), - FileTransferMgr: NewMemFileTransferMgr(), - } +type Option = func(s *Server) - // generate a new random passID for tracker registration - _, err := rand.Read(server.TrackerPassID[:]) - if err != nil { - return nil, err +func WithConfig(config Config) func(s *Server) { + return func(s *Server) { + s.Config = config } +} - server.Agreement, err = os.ReadFile(filepath.Join(configDir, agreementFile)) - if err != nil { - return nil, err +func WithLogger(logger *slog.Logger) func(s *Server) { + return func(s *Server) { + s.Logger = logger } +} - server.AccountManager, err = NewYAMLAccountManager(filepath.Join(configDir, "Users/")) - if err != nil { - return nil, fmt.Errorf("error loading accounts: %w", err) +// WithPort optionally overrides the default TCP port. +func WithPort(port int) func(s *Server) { + return func(s *Server) { + s.Port = port } +} - // If the FileRoot is an absolute path, use it, otherwise treat as a relative path to the config dir. - if !filepath.IsAbs(server.Config.FileRoot) { - server.Config.FileRoot = filepath.Join(configDir, server.Config.FileRoot) +// WithInterface optionally sets a specific interface to listen on. +func WithInterface(netInterface string) func(s *Server) { + return func(s *Server) { + s.NetInterface = netInterface } +} - server.banner, err = os.ReadFile(filepath.Join(server.ConfigDir, server.Config.BannerFile)) - if err != nil { - return nil, fmt.Errorf("error opening banner: %w", err) - } +type ServerConfig struct { +} - if server.Config.EnableTrackerRegistration { - server.Logger.Info( - "Tracker registration enabled", - "frequency", fmt.Sprintf("%vs", trackerUpdateFrequency), - "trackers", server.Config.Trackers, - ) +func NewServer(options ...Option) (*Server, error) { + server := Server{ + handlers: make(map[TranType]HandlerFunc), + outbox: make(chan Transaction), + FS: &OSFileStore{}, + ChatMgr: NewMemChatManager(), + ClientMgr: NewMemClientMgr(), + FileTransferMgr: NewMemFileTransferMgr(), + Stats: NewStats(), + } - go server.registerWithTrackers() + for _, opt := range options { + opt(&server) } - // Start Client Keepalive go routine - go server.keepaliveHandler() + // generate a new random passID for tracker registration + _, err := rand.Read(server.TrackerPassID[:]) + if err != nil { + return nil, err + } return &server, nil } @@ -129,6 +122,10 @@ func (s *Server) CurrentStats() map[string]interface{} { } func (s *Server) ListenAndServe(ctx context.Context) error { + go s.registerWithTrackers(ctx) + go s.keepaliveHandler(ctx) + go s.processOutbox() + var wg sync.WaitGroup wg.Add(1) @@ -179,7 +176,7 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error } func (s *Server) sendTransaction(t Transaction) error { - client := s.ClientMgr.Get(t.clientID) + client := s.ClientMgr.Get(t.ClientID) if client == nil { return nil @@ -187,7 +184,7 @@ func (s *Server) sendTransaction(t Transaction) error { _, err := io.Copy(client.Connection, &t) if err != nil { - return fmt.Errorf("failed to send transaction to client %v: %v", t.clientID, err) + return fmt.Errorf("failed to send transaction to client %v: %v", t.ClientID, err) } return nil @@ -205,78 +202,107 @@ func (s *Server) processOutbox() { } func (s *Server) Serve(ctx context.Context, ln net.Listener) error { - go s.processOutbox() - for { - conn, err := ln.Accept() - if err != nil { - s.Logger.Error("error accepting connection", "err", err) - } - connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ - remoteAddr: conn.RemoteAddr().String(), - }) + select { + case <-ctx.Done(): + s.Logger.Info("Server shutting down") + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + s.Logger.Error("Error accepting connection", "err", err) + continue + } - go func() { - s.Logger.Info("Connection established", "RemoteAddr", conn.RemoteAddr()) - - defer conn.Close() - if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { - if err == io.EOF { - s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr()) - } else { - s.Logger.Error("error serving request", "RemoteAddr", conn.RemoteAddr(), "err", err) + go func() { + connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ + remoteAddr: conn.RemoteAddr().String(), + }) + + s.Logger.Info("Connection established", "addr", conn.RemoteAddr()) + defer conn.Close() + + if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { + if err == io.EOF { + s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr()) + } else { + s.Logger.Error("Error serving request", "RemoteAddr", conn.RemoteAddr(), "err", err) + } } - } - }() + }() + } } } -const ( - agreementFile = "Agreement.txt" -) +// time in seconds between tracker re-registration +const trackerUpdateFrequency = 300 + +// registerWithTrackers runs every trackerUpdateFrequency seconds to update the server's tracker entry on all configured +// trackers. +func (s *Server) registerWithTrackers(ctx context.Context) { + ticker := time.NewTicker(trackerUpdateFrequency * time.Second) + defer ticker.Stop() -func (s *Server) registerWithTrackers() { for { - tr := &TrackerRegistration{ - UserCount: len(s.ClientMgr.List()), - PassID: s.TrackerPassID, - Name: s.Config.Name, - Description: s.Config.Description, - } - binary.BigEndian.PutUint16(tr.Port[:], uint16(s.Port)) - for _, t := range s.Config.Trackers { - if err := register(&RealDialer{}, t, tr); err != nil { - s.Logger.Error(fmt.Sprintf("unable to register with tracker %v", t), "error", err) + select { + case <-ctx.Done(): + return + case <-ticker.C: + if s.Config.EnableTrackerRegistration { + tr := &TrackerRegistration{ + UserCount: len(s.ClientMgr.List()), + PassID: s.TrackerPassID, + Name: s.Config.Name, + Description: s.Config.Description, + } + binary.BigEndian.PutUint16(tr.Port[:], uint16(s.Port)) + + for _, t := range s.Config.Trackers { + if err := register(&RealDialer{}, t, tr); err != nil { + s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err) + } + } } } - - time.Sleep(trackerUpdateFrequency * time.Second) } -} -// keepaliveHandler -func (s *Server) keepaliveHandler() { - for { - time.Sleep(idleCheckInterval * time.Second) - - for _, c := range s.ClientMgr.List() { - c.mu.Lock() +} - c.IdleTime += idleCheckInterval +const ( + userIdleSeconds = 300 // time in seconds before an inactive user is marked idle + idleCheckInterval = 10 // time in seconds to check for idle users +) - // Check if the user - if c.IdleTime > userIdleSeconds && !c.Flags.IsSet(UserFlagAway) { - c.Flags.Set(UserFlagAway, 1) +// keepaliveHandler runs every idleCheckInterval seconds and increments a user's idle time by idleCheckInterval seconds. +// If the updated idle time exceeds userIdleSeconds and the user was not previously idle, we notify all connected clients +// that the user has gone idle. For most clients, this turns the user grey in the user list. +func (s *Server) keepaliveHandler(ctx context.Context) { + ticker := time.NewTicker(idleCheckInterval * time.Second) + defer ticker.Stop() - c.SendAll( - TranNotifyChangeUser, - NewField(FieldUserID, c.ID[:]), - NewField(FieldUserFlags, c.Flags[:]), - NewField(FieldUserName, c.UserName), - NewField(FieldUserIconID, c.Icon), - ) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + for _, c := range s.ClientMgr.List() { + c.mu.Lock() + c.IdleTime += idleCheckInterval + + // Check if the user + if c.IdleTime > userIdleSeconds && !c.Flags.IsSet(UserFlagAway) { + c.Flags.Set(UserFlagAway, 1) + + c.SendAll( + TranNotifyChangeUser, + NewField(FieldUserID, c.ID[:]), + NewField(FieldUserFlags, c.Flags[:]), + NewField(FieldUserName, c.UserName), + NewField(FieldUserIconID, c.Icon), + ) + } + c.mu.Unlock() } - c.mu.Unlock() } } } @@ -296,18 +322,6 @@ func (s *Server) NewClientConn(conn io.ReadWriteCloser, remoteAddr string) *Clie return clientConn } -// loadFromYAMLFile loads data from a YAML file into the provided data structure. -func loadFromYAMLFile(path string, data interface{}) error { - fh, err := os.Open(path) - if err != nil { - return err - } - defer fh.Close() - - decoder := yaml.NewDecoder(fh) - return decoder.Decode(data) -} - func sendBanMessage(rwc io.Writer, message string) { t := NewTransaction( TranServerMsg, @@ -323,6 +337,10 @@ func sendBanMessage(rwc io.Writer, message string) { func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error { defer dontPanic(s.Logger) + if err := performHandshake(rwc); err != nil { + return fmt.Errorf("perform handshake: %w", err) + } + // Check if remoteAddr is present in the ban list ipAddr := strings.Split(remoteAddr, ":")[0] if isBanned, banUntil := s.BanList.IsBanned(ipAddr); isBanned { @@ -341,10 +359,6 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser } } - if err := performHandshake(rwc); err != nil { - return fmt.Errorf("error performing handshake: %w", err) - } - // Create a new scanner for parsing incoming bytes into transaction tokens scanner := bufio.NewScanner(rwc) scanner.Split(transactionScanner) @@ -372,7 +386,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser login = GuestAccount } - c.logger = s.Logger.With("ip", ipAddr, "login", login) + c.Logger = s.Logger.With("ip", ipAddr, "login", login) // If authentication fails, send error reply and close connection if !c.Authenticate(login, encodedPassword) { @@ -383,7 +397,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser return err } - c.logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version)) + c.Logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version)) return nil } @@ -427,7 +441,10 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldNoServerAgreement, []byte{1})) } } else { - c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, s.Agreement)) + _, _ = c.Server.Agreement.Seek(0, 0) + data, _ := io.ReadAll(c.Server.Agreement) + + c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, data)) } // If the client has provided a username as part of the login, we can infer that it is using the 1.2.3 login @@ -435,8 +452,8 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser if len(c.UserName) != 0 { // Add the client username to the logger. For 1.5+ clients, we don't have this information yet as it comes as // part of TranAgreed - c.logger = c.logger.With("Name", string(c.UserName)) - c.logger.Info("Login successful", "clientVersion", "Not sent (probably 1.2.3)") + c.Logger = c.Logger.With("name", string(c.UserName)) + c.Logger.Info("Login successful") // Notify other clients on the server that the new user has logged in. For 1.5+ clients we don't have this // information yet, so we do it in TranAgreed instead @@ -463,11 +480,11 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser // Scan for new transactions and handle them as they come in. for scanner.Scan() { // Copy the scanner bytes to a new slice to it to avoid a data race when the scanner re-uses the buffer. - buf := make([]byte, len(scanner.Bytes())) - copy(buf, scanner.Bytes()) + tmpBuf := make([]byte, len(scanner.Bytes())) + copy(tmpBuf, scanner.Bytes()) var t Transaction - if _, err := t.Write(buf); err != nil { + if _, err := t.Write(tmpBuf); err != nil { return err } @@ -506,15 +523,15 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro "Name", string(fileTransfer.ClientConn.UserName), ) - fullPath, err := readPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) + fullPath, err := ReadPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) if err != nil { return err } switch fileTransfer.Type { case BannerDownload: - if _, err := io.Copy(rwc, bytes.NewBuffer(s.banner)); err != nil { - return fmt.Errorf("error sending banner: %w", err) + if _, err := io.Copy(rwc, bytes.NewBuffer(s.Banner)); err != nil { + return fmt.Errorf("error sending Banner: %w", err) } case FileDownload: s.Stats.Increment(StatDownloadCounter, StatDownloadsInProgress)