X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/a2ef262a164fc735b9b8471ac0c8001eea2b9bf6..dcd23d5355badf66c34ffd63d3c44734e87ebf17:/hotline/server.go diff --git a/hotline/server.go b/hotline/server.go index 0ee2dd7..c4d8f7f 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -8,19 +8,14 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/go-playground/validator/v10" "golang.org/x/text/encoding/charmap" - "gopkg.in/yaml.v3" "io" "log" "log/slog" "net" "os" - "path" - "path/filepath" "strings" "sync" - "sync/atomic" "time" ) @@ -41,57 +36,97 @@ var txtEncoder = charmap.Macintosh.NewEncoder() type Server struct { NetInterface string Port int - Accounts map[string]*Account - Agreement []byte - Clients map[[2]byte]*ClientConn - fileTransfers map[[4]byte]*FileTransfer + handlers map[TranType]HandlerFunc - Config *Config - ConfigDir string - Logger *slog.Logger - banner []byte + Config Config + Logger *slog.Logger - PrivateChatsMu sync.Mutex - PrivateChats map[[4]byte]*PrivateChat - - nextClientID atomic.Uint32 TrackerPassID [4]byte - statsMu sync.Mutex - Stats *Stats + Stats Counter FS FileStore // Storage backend to use for File storage outbox chan Transaction - mux sync.Mutex - threadedNewsMux sync.Mutex - ThreadedNews *ThreadedNews + Agreement io.ReadSeeker + Banner []byte + + FileTransferMgr FileTransferMgr + ChatMgr ChatManager + ClientMgr ClientManager + AccountManager AccountManager + ThreadedNewsMgr ThreadedNewsMgr + BanList BanMgr + + MessageBoard io.ReadWriteSeeker +} + +type Option = func(s *Server) + +func WithConfig(config Config) func(s *Server) { + return func(s *Server) { + s.Config = config + } +} + +func WithLogger(logger *slog.Logger) func(s *Server) { + return func(s *Server) { + s.Logger = logger + } +} + +// WithPort optionally overrides the default TCP port. +func WithPort(port int) func(s *Server) { + return func(s *Server) { + s.Port = port + } +} - flatNewsMux sync.Mutex - FlatNews []byte +// WithInterface optionally sets a specific interface to listen on. +func WithInterface(netInterface string) func(s *Server) { + return func(s *Server) { + s.NetInterface = netInterface + } +} - banListMU sync.Mutex - banList map[string]*time.Time +type ServerConfig struct { } -func (s *Server) CurrentStats() Stats { - s.statsMu.Lock() - defer s.statsMu.Unlock() +func NewServer(options ...Option) (*Server, error) { + server := Server{ + handlers: make(map[TranType]HandlerFunc), + outbox: make(chan Transaction), + FS: &OSFileStore{}, + ChatMgr: NewMemChatManager(), + ClientMgr: NewMemClientMgr(), + FileTransferMgr: NewMemFileTransferMgr(), + Stats: NewStats(), + } - stats := s.Stats - stats.CurrentlyConnected = len(s.Clients) + for _, opt := range options { + opt(&server) + } + + // generate a new random passID for tracker registration + _, err := rand.Read(server.TrackerPassID[:]) + if err != nil { + return nil, err + } - return *stats + return &server, nil } -type PrivateChat struct { - Subject string - ClientConn map[[2]byte]*ClientConn +func (s *Server) CurrentStats() map[string]interface{} { + return s.Stats.Values() } func (s *Server) ListenAndServe(ctx context.Context) error { + go s.registerWithTrackers(ctx) + go s.keepaliveHandler(ctx) + go s.processOutbox() + var wg sync.WaitGroup wg.Add(1) @@ -142,17 +177,15 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error } func (s *Server) sendTransaction(t Transaction) error { - s.mux.Lock() - client, ok := s.Clients[t.clientID] - s.mux.Unlock() + client := s.ClientMgr.Get(t.ClientID) - if !ok || client == nil { + if client == nil { return nil } _, err := io.Copy(client.Connection, &t) if err != nil { - return fmt.Errorf("failed to send transaction to client %v: %v", t.clientID, err) + return fmt.Errorf("failed to send transaction to client %v: %v", t.ClientID, err) } return nil @@ -170,380 +203,126 @@ func (s *Server) processOutbox() { } func (s *Server) Serve(ctx context.Context, ln net.Listener) error { - go s.processOutbox() - for { - conn, err := ln.Accept() - if err != nil { - s.Logger.Error("error accepting connection", "err", err) - } - connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ - remoteAddr: conn.RemoteAddr().String(), - }) - - go func() { - s.Logger.Info("Connection established", "RemoteAddr", conn.RemoteAddr()) - - defer conn.Close() - if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { - if err == io.EOF { - s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr()) - } else { - s.Logger.Error("error serving request", "RemoteAddr", conn.RemoteAddr(), "err", err) - } + select { + case <-ctx.Done(): + s.Logger.Info("Server shutting down") + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + s.Logger.Error("Error accepting connection", "err", err) + continue } - }() - } -} - -const ( - agreementFile = "Agreement.txt" -) - -// NewServer constructs a new Server from a config dir -// TODO: move config file reads out of this function -func NewServer(configDir, netInterface string, netPort int, logger *slog.Logger, fs FileStore) (*Server, error) { - server := Server{ - NetInterface: netInterface, - Port: netPort, - Accounts: make(map[string]*Account), - Config: new(Config), - Clients: make(map[[2]byte]*ClientConn), - fileTransfers: make(map[[4]byte]*FileTransfer), - PrivateChats: make(map[[4]byte]*PrivateChat), - ConfigDir: configDir, - Logger: logger, - outbox: make(chan Transaction), - Stats: &Stats{Since: time.Now()}, - ThreadedNews: &ThreadedNews{}, - FS: fs, - banList: make(map[string]*time.Time), - } - - var err error - - // generate a new random passID for tracker registration - if _, err := rand.Read(server.TrackerPassID[:]); err != nil { - return nil, err - } - - server.Agreement, err = os.ReadFile(filepath.Join(configDir, agreementFile)) - if err != nil { - return nil, err - } - - if server.FlatNews, err = os.ReadFile(filepath.Join(configDir, "MessageBoard.txt")); err != nil { - return nil, err - } - - // try to load the ban list, but ignore errors as this file may not be present or may be empty - //_ = server.loadBanList(filepath.Join(configDir, "Banlist.yaml")) - _ = loadFromYAMLFile(filepath.Join(configDir, "Banlist.yaml"), &server.banList) + go func() { + connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ + remoteAddr: conn.RemoteAddr().String(), + }) - err = loadFromYAMLFile(filepath.Join(configDir, "ThreadedNews.yaml"), &server.ThreadedNews) - if err != nil { - return nil, fmt.Errorf("error loading threaded news: %w", err) - } - - err = server.loadConfig(filepath.Join(configDir, "config.yaml")) - if err != nil { - return nil, fmt.Errorf("error loading config: %w", err) - } - - if err := server.loadAccounts(filepath.Join(configDir, "Users/")); err != nil { - return nil, err - } - - // If the FileRoot is an absolute path, use it, otherwise treat as a relative path to the config dir. - if !filepath.IsAbs(server.Config.FileRoot) { - server.Config.FileRoot = filepath.Join(configDir, server.Config.FileRoot) - } + s.Logger.Info("Connection established", "addr", conn.RemoteAddr()) + defer conn.Close() - server.banner, err = os.ReadFile(filepath.Join(server.ConfigDir, server.Config.BannerFile)) - if err != nil { - return nil, fmt.Errorf("error opening banner: %w", err) + if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { + if err == io.EOF { + s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr()) + } else { + s.Logger.Error("Error serving request", "RemoteAddr", conn.RemoteAddr(), "err", err) + } + } + }() + } } +} - if server.Config.EnableTrackerRegistration { - server.Logger.Info( - "Tracker registration enabled", - "frequency", fmt.Sprintf("%vs", trackerUpdateFrequency), - "trackers", server.Config.Trackers, - ) +// time in seconds between tracker re-registration +const trackerUpdateFrequency = 300 - go func() { - for { +// registerWithTrackers runs every trackerUpdateFrequency seconds to update the server's tracker entry on all configured +// trackers. +func (s *Server) registerWithTrackers(ctx context.Context) { + for { + if s.Config.EnableTrackerRegistration { + for _, t := range s.Config.Trackers { tr := &TrackerRegistration{ - UserCount: server.userCount(), - PassID: server.TrackerPassID, - Name: server.Config.Name, - Description: server.Config.Description, + UserCount: len(s.ClientMgr.List()), + PassID: s.TrackerPassID, + Name: s.Config.Name, + Description: s.Config.Description, } - binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port)) - for _, t := range server.Config.Trackers { - if err := register(&RealDialer{}, t, tr); err != nil { - server.Logger.Error("unable to register with tracker %v", "error", err) - } - server.Logger.Debug("Sent Tracker registration", "addr", t) + binary.BigEndian.PutUint16(tr.Port[:], uint16(s.Port)) + + // Check the tracker string for a password. This is janky but avoids a breaking change to the Config + // Trackers field. + splitAddr := strings.Split(":", t) + if len(splitAddr) == 3 { + tr.Password = splitAddr[2] } - time.Sleep(trackerUpdateFrequency * time.Second) + if err := register(&RealDialer{}, t, tr); err != nil { + s.Logger.Error(fmt.Sprintf("Unable to register with tracker %v", t), "error", err) + } } - }() + } + // Using time.Ticker with for/select would be more idiomatic, but it's super annoying that it doesn't tick on + // first pass. Revist, maybe. + // https://github.com/golang/go/issues/17601 + time.Sleep(trackerUpdateFrequency * time.Second) } - - // Start Client Keepalive go routine - go server.keepaliveHandler() - - return &server, nil } -func (s *Server) userCount() int { - s.mux.Lock() - defer s.mux.Unlock() +const ( + userIdleSeconds = 300 // time in seconds before an inactive user is marked idle + idleCheckInterval = 10 // time in seconds to check for idle users +) - return len(s.Clients) -} +// keepaliveHandler runs every idleCheckInterval seconds and increments a user's idle time by idleCheckInterval seconds. +// If the updated idle time exceeds userIdleSeconds and the user was not previously idle, we notify all connected clients +// that the user has gone idle. For most clients, this turns the user grey in the user list. +func (s *Server) keepaliveHandler(ctx context.Context) { + ticker := time.NewTicker(idleCheckInterval * time.Second) + defer ticker.Stop() -func (s *Server) keepaliveHandler() { for { - time.Sleep(idleCheckInterval * time.Second) - s.mux.Lock() - - for _, c := range s.Clients { - c.IdleTime += idleCheckInterval - if c.IdleTime > userIdleSeconds && !c.Idle { - c.Idle = true - - c.flagsMU.Lock() - c.Flags.Set(UserFlagAway, 1) - c.flagsMU.Unlock() - c.sendAll( - TranNotifyChangeUser, - NewField(FieldUserID, c.ID[:]), - NewField(FieldUserFlags, c.Flags[:]), - NewField(FieldUserName, c.UserName), - NewField(FieldUserIconID, c.Icon), - ) + select { + case <-ctx.Done(): + return + case <-ticker.C: + for _, c := range s.ClientMgr.List() { + c.mu.Lock() + c.IdleTime += idleCheckInterval + + // Check if the user + if c.IdleTime > userIdleSeconds && !c.Flags.IsSet(UserFlagAway) { + c.Flags.Set(UserFlagAway, 1) + + c.SendAll( + TranNotifyChangeUser, + NewField(FieldUserID, c.ID[:]), + NewField(FieldUserFlags, c.Flags[:]), + NewField(FieldUserName, c.UserName), + NewField(FieldUserIconID, c.Icon), + ) + } + c.mu.Unlock() } } - s.mux.Unlock() - } -} - -func (s *Server) writeBanList() error { - s.banListMU.Lock() - defer s.banListMU.Unlock() - - out, err := yaml.Marshal(s.banList) - if err != nil { - return err - } - err = os.WriteFile( - filepath.Join(s.ConfigDir, "Banlist.yaml"), - out, - 0666, - ) - return err -} - -func (s *Server) writeThreadedNews() error { - s.threadedNewsMux.Lock() - defer s.threadedNewsMux.Unlock() - - out, err := yaml.Marshal(s.ThreadedNews) - if err != nil { - return err } - err = s.FS.WriteFile( - filepath.Join(s.ConfigDir, "ThreadedNews.yaml"), - out, - 0666, - ) - return err } func (s *Server) NewClientConn(conn io.ReadWriteCloser, remoteAddr string) *ClientConn { - s.mux.Lock() - defer s.mux.Unlock() - clientConn := &ClientConn{ Icon: []byte{0, 0}, // TODO: make array type Connection: conn, Server: s, RemoteAddr: remoteAddr, - transfers: map[int]map[[4]byte]*FileTransfer{ - FileDownload: {}, - FileUpload: {}, - FolderDownload: {}, - FolderUpload: {}, - bannerDownload: {}, - }, - } - - s.nextClientID.Add(1) - - binary.BigEndian.PutUint16(clientConn.ID[:], uint16(s.nextClientID.Load())) - s.Clients[clientConn.ID] = clientConn - - return clientConn -} - -// NewUser creates a new user account entry in the server map and config file -func (s *Server) NewUser(login, name, password string, access accessBitmap) error { - s.mux.Lock() - defer s.mux.Unlock() - - account := NewAccount(login, name, password, access) - // Create account file, returning an error if one already exists. - file, err := os.OpenFile( - filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml"), - os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644, - ) - if err != nil { - return fmt.Errorf("error creating account file: %w", err) + ClientFileTransferMgr: NewClientFileTransferMgr(), } - defer file.Close() - b, err := yaml.Marshal(account) - if err != nil { - return err - } - - _, err = file.Write(b) - if err != nil { - return fmt.Errorf("error writing account file: %w", err) - } - - s.Accounts[login] = account - - return nil -} - -func (s *Server) UpdateUser(login, newLogin, name, password string, access accessBitmap) error { - s.mux.Lock() - defer s.mux.Unlock() - - // If the login has changed, rename the account file. - if login != newLogin { - err := os.Rename( - filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml"), - filepath.Join(s.ConfigDir, "Users", path.Join("/", newLogin)+".yaml"), - ) - if err != nil { - return fmt.Errorf("error renaming account file: %w", err) - } - s.Accounts[newLogin] = s.Accounts[login] - s.Accounts[newLogin].Login = newLogin - delete(s.Accounts, login) - } - - account := s.Accounts[newLogin] - account.Access = access - account.Name = name - account.Password = password - - out, err := yaml.Marshal(&account) - if err != nil { - return err - } - - if err := os.WriteFile(filepath.Join(s.ConfigDir, "Users", newLogin+".yaml"), out, 0666); err != nil { - return fmt.Errorf("error writing account file: %w", err) - } - - return nil -} - -// DeleteUser deletes the user account -func (s *Server) DeleteUser(login string) error { - s.mux.Lock() - defer s.mux.Unlock() - - err := s.FS.Remove(filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml")) - if err != nil { - return err - } - - delete(s.Accounts, login) - - return nil -} - -func (s *Server) connectedUsers() []Field { - //s.mux.Lock() - //defer s.mux.Unlock() - - var connectedUsers []Field - for _, c := range sortedClients(s.Clients) { - b, err := io.ReadAll(&User{ - ID: c.ID, - Icon: c.Icon, - Flags: c.Flags[:], - Name: string(c.UserName), - }) - if err != nil { - return nil - } - connectedUsers = append(connectedUsers, NewField(FieldUsernameWithInfo, b)) - } - return connectedUsers -} - -// loadFromYAMLFile loads data from a YAML file into the provided data structure. -func loadFromYAMLFile(path string, data interface{}) error { - fh, err := os.Open(path) - if err != nil { - return err - } - defer fh.Close() - - decoder := yaml.NewDecoder(fh) - return decoder.Decode(data) -} - -// loadAccounts loads account data from disk -func (s *Server) loadAccounts(userDir string) error { - matches, err := filepath.Glob(filepath.Join(userDir, "*.yaml")) - if err != nil { - return err - } - - if len(matches) == 0 { - return fmt.Errorf("no accounts found in directory: %s", userDir) - } + s.ClientMgr.Add(clientConn) - for _, file := range matches { - var account Account - if err = loadFromYAMLFile(file, &account); err != nil { - return fmt.Errorf("error loading account %s: %w", file, err) - } - - s.Accounts[account.Login] = &account - } - return nil -} - -func (s *Server) loadConfig(path string) error { - fh, err := s.FS.Open(path) - if err != nil { - return err - } - - decoder := yaml.NewDecoder(fh) - err = decoder.Decode(s.Config) - if err != nil { - return err - } - - validate := validator.New() - err = validate.Struct(s.Config) - if err != nil { - return err - } - return nil + return clientConn } func sendBanMessage(rwc io.Writer, message string) { @@ -561,9 +340,13 @@ func sendBanMessage(rwc io.Writer, message string) { func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error { defer dontPanic(s.Logger) + if err := performHandshake(rwc); err != nil { + return fmt.Errorf("perform handshake: %w", err) + } + // Check if remoteAddr is present in the ban list ipAddr := strings.Split(remoteAddr, ":")[0] - if banUntil, ok := s.banList[ipAddr]; ok { + if isBanned, banUntil := s.BanList.IsBanned(ipAddr); isBanned { // permaban if banUntil == nil { sendBanMessage(rwc, "You are permanently banned on this server") @@ -579,10 +362,6 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser } } - if err := performHandshake(rwc); err != nil { - return fmt.Errorf("error performing handshake: %w", err) - } - // Create a new scanner for parsing incoming bytes into transaction tokens scanner := bufio.NewScanner(rwc) scanner.Split(transactionScanner) @@ -605,12 +384,12 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser encodedPassword := clientLogin.GetField(FieldUserPassword).Data c.Version = clientLogin.GetField(FieldVersion).Data - login := string(encodeString(clientLogin.GetField(FieldUserLogin).Data)) + login := clientLogin.GetField(FieldUserLogin).DecodeObfuscatedString() if login == "" { login = GuestAccount } - c.logger = s.Logger.With("remoteAddr", remoteAddr, "login", login) + c.Logger = s.Logger.With("ip", ipAddr, "login", login) // If authentication fails, send error reply and close connection if !c.Authenticate(login, encodedPassword) { @@ -621,7 +400,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser return err } - c.logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version)) + c.Logger.Info("Login failed", "clientVersion", fmt.Sprintf("%x", c.Version)) return nil } @@ -630,19 +409,20 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser c.Icon = clientLogin.GetField(FieldUserIconID).Data } - c.Lock() - c.Account = c.Server.Accounts[login] - c.Unlock() + c.Account = c.Server.AccountManager.Get(login) + if c.Account == nil { + return nil + } if clientLogin.GetField(FieldUserName).Data != nil { - if c.Authorize(accessAnyName) { + if c.Authorize(AccessAnyName) { c.UserName = clientLogin.GetField(FieldUserName).Data } else { c.UserName = []byte(c.Account.Name) } } - if c.Authorize(accessDisconUser) { + if c.Authorize(AccessDisconUser) { c.Flags.Set(UserFlagAdmin, 1) } @@ -655,16 +435,19 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser // Send user access privs so client UI knows how to behave c.Server.outbox <- NewTransaction(TranUserAccess, c.ID, NewField(FieldUserAccess, c.Account.Access[:])) - // Accounts with accessNoAgreement do not receive the server agreement on login. The behavior is different between + // Accounts with AccessNoAgreement do not receive the server agreement on login. The behavior is different between // client versions. For 1.2.3 client, we do not send TranShowAgreement. For other client versions, we send // TranShowAgreement but with the NoServerAgreement field set to 1. - if c.Authorize(accessNoAgreement) { + if c.Authorize(AccessNoAgreement) { // If client version is nil, then the client uses the 1.2.3 login behavior if c.Version != nil { c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldNoServerAgreement, []byte{1})) } } else { - c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, s.Agreement)) + _, _ = c.Server.Agreement.Seek(0, 0) + data, _ := io.ReadAll(c.Server.Agreement) + + c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, data)) } // If the client has provided a username as part of the login, we can infer that it is using the 1.2.3 login @@ -672,12 +455,12 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser if len(c.UserName) != 0 { // Add the client username to the logger. For 1.5+ clients, we don't have this information yet as it comes as // part of TranAgreed - c.logger = c.logger.With("Name", string(c.UserName)) - c.logger.Info("Login successful", "clientVersion", "Not sent (probably 1.2.3)") + c.Logger = c.Logger.With("name", string(c.UserName)) + c.Logger.Info("Login successful") // Notify other clients on the server that the new user has logged in. For 1.5+ clients we don't have this // information yet, so we do it in TranAgreed instead - for _, t := range c.notifyOthers( + for _, t := range c.NotifyOthers( NewTransaction( TranNotifyChangeUser, [2]byte{0, 0}, NewField(FieldUserName, c.UserName), @@ -690,21 +473,21 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser } } - c.Server.mux.Lock() - c.Server.Stats.ConnectionCounter += 1 - if len(s.Clients) > c.Server.Stats.ConnectionPeak { - c.Server.Stats.ConnectionPeak = len(s.Clients) + c.Server.Stats.Increment(StatConnectionCounter, StatCurrentlyConnected) + defer c.Server.Stats.Decrement(StatCurrentlyConnected) + + if len(s.ClientMgr.List()) > c.Server.Stats.Get(StatConnectionPeak) { + c.Server.Stats.Set(StatConnectionPeak, len(s.ClientMgr.List())) } - c.Server.mux.Unlock() // Scan for new transactions and handle them as they come in. for scanner.Scan() { // Copy the scanner bytes to a new slice to it to avoid a data race when the scanner re-uses the buffer. - buf := make([]byte, len(scanner.Bytes())) - copy(buf, scanner.Bytes()) + tmpBuf := make([]byte, len(scanner.Bytes())) + copy(tmpBuf, scanner.Bytes()) var t Transaction - if _, err := t.Write(buf); err != nil { + if _, err := t.Write(tmpBuf); err != nil { return err } @@ -713,25 +496,6 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser return nil } -func (s *Server) NewPrivateChat(cc *ClientConn) [4]byte { - s.PrivateChatsMu.Lock() - defer s.PrivateChatsMu.Unlock() - - var randID [4]byte - _, _ = rand.Read(randID[:]) - - s.PrivateChats[randID] = &PrivateChat{ - ClientConn: make(map[[2]byte]*ClientConn), - } - s.PrivateChats[randID].ClientConn[cc.ID] = cc - - return randID -} - -const dlFldrActionSendFile = 1 -const dlFldrActionResumeFile = 2 -const dlFldrActionNextFile = 3 - // handleFileTransfer receives a client net.Conn from the file transfer server, performs the requested transfer type, then closes the connection func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) error { defer dontPanic(s.Logger) @@ -742,10 +506,13 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro return fmt.Errorf("error reading file transfer: %w", err) } + fileTransfer := s.FileTransferMgr.Get(t.ReferenceNumber) + if fileTransfer == nil { + return errors.New("invalid transaction ID") + } + defer func() { - s.mux.Lock() - delete(s.fileTransfers, t.ReferenceNumber) - s.mux.Unlock() + s.FileTransferMgr.Delete(t.ReferenceNumber) // Wait a few seconds before closing the connection: this is a workaround for problems // observed with Windows clients where the client must initiate close of the TCP connection before @@ -753,71 +520,61 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro time.Sleep(3 * time.Second) }() - s.mux.Lock() - fileTransfer, ok := s.fileTransfers[t.ReferenceNumber] - s.mux.Unlock() - if !ok { - return errors.New("invalid transaction ID") - } - - defer func() { - fileTransfer.ClientConn.transfersMU.Lock() - delete(fileTransfer.ClientConn.transfers[fileTransfer.Type], t.ReferenceNumber) - fileTransfer.ClientConn.transfersMU.Unlock() - }() - rLogger := s.Logger.With( "remoteAddr", ctx.Value(contextKeyReq).(requestCtx).remoteAddr, "login", fileTransfer.ClientConn.Account.Login, "Name", string(fileTransfer.ClientConn.UserName), ) - fullPath, err := readPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) + fullPath, err := ReadPath(fileTransfer.FileRoot, fileTransfer.FilePath, fileTransfer.FileName) if err != nil { return err } switch fileTransfer.Type { - case bannerDownload: - if _, err := io.Copy(rwc, bytes.NewBuffer(s.banner)); err != nil { - return fmt.Errorf("error sending banner: %w", err) + case BannerDownload: + if _, err := io.Copy(rwc, bytes.NewBuffer(s.Banner)); err != nil { + return fmt.Errorf("banner download: %w", err) } case FileDownload: - s.Stats.DownloadCounter += 1 - s.Stats.DownloadsInProgress += 1 + s.Stats.Increment(StatDownloadCounter, StatDownloadsInProgress) defer func() { - s.Stats.DownloadsInProgress -= 1 + s.Stats.Decrement(StatDownloadsInProgress) }() err = DownloadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, true) if err != nil { - return fmt.Errorf("file download error: %w", err) + return fmt.Errorf("file download: %w", err) } case FileUpload: - s.Stats.UploadCounter += 1 - s.Stats.UploadsInProgress += 1 - defer func() { s.Stats.UploadsInProgress -= 1 }() + s.Stats.Increment(StatUploadCounter, StatUploadsInProgress) + defer func() { + s.Stats.Decrement(StatUploadsInProgress) + }() err = UploadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("file upload: %w", err) } case FolderDownload: - s.Stats.DownloadCounter += 1 - s.Stats.DownloadsInProgress += 1 - defer func() { s.Stats.DownloadsInProgress -= 1 }() + s.Stats.Increment(StatDownloadCounter, StatDownloadsInProgress) + defer func() { + s.Stats.Decrement(StatDownloadsInProgress) + }() err = DownloadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("folder download: %w", err) } case FolderUpload: - s.Stats.UploadCounter += 1 - s.Stats.UploadsInProgress += 1 - defer func() { s.Stats.UploadsInProgress -= 1 }() + s.Stats.Increment(StatUploadCounter, StatUploadsInProgress) + defer func() { + s.Stats.Decrement(StatUploadsInProgress) + }() + rLogger.Info( "Folder upload started", "dstPath", fullPath, @@ -827,8 +584,23 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro err = UploadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return fmt.Errorf("file upload error: %w", err) + return fmt.Errorf("folder upload: %w", err) } } return nil } + +func (s *Server) SendAll(t TranType, fields ...Field) { + for _, c := range s.ClientMgr.List() { + s.outbox <- NewTransaction(t, c.ID, fields...) + } +} + +func (s *Server) Shutdown(msg []byte) { + s.Logger.Info("Shutdown signal received") + s.SendAll(TranDisconnectMsg, NewField(FieldData, msg)) + + time.Sleep(3 * time.Second) + + os.Exit(0) +}