]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/client_conn.go
Extensive refactor and clean up
[rbdr/mobius] / hotline / client_conn.go
index e527eba99db5ef1ad101abda698991572a622243..dc589967c7f46ca810790a8cea949119f10dc869 100644 (file)
@@ -7,64 +7,117 @@ import (
        "golang.org/x/crypto/bcrypt"
        "io"
        "log/slog"
-       "slices"
        "strings"
        "sync"
 )
 
+var clientConnSortFunc = func(a, b *ClientConn) int {
+       return cmp.Compare(
+               binary.BigEndian.Uint16(a.ID[:]),
+               binary.BigEndian.Uint16(b.ID[:]),
+       )
+}
+
 // ClientConn represents a client connected to a Server
 type ClientConn struct {
        Connection io.ReadWriteCloser
        RemoteAddr string
-       ID         [2]byte
-       Icon       []byte
-       flagsMU    sync.Mutex
-       Flags      UserFlags
-       UserName   []byte
-       Account    *Account
-       IdleTime   int
-       Server     *Server
-       Version    []byte
-       Idle       bool
-       AutoReply  []byte
-
-       transfersMU sync.Mutex
-       transfers   map[int]map[[4]byte]*FileTransfer
+       ID         ClientID
+       Icon       []byte // TODO: make fixed size of 2
+       Version    []byte // TODO: make fixed size of 2
+
+       flagsMU sync.Mutex // TODO: move into UserFlags struct
+       Flags   UserFlags
+
+       UserName  []byte
+       Account   *Account
+       IdleTime  int
+       Server    *Server // TODO: consider adding methods to interact with server
+       AutoReply []byte
+
+       ClientFileTransferMgr ClientFileTransferMgr
 
        logger *slog.Logger
 
-       sync.Mutex
+       mu sync.RWMutex
+}
+
+type ClientFileTransferMgr struct {
+       transfers map[FileTransferType]map[FileTransferID]*FileTransfer
+
+       mu sync.RWMutex
+}
+
+func NewClientFileTransferMgr() ClientFileTransferMgr {
+       return ClientFileTransferMgr{
+               transfers: map[FileTransferType]map[FileTransferID]*FileTransfer{
+                       FileDownload:   {},
+                       FileUpload:     {},
+                       FolderDownload: {},
+                       FolderUpload:   {},
+                       BannerDownload: {},
+               },
+       }
+}
+
+func (cftm *ClientFileTransferMgr) Add(ftType FileTransferType, ft *FileTransfer) {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
+
+       cftm.transfers[ftType][ft.refNum] = ft
+}
+
+func (cftm *ClientFileTransferMgr) Get(ftType FileTransferType) []FileTransfer {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
+
+       fts := cftm.transfers[ftType]
+
+       var transfers []FileTransfer
+       for _, ft := range fts {
+               transfers = append(transfers, *ft)
+       }
+
+       return transfers
 }
 
