]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/client_conn.go
Update README.md
[rbdr/mobius] / hotline / client_conn.go
index e527eba99db5ef1ad101abda698991572a622243..d32459116074aeff5213ff35601a0e15a17283f6 100644 (file)
@@ -7,64 +7,124 @@ import (
        "golang.org/x/crypto/bcrypt"
        "io"
        "log/slog"
-       "slices"
        "strings"
        "sync"
 )
 
+var clientConnSortFunc = func(a, b *ClientConn) int {
+       return cmp.Compare(
+               binary.BigEndian.Uint16(a.ID[:]),
+               binary.BigEndian.Uint16(b.ID[:]),
+       )
+}
+
 // ClientConn represents a client connected to a Server
 type ClientConn struct {
        Connection io.ReadWriteCloser
        RemoteAddr string
-       ID         [2]byte
-       Icon       []byte
-       flagsMU    sync.Mutex
-       Flags      UserFlags
-       UserName   []byte
-       Account    *Account
-       IdleTime   int
-       Server     *Server
-       Version    []byte
-       Idle       bool
-       AutoReply  []byte
+       ID         ClientID
+       Icon       []byte // TODO: make fixed size of 2
+       Version    []byte // TODO: make fixed size of 2
+
+       FlagsMU sync.Mutex // TODO: move into UserFlags struct
+       Flags   UserFlags
+
+       UserName  []byte
+       Account   *Account
+       IdleTime  int
+       Server    *Server // TODO: consider adding methods to interact with server
+       AutoReply []byte
+
+       ClientFileTransferMgr ClientFileTransferMgr
+
+       Logger *slog.Logger
+
+       mu sync.RWMutex
+}
+
+func (cc *ClientConn) FileRoot() string {
+       if cc.Account.FileRoot != "" {
+               return cc.Account.FileRoot
+       }
+       return cc.Server.Config.FileRoot
+}
+
+type ClientFileTransferMgr struct {
+       transfers map[FileTransferType]map[FileTransferID]*FileTransfer
+
+       mu sync.RWMutex
+}
+
+func NewClientFileTransferMgr() ClientFileTransferMgr {
+       return ClientFileTransferMgr{
+               transfers: map[FileTransferType]map[FileTransferID]*FileTransfer{
+                       FileDownload:   {},
+                       FileUpload:     {},
+                       FolderDownload: {},
+                       FolderUpload:   {},
+                       BannerDownload: {},
+               },
+       }
+}
 
-       transfersMU sync.Mutex
-       transfers   map[int]map[[4]byte]*FileTransfer
+func (cftm *ClientFileTransferMgr) Add(ftType FileTransferType, ft *FileTransfer) {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
 
-       logger *slog.Logger
+       cftm.transfers[ftType][ft.RefNum] = ft
+}
+
+func (cftm *ClientFileTransferMgr) Get(ftType FileTransferType) []FileTransfer {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
+
+       fts := cftm.transfers[ftType]
 
-       sync.Mutex
+       var transfers []FileTransfer
+       for _, ft := range fts {
+               transfers = append(transfers, *ft)
+       }
+
+       return transfers
 }
 
-func (cc *ClientConn) sendAll(t [2]byte, fields ...Field) {
-       for _, c := range cc.Server.Clients {
+func (cftm *ClientFileTransferMgr) Delete(ftType FileTransferType, id FileTransferID) {
+       cftm.mu.Lock()
+       defer cftm.mu.Unlock()
+
+       delete(cftm.transfers[ftType], id)
+}
+
+func (cc *ClientConn) SendAll(t [2]byte, fields ...Field) {
+       for _, c := range cc.Server.ClientMgr.List() {
                cc.Server.outbox <- NewTransaction(t, c.ID, fields...)
        }
 }
 
 func (cc *ClientConn) handleTransaction(transaction Transaction) {
-       if handler, ok := TransactionHandlers[transaction.Type]; ok {
-               cc.logger.Debug("Received Transaction", "RequestType", transaction.Type)
+       if handler, ok := cc.Server.handlers[transaction.Type]; ok {
+               if transaction.Type != TranKeepAlive {
+                       cc.Logger.Info(tranTypeNames[transaction.Type])
+               }
 
                for _, t := range handler(cc, &transaction) {
                        cc.Server.outbox <- t
                }
        }
 
-       cc.Server.mux.Lock()
-       defer cc.Server.mux.Unlock()
-
        if transaction.Type != TranKeepAlive {
+               cc.mu.Lock()
+               defer cc.mu.Unlock()
+
                // reset the user idle timer
                cc.IdleTime = 0
 
                // if user was previously idle, mark as not idle and notify other connected clients that
                // the user is no longer away
-               if cc.Idle {
+               if cc.Flags.IsSet(UserFlagAway) {
                        cc.Flags.Set(UserFlagAway, 0)
-                       cc.Idle = false
 
-                       cc.sendAll(
+                       cc.SendAll(
                                TranNotifyChangeUser,
                                NewField(FieldUserID, cc.ID[:]),
                                NewField(FieldUserFlags, cc.Flags[:]),
@@ -76,7 +136,7 @@ func (cc *ClientConn) handleTransaction(transaction Transaction) {
 }
 
 func (cc *ClientConn) Authenticate(login string, password []byte) bool {
-       if account, ok := cc.Server.Accounts[login]; ok {
+       if account := cc.Server.AccountManager.Get(login); account != nil {
                return bcrypt.CompareHashAndPassword([]byte(account.Password), password) == nil
        }
 
@@ -85,36 +145,30 @@ func (cc *ClientConn) Authenticate(login string, password []byte) bool {
 
 // Authorize checks if the user account has the specified permission
 func (cc *ClientConn) Authorize(access int) bool {
-       cc.Lock()
-       defer cc.Unlock()
        if cc.Account == nil {
                return false
        }
        return cc.Account.Access.IsSet(access)
 }
 
-// Disconnect notifies other clients that a client has disconnected
+// Disconnect notifies other clients that a client has disconnected and closes the connection.
 func (cc *ClientConn) Disconnect() {
-       cc.Server.mux.Lock()
-       delete(cc.Server.Clients, cc.ID)
-       cc.Server.mux.Unlock()
+       cc.Server.ClientMgr.Delete(cc.ID)
 
-       for _, t := range cc.notifyOthers(NewTransaction(TranNotifyDeleteUser, [2]byte{}, NewField(FieldUserID, cc.ID[:]))) {
+       for _, t := range cc.NotifyOthers(NewTransaction(TranNotifyDeleteUser, [2]byte{}, NewField(FieldUserID, cc.ID[:]))) {
                cc.Server.outbox <- t
        }
 
        if err := cc.Connection.Close(); err != nil {
-               cc.Server.Logger.Error("error closing client connection", "RemoteAddr", cc.RemoteAddr)
+               cc.Server.Logger.Debug("error closing client connection", "RemoteAddr", cc.RemoteAddr)
        }
 }
 
-// notifyOthers sends transaction t to other clients connected to the server
-func (cc *ClientConn) notifyOthers(t Transaction) (trans []Transaction) {
-       cc.Server.mux.Lock()
-       defer cc.Server.mux.Unlock()
-       for _, c := range cc.Server.Clients {
+// NotifyOthers sends transaction t to other clients connected to the server
+func (cc *ClientConn) NotifyOthers(t Transaction) (trans []Transaction) {
+       for _, c := range cc.Server.ClientMgr.List() {
                if c.ID != cc.ID {
-                       t.clientID = c.ID
+                       t.ClientID = c.ID
                        trans = append(trans, t)
                }
        }
@@ -126,7 +180,7 @@ func (cc *ClientConn) NewReply(t *Transaction, fields ...Field) Transaction {
        return Transaction{
                IsReply:  1,
                ID:       t.ID,
-               clientID: cc.ID,
+               ClientID: cc.ID,
                Fields:   fields,
        }
 }
@@ -135,7 +189,7 @@ func (cc *ClientConn) NewReply(t *Transaction, fields ...Field) Transaction {
 func (cc *ClientConn) NewErrReply(t *Transaction, errMsg string) []Transaction {
        return []Transaction{
                {
-                       clientID:  cc.ID,
+                       ClientID:  cc.ID,
                        IsReply:   1,
                        ID:        t.ID,
                        ErrorCode: [4]byte{0, 0, 0, 1},
@@ -146,25 +200,6 @@ func (cc *ClientConn) NewErrReply(t *Transaction, errMsg string) []Transaction {
        }
 }
 
-var clientSortFunc = func(a, b *ClientConn) int {
-       return cmp.Compare(
-               binary.BigEndian.Uint16(a.ID[:]),
-               binary.BigEndian.Uint16(b.ID[:]),
-       )
-}
-
-// sortedClients is a utility function that takes a map of *ClientConn and returns a sorted slice of the values.
-// The purpose of this is to ensure that the ordering of client connections is deterministic so that test assertions work.
-func sortedClients(unsortedClients map[[2]byte]*ClientConn) (clients []*ClientConn) {
-       for _, c := range unsortedClients {
-               clients = append(clients, c)
-       }
-
-       slices.SortFunc(clients, clientSortFunc)
-
-       return clients
-}
-
 const userInfoTemplate = `Nickname:   %s
 Name:       %s
 Account:    %s
@@ -187,7 +222,7 @@ Address:    %s
 %s
 `
 
-func formatDownloadList(fts map[[4]byte]*FileTransfer) (s string) {
+func formatDownloadList(fts []FileTransfer) (s string) {
        if len(fts) == 0 {
                return "None.\n"
        }
@@ -200,18 +235,16 @@ func formatDownloadList(fts map[[4]byte]*FileTransfer) (s string) {
 }
 
 func (cc *ClientConn) String() string {
-       cc.transfersMU.Lock()
-       defer cc.transfersMU.Unlock()
        template := fmt.Sprintf(
                userInfoTemplate,
                cc.UserName,
                cc.Account.Name,
                cc.Account.Login,
                cc.RemoteAddr,
-               formatDownloadList(cc.transfers[FileDownload]),
-               formatDownloadList(cc.transfers[FolderDownload]),
-               formatDownloadList(cc.transfers[FileUpload]),
-               formatDownloadList(cc.transfers[FolderUpload]),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FileDownload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FolderDownload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FileUpload)),
+               formatDownloadList(cc.ClientFileTransferMgr.Get(FolderUpload)),
                "None.\n",
        )