-func (cc *ClientConn) sendAll(t [2]byte, fields ...Field) {
-       for _, c := range cc.Server.Clients {
+func (cftm *ClientFileTransferMgr) Delete(ftType FileTransferType, id FileTransferID) {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
+
+       delete(cftm.transfers[ftType], id)
+}
+
+func (cc *ClientConn) SendAll(t [2]byte, fields ...Field) {
+       for _, c := range cc.Server.ClientMgr.List() {
                cc.Server.outbox <- NewTransaction(t, c.ID, fields...)
        }
 }
 
 func (cc *ClientConn) handleTransaction(transaction Transaction) {
        if handler, ok := TransactionHandlers[transaction.Type]; ok {
-               cc.logger.Debug("Received Transaction", "RequestType", transaction.Type)
+               if transaction.Type != TranKeepAlive {
+                       cc.logger.Info(tranTypeNames[transaction.Type])
+               }
 
                for _, t := range handler(cc, &transaction) {
                        cc.Server.outbox <- t
                }
        }
 
-       cc.Server.mux.Lock()
-       defer cc.Server.mux.Unlock()
-
        if transaction.Type != TranKeepAlive {
+               cc.mu.Lock()
+               defer cc.mu.Unlock()
+
                // reset the user idle timer
                cc.IdleTime = 0
 
                // if user was previously idle, mark as not idle and notify other connected clients that
                // the user is no longer away
-               if cc.Idle {
+               if cc.Flags.IsSet(UserFlagAway) {
                        cc.Flags.Set(UserFlagAway, 0)
-                       cc.Idle = false
 
-                       cc.sendAll(
+                       cc.SendAll(
                                TranNotifyChangeUser,
                                NewField(FieldUserID, cc.ID[:]),
                                NewField(FieldUserFlags, cc.Flags[:]),
@@ -76,7 +129,7 @@ func (cc *ClientConn) handleTransaction(transaction Transaction) {
 }
 
 func (cc *ClientConn) Authenticate(login string, password []byte) bool {
-       if account, ok := cc.Server.Accounts[login]; ok {
+       if account := cc.Server.AccountManager.Get(login); account != nil {
                return bcrypt.CompareHashAndPassword([]byte(account.Password), password) == nil
        }
 
@@ -85,8 +138,6 @@ func (cc *ClientConn) Authenticate(login string, password []byte) bool {
 
 // Authorize checks if the user account has the specified permission
 func (cc *ClientConn) Authorize(access int) bool {
-       cc.Lock()
-       defer cc.Unlock()
        if cc.Account == nil {
                return false
        }
@@ -95,11 +146,9 @@ func (cc *ClientConn) Authorize(access int) bool {
 
 // Disconnect notifies other clients that a client has disconnected
 func (cc *ClientConn) Disconnect() {
-       cc.Server.mux.Lock()
-       delete(cc.Server.Clients, cc.ID)
-       cc.Server.mux.Unlock()
+       cc.Server.ClientMgr.Delete(cc.ID)
 
-       for _, t := range cc.notifyOthers(NewTransaction(TranNotifyDeleteUser, [2]byte{}, NewField(FieldUserID, cc.ID[:]))) {
+       for _, t := range cc.NotifyOthers(NewTransaction(TranNotifyDeleteUser, [2]byte{}, NewField(FieldUserID, cc.ID[:]))) {
                cc.Server.outbox <- t
        }
 
@@ -108,11 +157,9 @@ func (cc *ClientConn) Disconnect() {
        }
 }
 
-// notifyOthers sends transaction t to other clients connected to the server
-func (cc *ClientConn) notifyOthers(t Transaction) (trans []Transaction) {
-       cc.Server.mux.Lock()
-       defer cc.Server.mux.Unlock()
-       for _, c := range cc.Server.Clients {
+// NotifyOthers sends transaction t to other clients connected to the server
+func (cc *ClientConn) NotifyOthers(t Transaction) (trans []Transaction) {
+       for _, c := range cc.Server.ClientMgr.List() {
                if c.ID != cc.ID {
                        t.clientID = c.ID
                        trans = append(trans, t)
@@ -146,25 +193,6 @@ func (cc *ClientConn) NewErrReply(t *Transaction, errMsg string) []Transaction {
        }
 }
 
-var clientSortFunc = func(a, b *ClientConn) int {
-       return cmp.Compare(
-               binary.BigEndian.Uint16(a.ID[:]),
-               binary.BigEndian.Uint16(b.ID[:]),
-       )
-}
-
-// sortedClients is a utility function that takes a map of *ClientConn and returns a sorted slice of the values.
-// The purpose of this is to ensure that the ordering of client connections is deterministic so that test assertions work.
-func sortedClients(unsortedClients map[[2]byte]*ClientConn) (clients []*ClientConn) {
-       for _, c := range unsortedClients {
-               clients = append(clients, c)
-       }
-
-       slices.SortFunc(clients, clientSortFunc)
-
-       return clients
-}
-
 const userInfoTemplate = `Nickname:   %s
 Name:       %s
 Account:    %s
@@ -187,7 +215,7 @@ Address:    %s
 %s
 `
 
-func formatDownloadList(fts map[[4]byte]*FileTransfer) (s string) {
+func formatDownloadList(fts []FileTransfer) (s string) {
        if len(fts) == 0 {
                return "None.\n"
        }
@@ -200,18 +228,16 @@ func formatDownloadList(fts map[[4]byte]*FileTransfer) (s string) {
 }
 
 func (cc *ClientConn) String() string {
-       cc.transfersMU.Lock()
-       defer cc.transfersMU.Unlock()
        template := fmt.Sprintf(
                userInfoTemplate,
                cc.UserName,
                cc.Account.Name,
                cc.Account.Login,
                cc.RemoteAddr,
-               formatDownloadList(cc.transfers[FileDownload]),
-               formatDownloadList(cc.transfers[FolderDownload]),
-               formatDownloadList(cc.transfers[FileUpload]),
-               formatDownloadList(cc.transfers[FolderUpload]),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FileDownload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FolderDownload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FileUpload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FolderUpload)),
                "None.\n",
        )