From: Jeff Halter Date: Mon, 24 Jun 2024 23:23:56 +0000 (-0700) Subject: Refactoring, cleanup, test backfilling X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/a2ef262a164fc735b9b8471ac0c8001eea2b9bf6?ds=sidebyside Refactoring, cleanup, test backfilling --- diff --git a/cmd/mobius-hotline-server/main.go b/cmd/mobius-hotline-server/main.go index c34dc71..bf64848 100644 --- a/cmd/mobius-hotline-server/main.go +++ b/cmd/mobius-hotline-server/main.go @@ -28,6 +28,7 @@ var logLevels = map[string]slog.Level{ "error": slog.LevelError, } +// Values swapped in by go-releaser at build time var ( version = "dev" commit = "none" @@ -35,22 +36,22 @@ var ( ) func main() { - ctx, cancel := context.WithCancel(context.Background()) + ctx, _ := context.WithCancel(context.Background()) // TODO: implement graceful shutdown by closing context - // c := make(chan os.Signal, 1) - // signal.Notify(c, os.Interrupt) - // defer func() { - // signal.Stop(c) - // cancel() - // }() - // go func() { - // select { - // case <-c: - // cancel() - // case <-ctx.Done(): - // } - // }() + //c := make(chan os.Signal, 1) + //signal.Notify(c, os.Interrupt) + //defer func() { + // signal.Stop(c) + // cancel() + //}() + //go func() { + // select { + // case <-c: + // cancel() + // case <-ctx.Done(): + // } + //}() netInterface := flag.String("interface", "", "IP addr of interface to listen on. Defaults to all interfaces.") basePort := flag.Int("bind", defaultPort, "Base Hotline server port. File transfer port is base port + 1.") @@ -129,7 +130,7 @@ func main() { ) // Serve Hotline requests until program exit - log.Fatal(srv.ListenAndServe(ctx, cancel)) + log.Fatal(srv.ListenAndServe(ctx)) } type statHandler struct { @@ -166,7 +167,25 @@ func defaultConfigPath() string { return cfgPath } -// TODO: Simplify this mess. Why is it so difficult to recursively copy a directory? +// copyFile copies a file from src to dst. If dst does not exist, it is created. +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destinationFile, err := os.Create(dst) + if err != nil { + return err + } + defer destinationFile.Close() + + _, err = io.Copy(destinationFile, sourceFile) + return err +} + +// copyDir recursively copies a directory tree, attempting to preserve permissions. func copyDir(src, dst string) error { entries, err := cfgTemplate.ReadDir(src) if err != nil { diff --git a/hotline/account.go b/hotline/account.go index 18965ed..7b2aafe 100644 --- a/hotline/account.go +++ b/hotline/account.go @@ -5,7 +5,6 @@ import ( "fmt" "golang.org/x/crypto/bcrypt" "io" - "log" "slices" ) @@ -20,6 +19,15 @@ type Account struct { readOffset int // Internal offset to track read progress } +func NewAccount(login, name, password string, access accessBitmap) *Account { + return &Account{ + Login: login, + Name: name, + Password: hashAndSalt([]byte(password)), + Access: access, + } +} + // Read implements io.Reader interface for Account func (a *Account) Read(p []byte) (int, error) { fields := []Field{ @@ -57,10 +65,7 @@ func (a *Account) Read(p []byte) (int, error) { // hashAndSalt generates a password hash from a users obfuscated plaintext password func hashAndSalt(pwd []byte) string { - hash, err := bcrypt.GenerateFromPassword(pwd, bcrypt.MinCost) - if err != nil { - log.Println(err) - } + hash, _ := bcrypt.GenerateFromPassword(pwd, bcrypt.MinCost) return string(hash) } diff --git a/hotline/client.go b/hotline/client.go index cd6fa12..2389ee5 100644 --- a/hotline/client.go +++ b/hotline/client.go @@ -29,23 +29,23 @@ type Client struct { Connection net.Conn Logger *slog.Logger Pref *ClientPrefs - Handlers map[uint16]ClientHandler - activeTasks map[uint32]*Transaction + Handlers map[[2]byte]ClientHandler + activeTasks map[[4]byte]*Transaction } type ClientHandler func(context.Context, *Client, *Transaction) ([]Transaction, error) -func (c *Client) HandleFunc(transactionID uint16, handler ClientHandler) { - c.Handlers[transactionID] = handler +func (c *Client) HandleFunc(tranType [2]byte, handler ClientHandler) { + c.Handlers[tranType] = handler } func NewClient(username string, logger *slog.Logger) *Client { c := &Client{ Logger: logger, - activeTasks: make(map[uint32]*Transaction), - Handlers: make(map[uint16]ClientHandler), + activeTasks: make(map[[4]byte]*Transaction), + Handlers: make(map[[2]byte]ClientHandler), + Pref: &ClientPrefs{Username: username}, } - c.Pref = &ClientPrefs{Username: username} return c } @@ -79,8 +79,8 @@ func (c *Client) Connect(address, login, passwd string) (err error) { // Authenticate (send TranLogin 107) err = c.Send( - *NewTransaction( - TranLogin, nil, + NewTransaction( + TranLogin, [2]byte{0, 0}, NewField(FieldUserName, []byte(c.Pref.Username)), NewField(FieldUserIconID, c.Pref.IconBytes()), NewField(FieldUserLogin, encodeString([]byte(login))), @@ -102,7 +102,7 @@ const keepaliveInterval = 300 * time.Second func (c *Client) keepalive() error { for { time.Sleep(keepaliveInterval) - _ = c.Send(*NewTransaction(TranKeepAlive, nil)) + _ = c.Send(NewTransaction(TranKeepAlive, [2]byte{})) } } @@ -146,7 +146,7 @@ func (c *Client) Send(t Transaction) error { // if transaction is NOT reply, add it to the list to transactions we're expecting a response for if t.IsReply == 0 { - c.activeTasks[binary.BigEndian.Uint32(t.ID[:])] = &t + c.activeTasks[t.ID] = &t } n, err := io.Copy(c.Connection, &t) @@ -165,16 +165,15 @@ func (c *Client) Send(t Transaction) error { func (c *Client) HandleTransaction(ctx context.Context, t *Transaction) error { var origT Transaction if t.IsReply == 1 { - requestID := binary.BigEndian.Uint32(t.ID[:]) - origT = *c.activeTasks[requestID] + origT = *c.activeTasks[t.ID] t.Type = origT.Type } - if handler, ok := c.Handlers[binary.BigEndian.Uint16(t.Type[:])]; ok { + if handler, ok := c.Handlers[t.Type]; ok { c.Logger.Debug( "Received transaction", "IsReply", t.IsReply, - "type", binary.BigEndian.Uint16(t.Type[:]), + "type", t.Type[:], ) outT, err := handler(ctx, c, t) if err != nil { @@ -189,7 +188,7 @@ func (c *Client) HandleTransaction(ctx context.Context, t *Transaction) error { c.Logger.Debug( "Unimplemented transaction type", "IsReply", t.IsReply, - "type", binary.BigEndian.Uint16(t.Type[:]), + "type", t.Type[:], ) } diff --git a/hotline/client_conn.go b/hotline/client_conn.go index 0a7768f..e527eba 100644 --- a/hotline/client_conn.go +++ b/hotline/client_conn.go @@ -1,38 +1,25 @@ package hotline import ( + "cmp" "encoding/binary" "fmt" "golang.org/x/crypto/bcrypt" "io" "log/slog" - "math/big" - "sort" + "slices" "strings" "sync" ) -type byClientID []*ClientConn - -func (s byClientID) Len() int { - return len(s) -} - -func (s byClientID) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - -func (s byClientID) Less(i, j int) bool { - return s[i].uint16ID() < s[j].uint16ID() -} - // ClientConn represents a client connected to a Server type ClientConn struct { Connection io.ReadWriteCloser RemoteAddr string - ID *[]byte + ID [2]byte Icon []byte - Flags []byte + flagsMU sync.Mutex + Flags UserFlags UserName []byte Account *Account IdleTime int @@ -45,78 +32,47 @@ type ClientConn struct { transfers map[int]map[[4]byte]*FileTransfer logger *slog.Logger + + sync.Mutex } -func (cc *ClientConn) sendAll(t int, fields ...Field) { - for _, c := range sortedClients(cc.Server.Clients) { - cc.Server.outbox <- *NewTransaction(t, c.ID, fields...) +func (cc *ClientConn) sendAll(t [2]byte, fields ...Field) { + for _, c := range cc.Server.Clients { + cc.Server.outbox <- NewTransaction(t, c.ID, fields...) } } -func (cc *ClientConn) handleTransaction(transaction Transaction) error { - requestNum := binary.BigEndian.Uint16(transaction.Type[:]) - if handler, ok := TransactionHandlers[requestNum]; ok { - for _, reqField := range handler.RequiredFields { - field := transaction.GetField(reqField.ID) - - // Validate that required field is present - if field.ID == [2]byte{0, 0} { - cc.logger.Error( - "Missing required field", - "RequestType", handler.Name, "FieldID", reqField.ID, - ) - return nil - } - - if len(field.Data) < reqField.minLen { - cc.logger.Info( - "Field does not meet minLen", - "RequestType", handler.Name, "FieldID", reqField.ID, - ) - return nil - } - } - - cc.logger.Debug("Received Transaction", "RequestType", handler.Name) +func (cc *ClientConn) handleTransaction(transaction Transaction) { + if handler, ok := TransactionHandlers[transaction.Type]; ok { + cc.logger.Debug("Received Transaction", "RequestType", transaction.Type) - transactions, err := handler.Handler(cc, &transaction) - if err != nil { - return fmt.Errorf("error handling transaction: %w", err) - } - for _, t := range transactions { + for _, t := range handler(cc, &transaction) { cc.Server.outbox <- t } - } else { - cc.logger.Error( - "Unimplemented transaction type received", "RequestID", requestNum) } cc.Server.mux.Lock() defer cc.Server.mux.Unlock() - if requestNum != TranKeepAlive { + if transaction.Type != TranKeepAlive { // reset the user idle timer cc.IdleTime = 0 // if user was previously idle, mark as not idle and notify other connected clients that // the user is no longer away if cc.Idle { - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(cc.Flags))) - flagBitmap.SetBit(flagBitmap, UserFlagAway, 0) - binary.BigEndian.PutUint16(cc.Flags, uint16(flagBitmap.Int64())) + cc.Flags.Set(UserFlagAway, 0) cc.Idle = false cc.sendAll( TranNotifyChangeUser, - NewField(FieldUserID, *cc.ID), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserID, cc.ID[:]), + NewField(FieldUserFlags, cc.Flags[:]), NewField(FieldUserName, cc.UserName), NewField(FieldUserIconID, cc.Icon), ) } } - - return nil } func (cc *ClientConn) Authenticate(login string, password []byte) bool { @@ -127,24 +83,23 @@ func (cc *ClientConn) Authenticate(login string, password []byte) bool { return false } -func (cc *ClientConn) uint16ID() uint16 { - id, _ := byteToInt(*cc.ID) - return uint16(id) -} - // Authorize checks if the user account has the specified permission func (cc *ClientConn) Authorize(access int) bool { + cc.Lock() + defer cc.Unlock() + if cc.Account == nil { + return false + } return cc.Account.Access.IsSet(access) } // Disconnect notifies other clients that a client has disconnected func (cc *ClientConn) Disconnect() { cc.Server.mux.Lock() - defer cc.Server.mux.Unlock() + delete(cc.Server.Clients, cc.ID) + cc.Server.mux.Unlock() - delete(cc.Server.Clients, binary.BigEndian.Uint16(*cc.ID)) - - for _, t := range cc.notifyOthers(*NewTransaction(TranNotifyDeleteUser, nil, NewField(FieldUserID, *cc.ID))) { + for _, t := range cc.notifyOthers(NewTransaction(TranNotifyDeleteUser, [2]byte{}, NewField(FieldUserID, cc.ID[:]))) { cc.Server.outbox <- t } @@ -155,7 +110,9 @@ func (cc *ClientConn) Disconnect() { // notifyOthers sends transaction t to other clients connected to the server func (cc *ClientConn) notifyOthers(t Transaction) (trans []Transaction) { - for _, c := range sortedClients(cc.Server.Clients) { + cc.Server.mux.Lock() + defer cc.Server.mux.Unlock() + for _, c := range cc.Server.Clients { if c.ID != cc.ID { t.clientID = c.ID trans = append(trans, t) @@ -167,36 +124,44 @@ func (cc *ClientConn) notifyOthers(t Transaction) (trans []Transaction) { // NewReply returns a reply Transaction with fields for the ClientConn func (cc *ClientConn) NewReply(t *Transaction, fields ...Field) Transaction { return Transaction{ - IsReply: 0x01, - Type: [2]byte{0x00, 0x00}, - ID: t.ID, - clientID: cc.ID, - ErrorCode: [4]byte{0, 0, 0, 0}, - Fields: fields, + IsReply: 1, + ID: t.ID, + clientID: cc.ID, + Fields: fields, } } // NewErrReply returns an error reply Transaction with errMsg -func (cc *ClientConn) NewErrReply(t *Transaction, errMsg string) Transaction { - return Transaction{ - clientID: cc.ID, - IsReply: 0x01, - Type: [2]byte{0, 0}, - ID: t.ID, - ErrorCode: [4]byte{0, 0, 0, 1}, - Fields: []Field{ - NewField(FieldError, []byte(errMsg)), +func (cc *ClientConn) NewErrReply(t *Transaction, errMsg string) []Transaction { + return []Transaction{ + { + clientID: cc.ID, + IsReply: 1, + ID: t.ID, + ErrorCode: [4]byte{0, 0, 0, 1}, + Fields: []Field{ + NewField(FieldError, []byte(errMsg)), + }, }, } } +var clientSortFunc = func(a, b *ClientConn) int { + return cmp.Compare( + binary.BigEndian.Uint16(a.ID[:]), + binary.BigEndian.Uint16(b.ID[:]), + ) +} + // sortedClients is a utility function that takes a map of *ClientConn and returns a sorted slice of the values. // The purpose of this is to ensure that the ordering of client connections is deterministic so that test assertions work. -func sortedClients(unsortedClients map[uint16]*ClientConn) (clients []*ClientConn) { +func sortedClients(unsortedClients map[[2]byte]*ClientConn) (clients []*ClientConn) { for _, c := range unsortedClients { clients = append(clients, c) } - sort.Sort(byClientID(clients)) + + slices.SortFunc(clients, clientSortFunc) + return clients } diff --git a/hotline/field.go b/hotline/field.go index 2fcb41a..7bba0d7 100644 --- a/hotline/field.go +++ b/hotline/field.go @@ -2,68 +2,71 @@ package hotline import ( "encoding/binary" + "errors" "io" "slices" ) // List of Hotline protocol field types taken from the official 1.9 protocol document -const ( - FieldError = 100 - FieldData = 101 - FieldUserName = 102 - FieldUserID = 103 - FieldUserIconID = 104 - FieldUserLogin = 105 - FieldUserPassword = 106 - FieldRefNum = 107 - FieldTransferSize = 108 - FieldChatOptions = 109 - FieldUserAccess = 110 - FieldUserAlias = 111 // TODO: implement - FieldUserFlags = 112 - FieldOptions = 113 - FieldChatID = 114 - FieldChatSubject = 115 - FieldWaitingCount = 116 - FieldBannerType = 152 - FieldNoServerAgreement = 152 - FieldVersion = 160 - FieldCommunityBannerID = 161 - FieldServerName = 162 - FieldFileNameWithInfo = 200 - FieldFileName = 201 - FieldFilePath = 202 - FieldFileResumeData = 203 - FieldFileTransferOptions = 204 - FieldFileTypeString = 205 - FieldFileCreatorString = 206 - FieldFileSize = 207 - FieldFileCreateDate = 208 - FieldFileModifyDate = 209 - FieldFileComment = 210 - FieldFileNewName = 211 - FieldFileNewPath = 212 - FieldFileType = 213 - FieldQuotingMsg = 214 - FieldAutomaticResponse = 215 - FieldFolderItemCount = 220 - FieldUsernameWithInfo = 300 - FieldNewsArtListData = 321 - FieldNewsCatName = 322 - FieldNewsCatListData15 = 323 - FieldNewsPath = 325 - FieldNewsArtID = 326 - FieldNewsArtDataFlav = 327 - FieldNewsArtTitle = 328 - FieldNewsArtPoster = 329 - FieldNewsArtDate = 330 - FieldNewsArtPrevArt = 331 - FieldNewsArtNextArt = 332 - FieldNewsArtData = 333 - FieldNewsArtFlags = 334 // TODO: what is this used for? - FieldNewsArtParentArt = 335 - FieldNewsArt1stChildArt = 336 - FieldNewsArtRecurseDel = 337 // TODO: implement news article recusive deletion +var ( + FieldError = [2]byte{0x00, 0x64} // 100 + FieldData = [2]byte{0x00, 0x65} // 101 + FieldUserName = [2]byte{0x00, 0x66} // 102 + FieldUserID = [2]byte{0x00, 0x67} // 103 + FieldUserIconID = [2]byte{0x00, 0x68} // 104 + FieldUserLogin = [2]byte{0x00, 0x69} // 105 + FieldUserPassword = [2]byte{0x00, 0x6A} // 106 + FieldRefNum = [2]byte{0x00, 0x6B} // 107 + FieldTransferSize = [2]byte{0x00, 0x6C} // 108 + FieldChatOptions = [2]byte{0x00, 0x6D} // 109 + FieldUserAccess = [2]byte{0x00, 0x6E} // 110 + FieldUserFlags = [2]byte{0x00, 0x70} // 112 + FieldOptions = [2]byte{0x00, 0x71} // 113 + FieldChatID = [2]byte{0x00, 0x72} // 114 + FieldChatSubject = [2]byte{0x00, 0x73} // 115 + FieldWaitingCount = [2]byte{0x00, 0x74} // 116 + FieldBannerType = [2]byte{0x00, 0x98} // 152 + FieldNoServerAgreement = [2]byte{0x00, 0x98} // 152 + FieldVersion = [2]byte{0x00, 0xA0} // 160 + FieldCommunityBannerID = [2]byte{0x00, 0xA1} // 161 + FieldServerName = [2]byte{0x00, 0xA2} // 162 + FieldFileNameWithInfo = [2]byte{0x00, 0xC8} // 200 + FieldFileName = [2]byte{0x00, 0xC9} // 201 + FieldFilePath = [2]byte{0x00, 0xCA} // 202 + FieldFileResumeData = [2]byte{0x00, 0xCB} // 203 + FieldFileTransferOptions = [2]byte{0x00, 0xCC} // 204 + FieldFileTypeString = [2]byte{0x00, 0xCD} // 205 + FieldFileCreatorString = [2]byte{0x00, 0xCE} // 206 + FieldFileSize = [2]byte{0x00, 0xCF} // 207 + FieldFileCreateDate = [2]byte{0x00, 0xD0} // 208 + FieldFileModifyDate = [2]byte{0x00, 0xD1} // 209 + FieldFileComment = [2]byte{0x00, 0xD2} // 210 + FieldFileNewName = [2]byte{0x00, 0xD3} // 211 + FieldFileNewPath = [2]byte{0x00, 0xD4} // 212 + FieldFileType = [2]byte{0x00, 0xD5} // 213 + FieldQuotingMsg = [2]byte{0x00, 0xD6} // 214 + FieldAutomaticResponse = [2]byte{0x00, 0xD7} // 215 + FieldFolderItemCount = [2]byte{0x00, 0xDC} // 220 + FieldUsernameWithInfo = [2]byte{0x01, 0x2C} // 300 + FieldNewsArtListData = [2]byte{0x01, 0x41} // 321 + FieldNewsCatName = [2]byte{0x01, 0x42} // 322 + FieldNewsCatListData15 = [2]byte{0x01, 0x43} // 323 + FieldNewsPath = [2]byte{0x01, 0x45} // 325 + FieldNewsArtID = [2]byte{0x01, 0x46} // 326 + FieldNewsArtDataFlav = [2]byte{0x01, 0x47} // 327 + FieldNewsArtTitle = [2]byte{0x01, 0x48} // 328 + FieldNewsArtPoster = [2]byte{0x01, 0x49} // 329 + FieldNewsArtDate = [2]byte{0x01, 0x4A} // 330 + FieldNewsArtPrevArt = [2]byte{0x01, 0x4B} // 331 + FieldNewsArtNextArt = [2]byte{0x01, 0x4C} // 332 + FieldNewsArtData = [2]byte{0x01, 0x4D} // 333 + FieldNewsArtParentArt = [2]byte{0x01, 0x4F} // 335 + FieldNewsArt1stChildArt = [2]byte{0x01, 0x50} // 336 + + // These fields are documented, but seemingly unused. + // FieldUserAlias = [2]byte{0x00, 0x6F} // 111 + // FieldNewsArtFlags = [2]byte{0x01, 0x4E} // 334 + // FieldNewsArtRecurseDel = [2]byte{0x01, 0x51} // 337 ) type Field struct { @@ -74,16 +77,16 @@ type Field struct { readOffset int // Internal offset to track read progress } -type requiredField struct { - ID int - minLen int -} +func NewField(id [2]byte, data []byte) Field { + f := Field{ + ID: id, + Data: make([]byte, len(data)), + } -func NewField(id uint16, data []byte) Field { - f := Field{Data: data} - binary.BigEndian.PutUint16(f.ID[:], id) - binary.BigEndian.PutUint16(f.FieldSize[:], uint16(len(data))) + // Copy instead of assigning to avoid data race when the field is read in another go routine. + copy(f.Data, data) + binary.BigEndian.PutUint16(f.FieldSize[:], uint16(len(data))) return f } @@ -93,7 +96,7 @@ func fieldScanner(data []byte, _ bool) (advance int, token []byte, err error) { return 0, nil, nil } - // tranLen represents the length of bytes that are part of the transaction + // neededSize represents the length of bytes that are part of the field token. neededSize := minFieldLen + int(binary.BigEndian.Uint16(data[2:4])) if neededSize > len(data) { return 0, nil, nil @@ -118,18 +121,27 @@ func (f *Field) Read(p []byte) (int, error) { // Write implements io.Writer for Field func (f *Field) Write(p []byte) (int, error) { - f.ID = [2]byte(p[0:2]) - f.FieldSize = [2]byte(p[2:4]) + if len(p) < minFieldLen { + return 0, errors.New("input slice too short") + } + + copy(f.ID[:], p[0:2]) + copy(f.FieldSize[:], p[2:4]) + + dataSize := int(binary.BigEndian.Uint16(f.FieldSize[:])) + if len(p) < minFieldLen+dataSize { + return 0, errors.New("input slice too short for data size") + } - i := int(binary.BigEndian.Uint16(f.FieldSize[:])) - f.Data = p[4 : 4+i] + f.Data = make([]byte, dataSize) + copy(f.Data, p[4:4+dataSize]) - return minFieldLen + i, nil + return minFieldLen + dataSize, nil } -func getField(id int, fields *[]Field) *Field { +func getField(id [2]byte, fields *[]Field) *Field { for _, field := range *fields { - if id == int(binary.BigEndian.Uint16(field.ID[:])) { + if id == field.ID { return &field } } diff --git a/hotline/file_header.go b/hotline/file_header.go index 469f321..b18229a 100644 --- a/hotline/file_header.go +++ b/hotline/file_header.go @@ -16,7 +16,6 @@ type FileHeader struct { func NewFileHeader(fileName string, isDir bool) FileHeader { fh := FileHeader{ - Type: [2]byte{0x00, 0x00}, FilePath: EncodeFilePath(fileName), } if isDir { diff --git a/hotline/file_resume_data.go b/hotline/file_resume_data.go index f5b82de..b5d980e 100644 --- a/hotline/file_resume_data.go +++ b/hotline/file_resume_data.go @@ -12,6 +12,8 @@ type FileResumeData struct { RSVD [34]byte // Unused ForkCount [2]byte // Length of ForkInfoList. Either 2 or 3 depending on whether file has a resource fork ForkInfoList []ForkInfoList + + readOffset int } type ForkInfoList struct { @@ -40,6 +42,31 @@ func NewFileResumeData(list []ForkInfoList) *FileResumeData { } } +// +//func (frd *FileResumeData) Read(p []byte) (int, error) { +// buf := slices.Concat( +// frd.Format[:], +// frd.Version[:], +// frd.RSVD[:], +// frd.ForkCount[:], +// ) +// for _, fil := range frd.ForkInfoList { +// buf = append(buf, fil...) +// _ = binary.Write(&buf, binary.LittleEndian, fil) +// } +// +// var buf bytes.Buffer +// _ = binary.Write(&buf, binary.LittleEndian, frd.Format) +// _ = binary.Write(&buf, binary.LittleEndian, frd.Version) +// _ = binary.Write(&buf, binary.LittleEndian, frd.RSVD) +// _ = binary.Write(&buf, binary.LittleEndian, frd.ForkCount) +// for _, fil := range frd.ForkInfoList { +// _ = binary.Write(&buf, binary.LittleEndian, fil) +// } +// +// return buf.Bytes(), nil +//} + func (frd *FileResumeData) BinaryMarshal() ([]byte, error) { var buf bytes.Buffer _ = binary.Write(&buf, binary.LittleEndian, frd.Format) diff --git a/hotline/file_transfer.go b/hotline/file_transfer.go index 194e8b9..c883822 100644 --- a/hotline/file_transfer.go +++ b/hotline/file_transfer.go @@ -1,11 +1,18 @@ package hotline import ( + "bufio" "crypto/rand" "encoding/binary" + "errors" "fmt" + "io" + "io/fs" + "log/slog" "math" + "os" "path/filepath" + "strings" "sync" ) @@ -21,7 +28,6 @@ const ( type FileTransfer struct { FileName []byte FilePath []byte - ReferenceNumber []byte refNum [4]byte Type int TransferSize []byte @@ -50,27 +56,24 @@ func (wc *WriteCounter) Write(p []byte) (int, error) { } func (cc *ClientConn) newFileTransfer(transferType int, fileName, filePath, size []byte) *FileTransfer { - var transactionRef [4]byte - _, _ = rand.Read(transactionRef[:]) - ft := &FileTransfer{ FileName: fileName, FilePath: filePath, - ReferenceNumber: transactionRef[:], - refNum: transactionRef, Type: transferType, TransferSize: size, ClientConn: cc, bytesSentCounter: &WriteCounter{}, } + _, _ = rand.Read(ft.refNum[:]) + cc.transfersMU.Lock() defer cc.transfersMU.Unlock() - cc.transfers[transferType][transactionRef] = ft + cc.transfers[transferType][ft.refNum] = ft cc.Server.mux.Lock() defer cc.Server.mux.Unlock() - cc.Server.fileTransfers[transactionRef] = ft + cc.Server.fileTransfers[ft.refNum] = ft return ft } @@ -94,7 +97,7 @@ func (ft *FileTransfer) percentComplete() string { func (ft *FileTransfer) formattedTransferSize() string { sizeInKB := float32(binary.BigEndian.Uint32(ft.TransferSize)) / 1024 - if sizeInKB > 1024 { + if sizeInKB >= 1024 { return fmt.Sprintf("%.1fM", sizeInKB/1024) } else { return fmt.Sprintf("%.0fK", sizeInKB) @@ -112,6 +115,23 @@ type folderUpload struct { FileNamePath []byte } +//func (fu *folderUpload) Write(p []byte) (int, error) { +// if len(p) < 7 { +// return 0, errors.New("buflen too short") +// } +// copy(fu.DataSize[:], p[0:2]) +// copy(fu.IsFolder[:], p[2:4]) +// copy(fu.PathItemCount[:], p[4:6]) +// +// fu.FileNamePath = make([]byte, binary.BigEndian.Uint16(fu.DataSize[:])-4) // -4 to subtract the path separator bytes TODO: wat +// n, err := io.ReadFull(rwc, fu.FileNamePath) +// if err != nil { +// return 0, err +// } +// +// return n + 6, nil +//} + func (fu *folderUpload) FormattedPath() string { pathItemLen := binary.BigEndian.Uint16(fu.PathItemCount[:]) @@ -127,3 +147,445 @@ func (fu *folderUpload) FormattedPath() string { return filepath.Join(pathSegments...) } + +func DownloadHandler(rwc io.ReadWriter, fullPath string, fileTransfer *FileTransfer, fs FileStore, rLogger *slog.Logger, preserveForks bool) error { + //s.Stats.DownloadCounter += 1 + //s.Stats.DownloadsInProgress += 1 + //defer func() { + // s.Stats.DownloadsInProgress -= 1 + //}() + + var dataOffset int64 + if fileTransfer.fileResumeData != nil { + dataOffset = int64(binary.BigEndian.Uint32(fileTransfer.fileResumeData.ForkInfoList[0].DataSize[:])) + } + + fw, err := newFileWrapper(fs, fullPath, 0) + if err != nil { + //return err + } + + rLogger.Info("File download started", "filePath", fullPath) + + // if file transfer options are included, that means this is a "quick preview" request from a 1.5+ client + if fileTransfer.options == nil { + _, err = io.Copy(rwc, fw.ffo) + if err != nil { + //return err + } + } + + file, err := fw.dataForkReader() + if err != nil { + //return err + } + + br := bufio.NewReader(file) + if _, err := br.Discard(int(dataOffset)); err != nil { + //return err + } + + if _, err = io.Copy(rwc, io.TeeReader(br, fileTransfer.bytesSentCounter)); err != nil { + return err + } + + // if the client requested to resume transfer, do not send the resource fork header, or it will be appended into the fileWrapper data + if fileTransfer.fileResumeData == nil { + err = binary.Write(rwc, binary.BigEndian, fw.rsrcForkHeader()) + if err != nil { + return err + } + } + + rFile, err := fw.rsrcForkFile() + if err != nil { + return nil + } + + if _, err = io.Copy(rwc, io.TeeReader(rFile, fileTransfer.bytesSentCounter)); err != nil { + return err + } + + return nil +} + +func UploadHandler(rwc io.ReadWriter, fullPath string, fileTransfer *FileTransfer, fileStore FileStore, rLogger *slog.Logger, preserveForks bool) error { + var file *os.File + + // A file upload has two possible cases: + // 1) Upload a new file + // 2) Resume a partially transferred file + // We have to infer which case applies by inspecting what is already on the filesystem + + // Check for existing file. If found, do not proceed. This is an invalid scenario, as the file upload transaction + // handler should have returned an error to the client indicating there was an existing file present. + _, err := os.Stat(fullPath) + if err == nil { + return fmt.Errorf("existing file found: %s", fullPath) + } + if errors.Is(err, fs.ErrNotExist) { + // If not found, open or create a new .incomplete file + file, err = os.OpenFile(fullPath+incompleteFileSuffix, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + return err + } + } + + f, err := newFileWrapper(fileStore, fullPath, 0) + if err != nil { + return err + } + + rLogger.Info("File upload started", "dstFile", fullPath) + + rForkWriter := io.Discard + iForkWriter := io.Discard + if preserveForks { + rForkWriter, err = f.rsrcForkWriter() + if err != nil { + return err + } + + iForkWriter, err = f.infoForkWriter() + if err != nil { + return err + } + } + + if err := receiveFile(rwc, file, rForkWriter, iForkWriter, fileTransfer.bytesSentCounter); err != nil { + rLogger.Error(err.Error()) + } + + if err := file.Close(); err != nil { + return err + } + + if err := fileStore.Rename(fullPath+".incomplete", fullPath); err != nil { + return err + } + + rLogger.Info("File upload complete", "dstFile", fullPath) + + return nil +} + +func DownloadFolderHandler(rwc io.ReadWriter, fullPath string, fileTransfer *FileTransfer, fileStore FileStore, rLogger *slog.Logger, preserveForks bool) error { + // Folder Download flow: + // 1. Get filePath from the transfer + // 2. Iterate over files + // 3. For each fileWrapper: + // Send fileWrapper header to client + // The client can reply in 3 ways: + // + // 1. If type is an odd number (unknown type?), or fileWrapper download for the current fileWrapper is completed: + // client sends []byte{0x00, 0x03} to tell the server to continue to the next fileWrapper + // + // 2. If download of a fileWrapper is to be resumed: + // client sends: + // []byte{0x00, 0x02} // download folder action + // [2]byte // Resume data size + // []byte fileWrapper resume data (see myField_FileResumeData) + // + // 3. Otherwise, download of the fileWrapper is requested and client sends []byte{0x00, 0x01} + // + // When download is requested (case 2 or 3), server replies with: + // [4]byte - fileWrapper size + // []byte - Flattened File Object + // + // After every fileWrapper download, client could request next fileWrapper with: + // []byte{0x00, 0x03} + // + // This notifies the server to send the next item header + + basePathLen := len(fullPath) + + rLogger.Info("Start folder download", "path", fullPath) + + nextAction := make([]byte, 2) + if _, err := io.ReadFull(rwc, nextAction); err != nil { + return err + } + + i := 0 + err := filepath.Walk(fullPath+"/", func(path string, info os.FileInfo, err error) error { + //s.Stats.DownloadCounter += 1 + i += 1 + + if err != nil { + return err + } + + // skip dot files + if strings.HasPrefix(info.Name(), ".") { + return nil + } + + hlFile, err := newFileWrapper(fileStore, path, 0) + if err != nil { + return err + } + + subPath := path[basePathLen+1:] + rLogger.Debug("Sending fileheader", "i", i, "path", path, "fullFilePath", fullPath, "subPath", subPath, "IsDir", info.IsDir()) + + if i == 1 { + return nil + } + + fileHeader := NewFileHeader(subPath, info.IsDir()) + if _, err := io.Copy(rwc, &fileHeader); err != nil { + return fmt.Errorf("error sending file header: %w", err) + } + + // Read the client's Next Action request + if _, err := io.ReadFull(rwc, nextAction); err != nil { + return err + } + + rLogger.Debug("Client folder download action", "action", fmt.Sprintf("%X", nextAction[0:2])) + + var dataOffset int64 + + switch nextAction[1] { + case dlFldrActionResumeFile: + // get size of resumeData + resumeDataByteLen := make([]byte, 2) + if _, err := io.ReadFull(rwc, resumeDataByteLen); err != nil { + return err + } + + resumeDataLen := binary.BigEndian.Uint16(resumeDataByteLen) + resumeDataBytes := make([]byte, resumeDataLen) + if _, err := io.ReadFull(rwc, resumeDataBytes); err != nil { + return err + } + + var frd FileResumeData + if err := frd.UnmarshalBinary(resumeDataBytes); err != nil { + return err + } + dataOffset = int64(binary.BigEndian.Uint32(frd.ForkInfoList[0].DataSize[:])) + case dlFldrActionNextFile: + // client asked to skip this file + return nil + } + + if info.IsDir() { + return nil + } + + rLogger.Info("File download started", + "fileName", info.Name(), + "TransferSize", fmt.Sprintf("%x", hlFile.ffo.TransferSize(dataOffset)), + ) + + // Send file size to client + if _, err := rwc.Write(hlFile.ffo.TransferSize(dataOffset)); err != nil { + rLogger.Error(err.Error()) + return fmt.Errorf("error sending file size: %w", err) + } + + // Send ffo bytes to client + _, err = io.Copy(rwc, hlFile.ffo) + if err != nil { + return fmt.Errorf("error sending flat file object: %w", err) + } + + file, err := fileStore.Open(path) + if err != nil { + return fmt.Errorf("error opening file: %w", err) + } + + // wr := bufio.NewWriterSize(rwc, 1460) + if _, err = io.Copy(rwc, io.TeeReader(file, fileTransfer.bytesSentCounter)); err != nil { + return fmt.Errorf("error sending file: %w", err) + } + + if nextAction[1] != 2 && hlFile.ffo.FlatFileHeader.ForkCount[1] == 3 { + err = binary.Write(rwc, binary.BigEndian, hlFile.rsrcForkHeader()) + if err != nil { + return fmt.Errorf("error sending resource fork header: %w", err) + } + + rFile, err := hlFile.rsrcForkFile() + if err != nil { + return fmt.Errorf("error opening resource fork: %w", err) + } + + if _, err = io.Copy(rwc, io.TeeReader(rFile, fileTransfer.bytesSentCounter)); err != nil { + return fmt.Errorf("error sending resource fork: %w", err) + } + } + + // Read the client's Next Action request. This is always 3, I think? + if _, err := io.ReadFull(rwc, nextAction); err != nil && err != io.EOF { + return fmt.Errorf("error reading client next action: %w", err) + } + + return nil + }) + + if err != nil { + return err + } + + return nil +} + +func UploadFolderHandler(rwc io.ReadWriter, fullPath string, fileTransfer *FileTransfer, fileStore FileStore, rLogger *slog.Logger, preserveForks bool) error { + + // Check if the target folder exists. If not, create it. + if _, err := fileStore.Stat(fullPath); os.IsNotExist(err) { + if err := fileStore.Mkdir(fullPath, 0777); err != nil { + return err + } + } + + // Begin the folder upload flow by sending the "next file action" to client + if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { + return err + } + + fileSize := make([]byte, 4) + + for i := 0; i < fileTransfer.ItemCount(); i++ { + //s.Stats.UploadCounter += 1 + + var fu folderUpload + // TODO: implement io.Writer on folderUpload and replace this + if _, err := io.ReadFull(rwc, fu.DataSize[:]); err != nil { + return err + } + if _, err := io.ReadFull(rwc, fu.IsFolder[:]); err != nil { + return err + } + if _, err := io.ReadFull(rwc, fu.PathItemCount[:]); err != nil { + return err + } + fu.FileNamePath = make([]byte, binary.BigEndian.Uint16(fu.DataSize[:])-4) // -4 to subtract the path separator bytes TODO: wat + if _, err := io.ReadFull(rwc, fu.FileNamePath); err != nil { + return err + } + + if fu.IsFolder == [2]byte{0, 1} { + if _, err := os.Stat(filepath.Join(fullPath, fu.FormattedPath())); os.IsNotExist(err) { + if err := os.Mkdir(filepath.Join(fullPath, fu.FormattedPath()), 0777); err != nil { + return err + } + } + + // Tell client to send next file + if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { + return err + } + } else { + nextAction := dlFldrActionSendFile + + // Check if we have the full file already. If so, send dlFldrAction_NextFile to client to skip. + _, err := os.Stat(filepath.Join(fullPath, fu.FormattedPath())) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + if err == nil { + nextAction = dlFldrActionNextFile + } + + // Check if we have a partial file already. If so, send dlFldrAction_ResumeFile to client to resume upload. + incompleteFile, err := os.Stat(filepath.Join(fullPath, fu.FormattedPath()+incompleteFileSuffix)) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + if err == nil { + nextAction = dlFldrActionResumeFile + } + + if _, err := rwc.Write([]byte{0, uint8(nextAction)}); err != nil { + return err + } + + switch nextAction { + case dlFldrActionNextFile: + continue + case dlFldrActionResumeFile: + offset := make([]byte, 4) + binary.BigEndian.PutUint32(offset, uint32(incompleteFile.Size())) + + file, err := os.OpenFile(fullPath+"/"+fu.FormattedPath()+incompleteFileSuffix, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + + fileResumeData := NewFileResumeData([]ForkInfoList{*NewForkInfoList(offset)}) + + b, _ := fileResumeData.BinaryMarshal() + + bs := make([]byte, 2) + binary.BigEndian.PutUint16(bs, uint16(len(b))) + + if _, err := rwc.Write(append(bs, b...)); err != nil { + return err + } + + if _, err := io.ReadFull(rwc, fileSize); err != nil { + return err + } + + if err := receiveFile(rwc, file, io.Discard, io.Discard, fileTransfer.bytesSentCounter); err != nil { + rLogger.Error(err.Error()) + } + + err = os.Rename(fullPath+"/"+fu.FormattedPath()+".incomplete", fullPath+"/"+fu.FormattedPath()) + if err != nil { + return err + } + + case dlFldrActionSendFile: + if _, err := io.ReadFull(rwc, fileSize); err != nil { + return err + } + + filePath := filepath.Join(fullPath, fu.FormattedPath()) + + hlFile, err := newFileWrapper(fileStore, filePath, 0) + if err != nil { + return err + } + + rLogger.Info("Starting file transfer", "path", filePath, "fileNum", i+1, "fileSize", binary.BigEndian.Uint32(fileSize)) + + incWriter, err := hlFile.incFileWriter() + if err != nil { + return err + } + + rForkWriter := io.Discard + iForkWriter := io.Discard + if preserveForks { + iForkWriter, err = hlFile.infoForkWriter() + if err != nil { + return err + } + + rForkWriter, err = hlFile.rsrcForkWriter() + if err != nil { + return err + } + } + if err := receiveFile(rwc, incWriter, rForkWriter, iForkWriter, fileTransfer.bytesSentCounter); err != nil { + return err + } + + if err := os.Rename(filePath+".incomplete", filePath); err != nil { + return err + } + } + + // Tell client to send next fileWrapper + if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { + return err + } + } + } + rLogger.Info("Folder upload complete") + return nil +} diff --git a/hotline/file_transfer_test.go b/hotline/file_transfer_test.go new file mode 100644 index 0000000..213cb9a --- /dev/null +++ b/hotline/file_transfer_test.go @@ -0,0 +1,89 @@ +package hotline + +import ( + "encoding/binary" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFileTransfer_String(t *testing.T) { + type fields struct { + FileName []byte + FilePath []byte + refNum [4]byte + Type int + TransferSize []byte + FolderItemCount []byte + fileResumeData *FileResumeData + options []byte + bytesSentCounter *WriteCounter + ClientConn *ClientConn + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "50% complete 198MB file", + fields: fields{ + FileName: []byte("MasterOfOrionII1.4.0."), + TransferSize: func() []byte { b := make([]byte, 4); binary.BigEndian.PutUint32(b, 207618048); return b }(), + bytesSentCounter: &WriteCounter{ + Total: 103809024, + }, + }, + want: "MasterOfOrionII1.4.0. 50% 198.0M\n", + }, + { + name: "25% complete 512KB file", + fields: fields{ + FileName: []byte("ExampleFile.txt"), + TransferSize: func() []byte { b := make([]byte, 4); binary.BigEndian.PutUint32(b, 524288); return b }(), + bytesSentCounter: &WriteCounter{ + Total: 131072, + }, + }, + want: "ExampleFile.txt 25% 512K\n", + }, + { + name: "100% complete 2GB file", + fields: fields{ + FileName: []byte("LargeFile.dat"), + TransferSize: func() []byte { b := make([]byte, 4); binary.BigEndian.PutUint32(b, 2147483648); return b }(), + bytesSentCounter: &WriteCounter{ + Total: 2147483648, + }, + }, + want: "LargeFile.dat 100% 2048.0M\n", + }, + { + name: "0% complete 1MB file", + fields: fields{ + FileName: []byte("NewDocument.docx"), + TransferSize: func() []byte { b := make([]byte, 4); binary.BigEndian.PutUint32(b, 1048576); return b }(), + bytesSentCounter: &WriteCounter{ + Total: 0, + }, + }, + want: "NewDocument.docx 0% 1.0M\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ft := &FileTransfer{ + FileName: tt.fields.FileName, + FilePath: tt.fields.FilePath, + refNum: tt.fields.refNum, + Type: tt.fields.Type, + TransferSize: tt.fields.TransferSize, + FolderItemCount: tt.fields.FolderItemCount, + fileResumeData: tt.fields.fileResumeData, + options: tt.fields.options, + bytesSentCounter: tt.fields.bytesSentCounter, + ClientConn: tt.fields.ClientConn, + } + assert.Equalf(t, tt.want, ft.String(), "String()") + }) + } +} diff --git a/hotline/files.go b/hotline/files.go index f2fc171..bc2f8d5 100644 --- a/hotline/files.go +++ b/hotline/files.go @@ -169,15 +169,19 @@ func CalcTotalSize(filePath string) ([]byte, error) { return bs, nil } +// CalcItemCount recurses through a file path and counts the number of non-hidden files. func CalcItemCount(filePath string) ([]byte, error) { - var itemcount uint16 + var itemCount uint16 + + // Walk the directory and count items err := filepath.Walk(filePath, func(path string, info os.FileInfo, err error) error { if err != nil { return err } + // Skip hidden files if !strings.HasPrefix(info.Name(), ".") { - itemcount += 1 + itemCount++ } return nil @@ -187,7 +191,7 @@ func CalcItemCount(filePath string) ([]byte, error) { } bs := make([]byte, 2) - binary.BigEndian.PutUint16(bs, itemcount-1) + binary.BigEndian.PutUint16(bs, itemCount-1) return bs, nil } diff --git a/hotline/files_test.go b/hotline/files_test.go index 90c5f00..4c8ef80 100644 --- a/hotline/files_test.go +++ b/hotline/files_test.go @@ -2,7 +2,9 @@ package hotline import ( "bytes" + "encoding/binary" "os" + "path/filepath" "reflect" "testing" ) @@ -81,3 +83,89 @@ func TestCalcTotalSize(t *testing.T) { }) } } + +func createTestDirStructure(baseDir string, structure map[string]string) error { + // First pass: create directories + for path, content := range structure { + if content == "dir" { + if err := os.MkdirAll(filepath.Join(baseDir, path), 0755); err != nil { + return err + } + } + } + + // Second pass: create files + for path, content := range structure { + if content != "dir" { + fullPath := filepath.Join(baseDir, path) + dir := filepath.Dir(fullPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + return err + } + } + } + return nil +} + +func TestCalcItemCount(t *testing.T) { + tests := []struct { + name string + structure map[string]string + expected uint16 + }{ + { + name: "directory with files", + structure: map[string]string{ + "file1.txt": "content1", + "file2.txt": "content2", + "subdir/": "dir", + "subdir/file3.txt": "content3", + }, + expected: 4, // 3 files and 1 directory, should count 4 items + }, + { + name: "directory with hidden files", + structure: map[string]string{ + ".hiddenfile": "hiddencontent", + "file1.txt": "content1", + }, + expected: 1, // 1 non-hidden file + }, + { + name: "empty directory", + structure: map[string]string{}, + expected: 0, // 0 files + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create the test directory structure + if err := createTestDirStructure(tempDir, tt.structure); err != nil { + t.Fatalf("Failed to create test dir structure: %v", err) + } + + // Calculate item count + result, err := CalcItemCount(tempDir) + if err != nil { + t.Fatalf("CalcItemCount returned an error: %v", err) + } + + // Convert result to uint16 + count := binary.BigEndian.Uint16(result) + if count != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, count) + } + }) + } +} diff --git a/hotline/flattened_file_object.go b/hotline/flattened_file_object.go index 15ce94d..fbc319b 100644 --- a/hotline/flattened_file_object.go +++ b/hotline/flattened_file_object.go @@ -102,8 +102,10 @@ func (ffif *FlatFileInformationFork) Size() [4]byte { } func (ffo *flattenedFileObject) TransferSize(offset int64) []byte { + ffoCopy := *ffo + // get length of the flattenedFileObject, including the info fork - b, _ := io.ReadAll(ffo) + b, _ := io.ReadAll(&ffoCopy) payloadSize := len(b) // length of data fork diff --git a/hotline/handshake.go b/hotline/handshake.go index 9ccdc62..c54359e 100644 --- a/hotline/handshake.go +++ b/hotline/handshake.go @@ -4,19 +4,12 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" ) -type handshake struct { - Protocol [4]byte // Must be 0x54525450 TRTP - SubProtocol [4]byte - Version [2]byte // Always 1 - SubVersion [2]byte -} - -var trtp = [4]byte{0x54, 0x52, 0x54, 0x50} - -// Handshake +// Hotline handshake process +// // After establishing TCP connection, both client and server start the handshake process // in order to confirm that each of them comply with requirements of the other. // The information provided in this initial data exchange identifies protocols, @@ -35,22 +28,60 @@ var trtp = [4]byte{0x54, 0x52, 0x54, 0x50} // Description Size Data Note // Protocol ID 4 TRTP // Error code 4 Error code returned by the server (0 = no error) -func Handshake(rw io.ReadWriter) error { - handshakeBuf := make([]byte, 12) - if _, err := io.ReadFull(rw, handshakeBuf); err != nil { - return err + +type handshake struct { + Protocol [4]byte // Must be 0x54525450 TRTP + SubProtocol [4]byte // Must be 0x484F544C HOTL + Version [2]byte // Always 1 (?) + SubVersion [2]byte // Always 2 (?) +} + +// Write implements the io.Writer interface for handshake. +func (h *handshake) Write(p []byte) (n int, err error) { + if len(p) != handshakeSize { + return 0, errors.New("invalid handshake size") } + _ = binary.Read(bytes.NewBuffer(p), binary.BigEndian, h) + + return len(p), nil +} + +// Valid checks if the handshake contains valid protocol and sub-protocol IDs. +func (h *handshake) Valid() bool { + return h.Protocol == trtp && h.SubProtocol == hotl +} + +var ( + // trtp represents the Protocol ID "TRTP" in hex + trtp = [4]byte{0x54, 0x52, 0x54, 0x50} + + // hotl represents the Sub-protocol ID "HOTL" in hex + hotl = [4]byte{0x48, 0x4F, 0x54, 0x4C} + + // handshakeResponse represents the server's response after a successful handshake + // Response with "TRTP" and no error code + handshakeResponse = [8]byte{0x54, 0x52, 0x54, 0x50, 0x00, 0x00, 0x00, 0x00} +) + +const handshakeSize = 12 + +// performHandshake performs the handshake process. +func performHandshake(rw io.ReadWriter) error { var h handshake - r := bytes.NewReader(handshakeBuf) - if err := binary.Read(r, binary.BigEndian, &h); err != nil { - return err + + // Copy exactly handshakeSize bytes from rw to handshake + if _, err := io.CopyN(&h, rw, handshakeSize); err != nil { + return fmt.Errorf("failed to read handshake data: %w", err) + } + + if !h.Valid() { + return errors.New("invalid protocol or sub-protocol in handshake") } - if h.Protocol != trtp { - return errors.New("invalid handshake") + if _, err := rw.Write(handshakeResponse[:]); err != nil { + return fmt.Errorf("error sending handshake response: %w", err) } - _, err := rw.Write([]byte{84, 82, 84, 80, 0, 0, 0, 0}) - return err + return nil } diff --git a/hotline/handshake_test.go b/hotline/handshake_test.go new file mode 100644 index 0000000..73ebfe4 --- /dev/null +++ b/hotline/handshake_test.go @@ -0,0 +1,229 @@ +package hotline + +import ( + "bytes" + "testing" +) + +func TestHandshakeWrite(t *testing.T) { + tests := []struct { + name string + input []byte + expected handshake + expectedError string + }{ + { + name: "Valid Handshake", + input: []byte{0x54, 0x52, 0x54, 0x50, 0x48, 0x4F, 0x54, 0x4C, 0x00, 0x01, 0x00, 0x02}, + expected: handshake{ + Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, + SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, + Version: [2]byte{0x00, 0x01}, + SubVersion: [2]byte{0x00, 0x02}, + }, + expectedError: "", + }, + { + name: "Invalid Handshake Size", + input: []byte{0x54, 0x52, 0x54, 0x50}, + expected: handshake{}, + expectedError: "invalid handshake size", + }, + { + name: "Empty Handshake Data", + input: []byte{}, + expected: handshake{}, + expectedError: "invalid handshake size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var h handshake + n, err := h.Write(tt.input) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %q, got %q", tt.expectedError, err) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != handshakeSize { + t.Fatalf("expected %d bytes written, got %d", handshakeSize, n) + } + if h != tt.expected { + t.Fatalf("expected handshake %+v, got %+v", tt.expected, h) + } + } + }) + } +} + +func TestHandshakeValid(t *testing.T) { + tests := []struct { + name string + input handshake + expected bool + }{ + { + name: "Valid Handshake", + input: handshake{ + Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP + SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL + Version: [2]byte{0x00, 0x01}, + SubVersion: [2]byte{0x00, 0x02}, + }, + expected: true, + }, + { + name: "Invalid Protocol", + input: handshake{ + Protocol: [4]byte{0x00, 0x00, 0x00, 0x00}, + SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL + Version: [2]byte{0x00, 0x01}, + SubVersion: [2]byte{0x00, 0x02}, + }, + expected: false, + }, + { + name: "Invalid SubProtocol", + input: handshake{ + Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP + SubProtocol: [4]byte{0x00, 0x00, 0x00, 0x00}, + Version: [2]byte{0x00, 0x01}, + SubVersion: [2]byte{0x00, 0x02}, + }, + expected: false, + }, + { + name: "Invalid Protocol and SubProtocol", + input: handshake{ + Protocol: [4]byte{0x00, 0x00, 0x00, 0x00}, + SubProtocol: [4]byte{0x00, 0x00, 0x00, 0x00}, + Version: [2]byte{0x00, 0x01}, + SubVersion: [2]byte{0x00, 0x02}, + }, + expected: false, + }, + { + name: "Valid Handshake with Different Version", + input: handshake{ + Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP + SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL + Version: [2]byte{0x00, 0x02}, + SubVersion: [2]byte{0x00, 0x03}, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.input.Valid() + if result != tt.expected { + t.Fatalf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +// readWriteBuffer combines input and output buffers to implement io.ReadWriter +type readWriteBuffer struct { + input *bytes.Buffer + output *bytes.Buffer +} + +func (rw *readWriteBuffer) Read(p []byte) (int, error) { + return rw.input.Read(p) +} + +func (rw *readWriteBuffer) Write(p []byte) (int, error) { + return rw.output.Write(p) +} + +func TestPerformHandshake(t *testing.T) { + tests := []struct { + name string + input []byte + expectedOutput []byte + expectedError string + }{ + { + name: "Valid Handshake", + input: []byte{ + 0x54, 0x52, 0x54, 0x50, // TRTP + 0x48, 0x4F, 0x54, 0x4C, // HOTL + 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2 + }, + expectedOutput: []byte{0x54, 0x52, 0x54, 0x50, 0x00, 0x00, 0x00, 0x00}, + expectedError: "", + }, + { + name: "Invalid Handshake Size", + input: []byte{ + 0x54, 0x52, 0x54, 0x50, // TRTP + }, + expectedOutput: nil, + expectedError: "failed to read handshake data: invalid handshake size", + }, + { + name: "Invalid Protocol", + input: []byte{ + 0x00, 0x00, 0x00, 0x00, // Invalid protocol + 0x48, 0x4F, 0x54, 0x4C, // HOTL + 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2 + }, + expectedOutput: nil, + expectedError: "invalid protocol or sub-protocol in handshake", + }, + { + name: "Invalid SubProtocol", + input: []byte{ + 0x54, 0x52, 0x54, 0x50, // TRTP + 0x00, 0x00, 0x00, 0x00, // Invalid sub-protocol + 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2 + }, + expectedOutput: nil, + expectedError: "invalid protocol or sub-protocol in handshake", + }, + { + name: "Binary Read Error", + input: []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, // Invalid data + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + }, + expectedOutput: nil, + expectedError: "invalid protocol or sub-protocol in handshake", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputBuffer := bytes.NewBuffer(tt.input) + outputBuffer := &bytes.Buffer{} + rw := &readWriteBuffer{ + input: inputBuffer, + output: outputBuffer, + } + + err := performHandshake(rw) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %q, got %q", tt.expectedError, err) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + output := outputBuffer.Bytes() + if !bytes.Equal(output, tt.expectedOutput) { + t.Fatalf("expected output %v, got %v", tt.expectedOutput, output) + } + } + }) + } +} diff --git a/hotline/news.go b/hotline/news.go index cd3b6af..6e89388 100644 --- a/hotline/news.go +++ b/hotline/news.go @@ -1,7 +1,6 @@ package hotline import ( - "bytes" "encoding/binary" "io" "slices" @@ -29,6 +28,8 @@ type NewsCategoryListData15 struct { GUID [16]byte `yaml:"-"` // What does this do? Undocumented and seeming unused. AddSN [4]byte `yaml:"-"` // What does this do? Undocumented and seeming unused. DeleteSN [4]byte `yaml:"-"` // What does this do? Undocumented and seeming unused. + + readOffset int // Internal offset to track read progress } func (newscat *NewsCategoryListData15) GetNewsArtListData() NewsArtListData { @@ -165,7 +166,7 @@ func (nal *NewsArtList) Read(p []byte) (int, error) { nal.TimeStamp[:], nal.ParentID[:], nal.Flags[:], - []byte{0, 1}, // Flavor Count + []byte{0, 1}, // Flavor Count TODO: make this not hardcoded []byte{uint8(len(nal.Title))}, nal.Title, []byte{uint8(len(nal.Poster))}, @@ -191,23 +192,33 @@ type NewsFlavorList struct { // Article size 2 } -func (newscat *NewsCategoryListData15) MarshalBinary() (data []byte, err error) { +func (newscat *NewsCategoryListData15) Read(p []byte) (int, error) { count := make([]byte, 2) binary.BigEndian.PutUint16(count, uint16(len(newscat.Articles)+len(newscat.SubCats))) - out := append(newscat.Type[:], count...) + out := slices.Concat( + newscat.Type[:], + count, + ) // If type is category - if bytes.Equal(newscat.Type[:], []byte{0, 3}) { - out = append(out, newscat.GUID[:]...) // GUID - out = append(out, newscat.AddSN[:]...) // Add SN - out = append(out, newscat.DeleteSN[:]...) // Delete SN + if newscat.Type == [2]byte{0, 3} { + out = append(out, newscat.GUID[:]...) + out = append(out, newscat.AddSN[:]...) + out = append(out, newscat.DeleteSN[:]...) } out = append(out, newscat.nameLen()...) out = append(out, []byte(newscat.Name)...) - return out, err + if newscat.readOffset >= len(out) { + return 0, io.EOF // All bytes have been read + } + + n := copy(p, out) + newscat.readOffset = n + + return n, nil } func (newscat *NewsCategoryListData15) nameLen() []byte { diff --git a/hotline/news_test.go b/hotline/news_test.go index 3871ec2..d1b043e 100644 --- a/hotline/news_test.go +++ b/hotline/news_test.go @@ -2,6 +2,7 @@ package hotline import ( "github.com/stretchr/testify/assert" + "io" "testing" ) @@ -79,7 +80,7 @@ func TestNewsCategoryListData15_MarshalBinary(t *testing.T) { DeleteSN: tt.fields.DeleteSN, GUID: tt.fields.GUID, } - gotData, err := newscat.MarshalBinary() + gotData, err := io.ReadAll(newscat) if newscat.Type == [2]byte{0, 3} { // zero out the random GUID before comparison for i := 4; i < 20; i++ { diff --git a/hotline/server.go b/hotline/server.go index 5245c7b..0ee2dd7 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -12,16 +12,15 @@ import ( "golang.org/x/text/encoding/charmap" "gopkg.in/yaml.v3" "io" - "io/fs" "log" "log/slog" - "math/big" "net" "os" "path" "path/filepath" "strings" "sync" + "sync/atomic" "time" ) @@ -40,11 +39,12 @@ var txtDecoder = charmap.Macintosh.NewDecoder() var txtEncoder = charmap.Macintosh.NewEncoder() type Server struct { - NetInterface string - Port int - Accounts map[string]*Account - Agreement []byte - Clients map[uint16]*ClientConn + NetInterface string + Port int + Accounts map[string]*Account + Agreement []byte + + Clients map[[2]byte]*ClientConn fileTransfers map[[4]byte]*FileTransfer Config *Config @@ -53,12 +53,12 @@ type Server struct { banner []byte PrivateChatsMu sync.Mutex - PrivateChats map[uint32]*PrivateChat + PrivateChats map[[4]byte]*PrivateChat - NextGuestID *uint16 + nextClientID atomic.Uint32 TrackerPassID [4]byte - StatsMu sync.Mutex + statsMu sync.Mutex Stats *Stats FS FileStore // Storage backend to use for File storage @@ -77,8 +77,8 @@ type Server struct { } func (s *Server) CurrentStats() Stats { - s.StatsMu.Lock() - defer s.StatsMu.Unlock() + s.statsMu.Lock() + defer s.statsMu.Unlock() stats := s.Stats stats.CurrentlyConnected = len(s.Clients) @@ -88,10 +88,10 @@ func (s *Server) CurrentStats() Stats { type PrivateChat struct { Subject string - ClientConn map[uint16]*ClientConn + ClientConn map[[2]byte]*ClientConn } -func (s *Server) ListenAndServe(ctx context.Context, cancelRoot context.CancelFunc) error { +func (s *Server) ListenAndServe(ctx context.Context) error { var wg sync.WaitGroup wg.Add(1) @@ -130,9 +130,7 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error defer func() { _ = conn.Close() }() err = s.handleFileTransfer( - context.WithValue(ctx, contextKeyReq, requestCtx{ - remoteAddr: conn.RemoteAddr().String(), - }), + context.WithValue(ctx, contextKeyReq, requestCtx{remoteAddr: conn.RemoteAddr().String()}), conn, ) @@ -144,21 +142,17 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error } func (s *Server) sendTransaction(t Transaction) error { - clientID, err := byteToInt(*t.clientID) - if err != nil { - return fmt.Errorf("invalid client ID: %v", err) - } - s.mux.Lock() - client, ok := s.Clients[uint16(clientID)] + client, ok := s.Clients[t.clientID] s.mux.Unlock() + if !ok || client == nil { - return fmt.Errorf("invalid client id %v", *t.clientID) + return nil } - _, err = io.Copy(client.Connection, &t) + _, err := io.Copy(client.Connection, &t) if err != nil { - return fmt.Errorf("failed to send transaction to client %v: %v", clientID, err) + return fmt.Errorf("failed to send transaction to client %v: %v", t.clientID, err) } return nil @@ -207,18 +201,18 @@ const ( ) // NewServer constructs a new Server from a config dir +// TODO: move config file reads out of this function func NewServer(configDir, netInterface string, netPort int, logger *slog.Logger, fs FileStore) (*Server, error) { server := Server{ NetInterface: netInterface, Port: netPort, Accounts: make(map[string]*Account), Config: new(Config), - Clients: make(map[uint16]*ClientConn), + Clients: make(map[[2]byte]*ClientConn), fileTransfers: make(map[[4]byte]*FileTransfer), - PrivateChats: make(map[uint32]*PrivateChat), + PrivateChats: make(map[[4]byte]*PrivateChat), ConfigDir: configDir, Logger: logger, - NextGuestID: new(uint16), outbox: make(chan Transaction), Stats: &Stats{Since: time.Now()}, ThreadedNews: &ThreadedNews{}, @@ -243,14 +237,18 @@ func NewServer(configDir, netInterface string, netPort int, logger *slog.Logger, } // try to load the ban list, but ignore errors as this file may not be present or may be empty - _ = server.loadBanList(filepath.Join(configDir, "Banlist.yaml")) + //_ = server.loadBanList(filepath.Join(configDir, "Banlist.yaml")) - if err := server.loadThreadedNews(filepath.Join(configDir, "ThreadedNews.yaml")); err != nil { + _ = loadFromYAMLFile(filepath.Join(configDir, "Banlist.yaml"), &server.banList) + + err = loadFromYAMLFile(filepath.Join(configDir, "ThreadedNews.yaml"), &server.ThreadedNews) + if err != nil { return nil, fmt.Errorf("error loading threaded news: %w", err) } - if err := server.loadConfig(filepath.Join(configDir, "config.yaml")); err != nil { - return nil, err + err = server.loadConfig(filepath.Join(configDir, "config.yaml")) + if err != nil { + return nil, fmt.Errorf("error loading config: %w", err) } if err := server.loadAccounts(filepath.Join(configDir, "Users/")); err != nil { @@ -267,8 +265,6 @@ func NewServer(configDir, netInterface string, netPort int, logger *slog.Logger, return nil, fmt.Errorf("error opening banner: %w", err) } - *server.NextGuestID = 1 - if server.Config.EnableTrackerRegistration { server.Logger.Info( "Tracker registration enabled", @@ -286,7 +282,7 @@ func NewServer(configDir, netInterface string, netPort int, logger *slog.Logger, } binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port)) for _, t := range server.Config.Trackers { - if err := register(t, tr); err != nil { + if err := register(&RealDialer{}, t, tr); err != nil { server.Logger.Error("unable to register with tracker %v", "error", err) } server.Logger.Debug("Sent Tracker registration", "addr", t) @@ -320,14 +316,13 @@ func (s *Server) keepaliveHandler() { if c.IdleTime > userIdleSeconds && !c.Idle { c.Idle = true - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(c.Flags))) - flagBitmap.SetBit(flagBitmap, UserFlagAway, 1) - binary.BigEndian.PutUint16(c.Flags, uint16(flagBitmap.Int64())) - + c.flagsMU.Lock() + c.Flags.Set(UserFlagAway, 1) + c.flagsMU.Unlock() c.sendAll( TranNotifyChangeUser, - NewField(FieldUserID, *c.ID), - NewField(FieldUserFlags, c.Flags), + NewField(FieldUserID, c.ID[:]), + NewField(FieldUserFlags, c.Flags[:]), NewField(FieldUserName, c.UserName), NewField(FieldUserIconID, c.Icon), ) @@ -374,14 +369,9 @@ func (s *Server) NewClientConn(conn io.ReadWriteCloser, remoteAddr string) *Clie defer s.mux.Unlock() clientConn := &ClientConn{ - ID: &[]byte{0, 0}, - Icon: []byte{0, 0}, - Flags: []byte{0, 0}, - UserName: []byte{}, + Icon: []byte{0, 0}, // TODO: make array type Connection: conn, Server: s, - Version: []byte{}, - AutoReply: []byte{}, RemoteAddr: remoteAddr, transfers: map[int]map[[4]byte]*FileTransfer{ FileDownload: {}, @@ -392,11 +382,10 @@ func (s *Server) NewClientConn(conn io.ReadWriteCloser, remoteAddr string) *Clie }, } - *s.NextGuestID++ - ID := *s.NextGuestID + s.nextClientID.Add(1) - binary.BigEndian.PutUint16(*clientConn.ID, ID) - s.Clients[ID] = clientConn + binary.BigEndian.PutUint16(clientConn.ID[:], uint16(s.nextClientID.Load())) + s.Clients[clientConn.ID] = clientConn return clientConn } @@ -406,34 +395,29 @@ func (s *Server) NewUser(login, name, password string, access accessBitmap) erro s.mux.Lock() defer s.mux.Unlock() - account := Account{ - Login: login, - Name: name, - Password: hashAndSalt([]byte(password)), - Access: access, - } - out, err := yaml.Marshal(&account) - if err != nil { - return err - } + account := NewAccount(login, name, password, access) // Create account file, returning an error if one already exists. file, err := os.OpenFile( filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml"), - os.O_CREATE|os.O_EXCL|os.O_WRONLY, - 0644, + os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644, ) if err != nil { - return err + return fmt.Errorf("error creating account file: %w", err) } defer file.Close() - _, err = file.Write(out) + b, err := yaml.Marshal(account) + if err != nil { + return err + } + + _, err = file.Write(b) if err != nil { return fmt.Errorf("error writing account file: %w", err) } - s.Accounts[login] = &account + s.Accounts[login] = account return nil } @@ -442,11 +426,14 @@ func (s *Server) UpdateUser(login, newLogin, name, password string, access acces s.mux.Lock() defer s.mux.Unlock() - // update renames the user login + // If the login has changed, rename the account file. if login != newLogin { - err := os.Rename(filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml"), filepath.Join(s.ConfigDir, "Users", path.Join("/", newLogin)+".yaml")) + err := os.Rename( + filepath.Join(s.ConfigDir, "Users", path.Join("/", login)+".yaml"), + filepath.Join(s.ConfigDir, "Users", path.Join("/", newLogin)+".yaml"), + ) if err != nil { - return fmt.Errorf("unable to rename account: %w", err) + return fmt.Errorf("error renaming account file: %w", err) } s.Accounts[newLogin] = s.Accounts[login] s.Accounts[newLogin].Login = newLogin @@ -464,7 +451,7 @@ func (s *Server) UpdateUser(login, newLogin, name, password string, access acces } if err := os.WriteFile(filepath.Join(s.ConfigDir, "Users", newLogin+".yaml"), out, 0666); err != nil { - return err + return fmt.Errorf("error writing account file: %w", err) } return nil @@ -486,15 +473,15 @@ func (s *Server) DeleteUser(login string) error { } func (s *Server) connectedUsers() []Field { - s.mux.Lock() - defer s.mux.Unlock() + //s.mux.Lock() + //defer s.mux.Unlock() var connectedUsers []Field for _, c := range sortedClients(s.Clients) { b, err := io.ReadAll(&User{ - ID: *c.ID, + ID: c.ID, Icon: c.Icon, - Flags: c.Flags, + Flags: c.Flags[:], Name: string(c.UserName), }) if err != nil { @@ -505,25 +492,16 @@ func (s *Server) connectedUsers() []Field { return connectedUsers } -func (s *Server) loadBanList(path string) error { +// loadFromYAMLFile loads data from a YAML file into the provided data structure. +func loadFromYAMLFile(path string, data interface{}) error { fh, err := os.Open(path) if err != nil { return err } - decoder := yaml.NewDecoder(fh) + defer fh.Close() - return decoder.Decode(s.banList) -} - -// loadThreadedNews loads the threaded news data from disk -func (s *Server) loadThreadedNews(threadedNewsPath string) error { - fh, err := os.Open(threadedNewsPath) - if err != nil { - return err - } decoder := yaml.NewDecoder(fh) - - return decoder.Decode(s.ThreadedNews) + return decoder.Decode(data) } // loadAccounts loads account data from disk @@ -534,18 +512,12 @@ func (s *Server) loadAccounts(userDir string) error { } if len(matches) == 0 { - return errors.New("no user accounts found in " + userDir) + return fmt.Errorf("no accounts found in directory: %s", userDir) } for _, file := range matches { - fh, err := s.FS.Open(file) - if err != nil { - return err - } - - account := Account{} - decoder := yaml.NewDecoder(fh) - if err = decoder.Decode(&account); err != nil { + var account Account + if err = loadFromYAMLFile(file, &account); err != nil { return fmt.Errorf("error loading account %s: %w", file, err) } @@ -574,12 +546,41 @@ func (s *Server) loadConfig(path string) error { return nil } +func sendBanMessage(rwc io.Writer, message string) { + t := NewTransaction( + TranServerMsg, + [2]byte{0, 0}, + NewField(FieldData, []byte(message)), + NewField(FieldChatOptions, []byte{0, 0}), + ) + _, _ = io.Copy(rwc, &t) + time.Sleep(1 * time.Second) +} + // handleNewConnection takes a new net.Conn and performs the initial login sequence func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error { defer dontPanic(s.Logger) - if err := Handshake(rwc); err != nil { - return err + // Check if remoteAddr is present in the ban list + ipAddr := strings.Split(remoteAddr, ":")[0] + if banUntil, ok := s.banList[ipAddr]; ok { + // permaban + if banUntil == nil { + sendBanMessage(rwc, "You are permanently banned on this server") + s.Logger.Debug("Disconnecting permanently banned IP", "remoteAddr", ipAddr) + return nil + } + + // temporary ban + if time.Now().Before(*banUntil) { + sendBanMessage(rwc, "You are temporarily banned on this server") + s.Logger.Debug("Disconnecting temporarily banned IP", "remoteAddr", ipAddr) + return nil + } + } + + if err := performHandshake(rwc); err != nil { + return fmt.Errorf("error performing handshake: %w", err) } // Create a new scanner for parsing incoming bytes into transaction tokens @@ -595,59 +596,16 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser var clientLogin Transaction if _, err := clientLogin.Write(buf); err != nil { - return err - } - - // check if remoteAddr is present in the ban list - if banUntil, ok := s.banList[strings.Split(remoteAddr, ":")[0]]; ok { - // permaban - if banUntil == nil { - t := NewTransaction( - TranServerMsg, - &[]byte{0, 0}, - NewField(FieldData, []byte("You are permanently banned on this server")), - NewField(FieldChatOptions, []byte{0, 0}), - ) - - _, err := io.Copy(rwc, t) - if err != nil { - return err - } - - time.Sleep(1 * time.Second) - return nil - } - - // temporary ban - if time.Now().Before(*banUntil) { - t := NewTransaction( - TranServerMsg, - &[]byte{0, 0}, - NewField(FieldData, []byte("You are temporarily banned on this server")), - NewField(FieldChatOptions, []byte{0, 0}), - ) - - _, err := io.Copy(rwc, t) - if err != nil { - return err - } - - time.Sleep(1 * time.Second) - return nil - } + return fmt.Errorf("error writing login transaction: %w", err) } c := s.NewClientConn(rwc, remoteAddr) defer c.Disconnect() - encodedLogin := clientLogin.GetField(FieldUserLogin).Data encodedPassword := clientLogin.GetField(FieldUserPassword).Data c.Version = clientLogin.GetField(FieldVersion).Data - var login string - for _, char := range encodedLogin { - login += string(rune(255 - uint(char))) - } + login := string(encodeString(clientLogin.GetField(FieldUserLogin).Data)) if login == "" { login = GuestAccount } @@ -656,7 +614,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser // If authentication fails, send error reply and close connection if !c.Authenticate(login, encodedPassword) { - t := c.NewErrReply(&clientLogin, "Incorrect login.") + t := c.NewErrReply(&clientLogin, "Incorrect login.")[0] _, err := io.Copy(rwc, &t) if err != nil { @@ -672,7 +630,9 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser c.Icon = clientLogin.GetField(FieldUserIconID).Data } + c.Lock() c.Account = c.Server.Accounts[login] + c.Unlock() if clientLogin.GetField(FieldUserName).Data != nil { if c.Authorize(accessAnyName) { @@ -683,7 +643,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser } if c.Authorize(accessDisconUser) { - c.Flags = []byte{0, 2} + c.Flags.Set(UserFlagAdmin, 1) } s.outbox <- c.NewReply(&clientLogin, @@ -693,7 +653,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser ) // Send user access privs so client UI knows how to behave - c.Server.outbox <- *NewTransaction(TranUserAccess, c.ID, NewField(FieldUserAccess, c.Account.Access[:])) + c.Server.outbox <- NewTransaction(TranUserAccess, c.ID, NewField(FieldUserAccess, c.Account.Access[:])) // Accounts with accessNoAgreement do not receive the server agreement on login. The behavior is different between // client versions. For 1.2.3 client, we do not send TranShowAgreement. For other client versions, we send @@ -701,10 +661,10 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser if c.Authorize(accessNoAgreement) { // If client version is nil, then the client uses the 1.2.3 login behavior if c.Version != nil { - c.Server.outbox <- *NewTransaction(TranShowAgreement, c.ID, NewField(FieldNoServerAgreement, []byte{1})) + c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldNoServerAgreement, []byte{1})) } } else { - c.Server.outbox <- *NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, s.Agreement)) + c.Server.outbox <- NewTransaction(TranShowAgreement, c.ID, NewField(FieldData, s.Agreement)) } // If the client has provided a username as part of the login, we can infer that it is using the 1.2.3 login @@ -713,33 +673,33 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser // Add the client username to the logger. For 1.5+ clients, we don't have this information yet as it comes as // part of TranAgreed c.logger = c.logger.With("Name", string(c.UserName)) - c.logger.Info("Login successful", "clientVersion", "Not sent (probably 1.2.3)") // Notify other clients on the server that the new user has logged in. For 1.5+ clients we don't have this // information yet, so we do it in TranAgreed instead for _, t := range c.notifyOthers( - *NewTransaction( - TranNotifyChangeUser, nil, + NewTransaction( + TranNotifyChangeUser, [2]byte{0, 0}, NewField(FieldUserName, c.UserName), - NewField(FieldUserID, *c.ID), + NewField(FieldUserID, c.ID[:]), NewField(FieldUserIconID, c.Icon), - NewField(FieldUserFlags, c.Flags), + NewField(FieldUserFlags, c.Flags[:]), ), ) { c.Server.outbox <- t } } + c.Server.mux.Lock() c.Server.Stats.ConnectionCounter += 1 if len(s.Clients) > c.Server.Stats.ConnectionPeak { c.Server.Stats.ConnectionPeak = len(s.Clients) } + c.Server.mux.Unlock() // Scan for new transactions and handle them as they come in. for scanner.Scan() { - // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the - // scanner re-uses the buffer for subsequent scans. + // Copy the scanner bytes to a new slice to it to avoid a data race when the scanner re-uses the buffer. buf := make([]byte, len(scanner.Bytes())) copy(buf, scanner.Bytes()) @@ -748,26 +708,22 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser return err } - if err := c.handleTransaction(t); err != nil { - c.logger.Error("Error handling transaction", "err", err) - } + c.handleTransaction(t) } return nil } -func (s *Server) NewPrivateChat(cc *ClientConn) []byte { +func (s *Server) NewPrivateChat(cc *ClientConn) [4]byte { s.PrivateChatsMu.Lock() defer s.PrivateChatsMu.Unlock() - randID := make([]byte, 4) - _, _ = rand.Read(randID) + var randID [4]byte + _, _ = rand.Read(randID[:]) - data := binary.BigEndian.Uint32(randID) - - s.PrivateChats[data] = &PrivateChat{ - ClientConn: make(map[uint16]*ClientConn), + s.PrivateChats[randID] = &PrivateChat{ + ClientConn: make(map[[2]byte]*ClientConn), } - s.PrivateChats[data].ClientConn[cc.uint16ID()] = cc + s.PrivateChats[randID].ClientConn[cc.ID] = cc return randID } @@ -780,14 +736,10 @@ const dlFldrActionNextFile = 3 func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) error { defer dontPanic(s.Logger) - txBuf := make([]byte, 16) - if _, err := io.ReadFull(rwc, txBuf); err != nil { - return err - } - + // The first 16 bytes contain the file transfer. var t transfer - if _, err := t.Write(txBuf); err != nil { - return err + if _, err := io.CopyN(&t, rwc, 16); err != nil { + return fmt.Errorf("error reading file transfer: %w", err) } defer func() { @@ -837,55 +789,9 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro s.Stats.DownloadsInProgress -= 1 }() - var dataOffset int64 - if fileTransfer.fileResumeData != nil { - dataOffset = int64(binary.BigEndian.Uint32(fileTransfer.fileResumeData.ForkInfoList[0].DataSize[:])) - } - - fw, err := newFileWrapper(s.FS, fullPath, 0) - if err != nil { - return err - } - - rLogger.Info("File download started", "filePath", fullPath) - - // if file transfer options are included, that means this is a "quick preview" request from a 1.5+ client - if fileTransfer.options == nil { - _, err = io.Copy(rwc, fw.ffo) - if err != nil { - return err - } - } - - file, err := fw.dataForkReader() + err = DownloadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, true) if err != nil { - return err - } - - br := bufio.NewReader(file) - if _, err := br.Discard(int(dataOffset)); err != nil { - return err - } - - if _, err = io.Copy(rwc, io.TeeReader(br, fileTransfer.bytesSentCounter)); err != nil { - return err - } - - // if the client requested to resume transfer, do not send the resource fork header, or it will be appended into the fileWrapper data - if fileTransfer.fileResumeData == nil { - err = binary.Write(rwc, binary.BigEndian, fw.rsrcForkHeader()) - if err != nil { - return err - } - } - - rFile, err := fw.rsrcForkFile() - if err != nil { - return nil - } - - if _, err = io.Copy(rwc, io.TeeReader(rFile, fileTransfer.bytesSentCounter)); err != nil { - return err + return fmt.Errorf("file download error: %w", err) } case FileUpload: @@ -893,224 +799,19 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro s.Stats.UploadsInProgress += 1 defer func() { s.Stats.UploadsInProgress -= 1 }() - var file *os.File - - // A file upload has three possible cases: - // 1) Upload a new file - // 2) Resume a partially transferred file - // 3) Replace a fully uploaded file - // We have to infer which case applies by inspecting what is already on the filesystem - - // 1) Check for existing file: - _, err = os.Stat(fullPath) - if err == nil { - return errors.New("existing file found at " + fullPath) - } - if errors.Is(err, fs.ErrNotExist) { - // If not found, open or create a new .incomplete file - file, err = os.OpenFile(fullPath+incompleteFileSuffix, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - return err - } - } - - f, err := newFileWrapper(s.FS, fullPath, 0) + err = UploadHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return err - } - - rLogger.Info("File upload started", "dstFile", fullPath) - - rForkWriter := io.Discard - iForkWriter := io.Discard - if s.Config.PreserveResourceForks { - rForkWriter, err = f.rsrcForkWriter() - if err != nil { - return err - } - - iForkWriter, err = f.infoForkWriter() - if err != nil { - return err - } - } - - if err := receiveFile(rwc, file, rForkWriter, iForkWriter, fileTransfer.bytesSentCounter); err != nil { - s.Logger.Error(err.Error()) - } - - if err := file.Close(); err != nil { - return err - } - - if err := s.FS.Rename(fullPath+".incomplete", fullPath); err != nil { - return err + return fmt.Errorf("file upload error: %w", err) } - rLogger.Info("File upload complete", "dstFile", fullPath) - case FolderDownload: s.Stats.DownloadCounter += 1 s.Stats.DownloadsInProgress += 1 defer func() { s.Stats.DownloadsInProgress -= 1 }() - // Folder Download flow: - // 1. Get filePath from the transfer - // 2. Iterate over files - // 3. For each fileWrapper: - // Send fileWrapper header to client - // The client can reply in 3 ways: - // - // 1. If type is an odd number (unknown type?), or fileWrapper download for the current fileWrapper is completed: - // client sends []byte{0x00, 0x03} to tell the server to continue to the next fileWrapper - // - // 2. If download of a fileWrapper is to be resumed: - // client sends: - // []byte{0x00, 0x02} // download folder action - // [2]byte // Resume data size - // []byte fileWrapper resume data (see myField_FileResumeData) - // - // 3. Otherwise, download of the fileWrapper is requested and client sends []byte{0x00, 0x01} - // - // When download is requested (case 2 or 3), server replies with: - // [4]byte - fileWrapper size - // []byte - Flattened File Object - // - // After every fileWrapper download, client could request next fileWrapper with: - // []byte{0x00, 0x03} - // - // This notifies the server to send the next item header - - basePathLen := len(fullPath) - - rLogger.Info("Start folder download", "path", fullPath) - - nextAction := make([]byte, 2) - if _, err := io.ReadFull(rwc, nextAction); err != nil { - return err - } - - i := 0 - err = filepath.Walk(fullPath+"/", func(path string, info os.FileInfo, err error) error { - s.Stats.DownloadCounter += 1 - i += 1 - - if err != nil { - return err - } - - // skip dot files - if strings.HasPrefix(info.Name(), ".") { - return nil - } - - hlFile, err := newFileWrapper(s.FS, path, 0) - if err != nil { - return err - } - - subPath := path[basePathLen+1:] - rLogger.Debug("Sending fileheader", "i", i, "path", path, "fullFilePath", fullPath, "subPath", subPath, "IsDir", info.IsDir()) - - if i == 1 { - return nil - } - - fileHeader := NewFileHeader(subPath, info.IsDir()) - if _, err := io.Copy(rwc, &fileHeader); err != nil { - return fmt.Errorf("error sending file header: %w", err) - } - - // Read the client's Next Action request - if _, err := io.ReadFull(rwc, nextAction); err != nil { - return err - } - - rLogger.Debug("Client folder download action", "action", fmt.Sprintf("%X", nextAction[0:2])) - - var dataOffset int64 - - switch nextAction[1] { - case dlFldrActionResumeFile: - // get size of resumeData - resumeDataByteLen := make([]byte, 2) - if _, err := io.ReadFull(rwc, resumeDataByteLen); err != nil { - return err - } - - resumeDataLen := binary.BigEndian.Uint16(resumeDataByteLen) - resumeDataBytes := make([]byte, resumeDataLen) - if _, err := io.ReadFull(rwc, resumeDataBytes); err != nil { - return err - } - - var frd FileResumeData - if err := frd.UnmarshalBinary(resumeDataBytes); err != nil { - return err - } - dataOffset = int64(binary.BigEndian.Uint32(frd.ForkInfoList[0].DataSize[:])) - case dlFldrActionNextFile: - // client asked to skip this file - return nil - } - - if info.IsDir() { - return nil - } - - rLogger.Info("File download started", - "fileName", info.Name(), - "TransferSize", fmt.Sprintf("%x", hlFile.ffo.TransferSize(dataOffset)), - ) - - // Send file size to client - if _, err := rwc.Write(hlFile.ffo.TransferSize(dataOffset)); err != nil { - s.Logger.Error(err.Error()) - return err - } - - // Send ffo bytes to client - _, err = io.Copy(rwc, hlFile.ffo) - if err != nil { - return err - } - - file, err := s.FS.Open(path) - if err != nil { - return err - } - - // wr := bufio.NewWriterSize(rwc, 1460) - if _, err = io.Copy(rwc, io.TeeReader(file, fileTransfer.bytesSentCounter)); err != nil { - return err - } - - if nextAction[1] != 2 && hlFile.ffo.FlatFileHeader.ForkCount[1] == 3 { - err = binary.Write(rwc, binary.BigEndian, hlFile.rsrcForkHeader()) - if err != nil { - return err - } - - rFile, err := hlFile.rsrcForkFile() - if err != nil { - return err - } - - if _, err = io.Copy(rwc, io.TeeReader(rFile, fileTransfer.bytesSentCounter)); err != nil { - return err - } - } - - // Read the client's Next Action request. This is always 3, I think? - if _, err := io.ReadFull(rwc, nextAction); err != nil { - return err - } - - return nil - }) - + err = DownloadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) if err != nil { - return err + return fmt.Errorf("file upload error: %w", err) } case FolderUpload: @@ -1124,168 +825,10 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro "FolderItemCount", fileTransfer.FolderItemCount, ) - // Check if the target folder exists. If not, create it. - if _, err := s.FS.Stat(fullPath); os.IsNotExist(err) { - if err := s.FS.Mkdir(fullPath, 0777); err != nil { - return err - } - } - - // Begin the folder upload flow by sending the "next file action" to client - if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { - return err - } - - fileSize := make([]byte, 4) - - for i := 0; i < fileTransfer.ItemCount(); i++ { - s.Stats.UploadCounter += 1 - - var fu folderUpload - if _, err := io.ReadFull(rwc, fu.DataSize[:]); err != nil { - return err - } - if _, err := io.ReadFull(rwc, fu.IsFolder[:]); err != nil { - return err - } - if _, err := io.ReadFull(rwc, fu.PathItemCount[:]); err != nil { - return err - } - - fu.FileNamePath = make([]byte, binary.BigEndian.Uint16(fu.DataSize[:])-4) // -4 to subtract the path separator bytes - - if _, err := io.ReadFull(rwc, fu.FileNamePath); err != nil { - return err - } - - rLogger.Info( - "Folder upload continued", - "FormattedPath", fu.FormattedPath(), - "IsFolder", fmt.Sprintf("%x", fu.IsFolder), - "PathItemCount", binary.BigEndian.Uint16(fu.PathItemCount[:]), - ) - - if fu.IsFolder == [2]byte{0, 1} { - if _, err := os.Stat(filepath.Join(fullPath, fu.FormattedPath())); os.IsNotExist(err) { - if err := os.Mkdir(filepath.Join(fullPath, fu.FormattedPath()), 0777); err != nil { - return err - } - } - - // Tell client to send next file - if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { - return err - } - } else { - nextAction := dlFldrActionSendFile - - // Check if we have the full file already. If so, send dlFldrAction_NextFile to client to skip. - _, err = os.Stat(filepath.Join(fullPath, fu.FormattedPath())) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - if err == nil { - nextAction = dlFldrActionNextFile - } - - // Check if we have a partial file already. If so, send dlFldrAction_ResumeFile to client to resume upload. - incompleteFile, err := os.Stat(filepath.Join(fullPath, fu.FormattedPath()+incompleteFileSuffix)) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - if err == nil { - nextAction = dlFldrActionResumeFile - } - - if _, err := rwc.Write([]byte{0, uint8(nextAction)}); err != nil { - return err - } - - switch nextAction { - case dlFldrActionNextFile: - continue - case dlFldrActionResumeFile: - offset := make([]byte, 4) - binary.BigEndian.PutUint32(offset, uint32(incompleteFile.Size())) - - file, err := os.OpenFile(fullPath+"/"+fu.FormattedPath()+incompleteFileSuffix, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return err - } - - fileResumeData := NewFileResumeData([]ForkInfoList{*NewForkInfoList(offset)}) - - b, _ := fileResumeData.BinaryMarshal() - - bs := make([]byte, 2) - binary.BigEndian.PutUint16(bs, uint16(len(b))) - - if _, err := rwc.Write(append(bs, b...)); err != nil { - return err - } - - if _, err := io.ReadFull(rwc, fileSize); err != nil { - return err - } - - if err := receiveFile(rwc, file, io.Discard, io.Discard, fileTransfer.bytesSentCounter); err != nil { - s.Logger.Error(err.Error()) - } - - err = os.Rename(fullPath+"/"+fu.FormattedPath()+".incomplete", fullPath+"/"+fu.FormattedPath()) - if err != nil { - return err - } - - case dlFldrActionSendFile: - if _, err := io.ReadFull(rwc, fileSize); err != nil { - return err - } - - filePath := filepath.Join(fullPath, fu.FormattedPath()) - - hlFile, err := newFileWrapper(s.FS, filePath, 0) - if err != nil { - return err - } - - rLogger.Info("Starting file transfer", "path", filePath, "fileNum", i+1, "fileSize", binary.BigEndian.Uint32(fileSize)) - - incWriter, err := hlFile.incFileWriter() - if err != nil { - return err - } - - rForkWriter := io.Discard - iForkWriter := io.Discard - if s.Config.PreserveResourceForks { - iForkWriter, err = hlFile.infoForkWriter() - if err != nil { - return err - } - - rForkWriter, err = hlFile.rsrcForkWriter() - if err != nil { - return err - } - } - if err := receiveFile(rwc, incWriter, rForkWriter, iForkWriter, fileTransfer.bytesSentCounter); err != nil { - return err - } - - if err := os.Rename(filePath+".incomplete", filePath); err != nil { - return err - } - } - - // Tell client to send next fileWrapper - if _, err := rwc.Write([]byte{0, dlFldrActionNextFile}); err != nil { - return err - } - } + err = UploadFolderHandler(rwc, fullPath, fileTransfer, s.FS, rLogger, s.Config.PreserveResourceForks) + if err != nil { + return fmt.Errorf("file upload error: %w", err) } - rLogger.Info("Folder upload complete") } - return nil } diff --git a/hotline/server_blackbox_test.go b/hotline/server_blackbox_test.go index 28aecf0..06e7771 100644 --- a/hotline/server_blackbox_test.go +++ b/hotline/server_blackbox_test.go @@ -1,10 +1,13 @@ package hotline import ( + "cmp" + "encoding/binary" "encoding/hex" "github.com/stretchr/testify/assert" "log/slog" "os" + "slices" "testing" ) @@ -30,6 +33,13 @@ func assertTransferBytesEqual(t *testing.T, wantHexDump string, got []byte) bool return assert.Equal(t, wantHexDump, hex.Dump(clean)) } +var tranSortFunc = func(a, b Transaction) int { + return cmp.Compare( + binary.BigEndian.Uint16(a.clientID[:]), + binary.BigEndian.Uint16(b.clientID[:]), + ) +} + // tranAssertEqual compares equality of transactions slices after stripping out the random transaction ID func tranAssertEqual(t *testing.T, tran1, tran2 []Transaction) bool { var newT1 []Transaction @@ -45,7 +55,7 @@ func tranAssertEqual(t *testing.T, tran1, tran2 []Transaction) bool { if field.ID == [2]byte{0x00, 0x72} { // FieldChatID continue } - trans.Fields = append(trans.Fields, field) + fs = append(fs, field) } trans.Fields = fs newT1 = append(newT1, trans) @@ -61,11 +71,14 @@ func tranAssertEqual(t *testing.T, tran1, tran2 []Transaction) bool { if field.ID == [2]byte{0x00, 0x72} { // FieldChatID continue } - trans.Fields = append(trans.Fields, field) + fs = append(fs, field) } trans.Fields = fs newT2 = append(newT2, trans) } + slices.SortFunc(newT1, tranSortFunc) + slices.SortFunc(newT2, tranSortFunc) + return assert.Equal(t, newT1, newT2) } diff --git a/hotline/server_test.go b/hotline/server_test.go index 6ea8880..a29a4f8 100644 --- a/hotline/server_test.go +++ b/hotline/server_test.go @@ -30,7 +30,7 @@ func TestServer_handleFileTransfer(t *testing.T) { Port int Accounts map[string]*Account Agreement []byte - Clients map[uint16]*ClientConn + Clients map[[2]byte]*ClientConn ThreadedNews *ThreadedNews fileTransfers map[[4]byte]*FileTransfer Config *Config @@ -115,10 +115,10 @@ func TestServer_handleFileTransfer(t *testing.T) { Stats: &Stats{}, fileTransfers: map[[4]byte]*FileTransfer{ {0, 0, 0, 5}: { - ReferenceNumber: []byte{0, 0, 0, 5}, - Type: FileDownload, - FileName: []byte("testfile-8b"), - FilePath: []byte{}, + refNum: [4]byte{0, 0, 0, 5}, + Type: FileDownload, + FileName: []byte("testfile-8b"), + FilePath: []byte{}, ClientConn: &ClientConn{ Account: &Account{ Login: "foo", @@ -190,3 +190,61 @@ func TestServer_handleFileTransfer(t *testing.T) { }) } } + +type TestData struct { + Name string `yaml:"name"` + Value int `yaml:"value"` +} + +func TestLoadFromYAMLFile(t *testing.T) { + tests := []struct { + name string + fileName string + content string + wantData TestData + wantErr bool + }{ + { + name: "Valid YAML file", + fileName: "valid.yaml", + content: "name: Test\nvalue: 123\n", + wantData: TestData{Name: "Test", Value: 123}, + wantErr: false, + }, + { + name: "File not found", + fileName: "nonexistent.yaml", + content: "", + wantData: TestData{}, + wantErr: true, + }, + { + name: "Invalid YAML content", + fileName: "invalid.yaml", + content: "name: Test\nvalue: invalid_int\n", + wantData: TestData{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup: Create a temporary file with the provided content if content is not empty + if tt.content != "" { + err := os.WriteFile(tt.fileName, []byte(tt.content), 0644) + assert.NoError(t, err) + defer os.Remove(tt.fileName) // Cleanup the file after the test + } + + var data TestData + err := loadFromYAMLFile(tt.fileName, &data) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantData, data) + } + }) + } +} diff --git a/hotline/stats.go b/hotline/stats.go index bc06a69..d1a3c5f 100644 --- a/hotline/stats.go +++ b/hotline/stats.go @@ -1,6 +1,7 @@ package hotline import ( + "sync" "time" ) @@ -14,4 +15,6 @@ type Stats struct { DownloadCounter int UploadCounter int Since time.Time + + sync.Mutex } diff --git a/hotline/tracker.go b/hotline/tracker.go index b927986..d0ccff4 100644 --- a/hotline/tracker.go +++ b/hotline/tracker.go @@ -3,12 +3,12 @@ package hotline import ( "bufio" "encoding/binary" + "errors" "fmt" "io" "net" "slices" "strconv" - "time" ) // TrackerRegistration represents the payload a Hotline server sends to a Tracker to register @@ -49,8 +49,20 @@ func (tr *TrackerRegistration) Read(p []byte) (int, error) { return n, nil } -func register(tracker string, tr *TrackerRegistration) error { - conn, err := net.Dial("udp", tracker) +// Dialer interface to abstract the dialing operation +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +// RealDialer is the real implementation of the Dialer interface +type RealDialer struct{} + +func (d *RealDialer) Dial(network, address string) (net.Conn, error) { + return net.Dial(network, address) +} + +func register(dialer Dialer, tracker string, tr io.Reader) error { + conn, err := dialer.Dial("udp", tracker) if err != nil { return fmt.Errorf("failed to dial tracker: %w", err) } @@ -63,8 +75,6 @@ func register(tracker string, tr *TrackerRegistration) error { return nil } -const trackerTimeout = 5 * time.Second - // All string values use 8-bit ASCII character set encoding. // Client Interface with Tracker // After establishing a connection with tracker, the following information is sent: @@ -72,23 +82,20 @@ const trackerTimeout = 5 * time.Second // Magic number 4 ‘HTRK’ // Version 2 1 or 2 Old protocol (1) or new (2) -// Reply received from the tracker starts with a header: +// TrackerHeader is sent in reply Reply received from the tracker starts with a header: type TrackerHeader struct { Protocol [4]byte // "HTRK" 0x4854524B Version [2]byte // Old protocol (1) or new (2) } -// Message type 2 1 Sending list of servers -// Message data size 2 Remaining size of this request -// Number of servers 2 Number of servers in the server list -// Number of servers 2 Same as previous field type ServerInfoHeader struct { - MsgType [2]byte // always has value of 1 + MsgType [2]byte // Always has value of 1 MsgDataSize [2]byte // Remaining size of request SrvCount [2]byte // Number of servers in the server list SrvCountDup [2]byte // Same as previous field ¯\_(ツ)_/¯ } +// ServerRecord is a tracker listing for a single server type ServerRecord struct { IPAddr [4]byte Port [2]byte @@ -100,14 +107,10 @@ type ServerRecord struct { Description []byte } -func GetListing(addr string) ([]ServerRecord, error) { - conn, err := net.DialTimeout("tcp", addr, trackerTimeout) - if err != nil { - return []ServerRecord{}, err - } +func GetListing(conn io.ReadWriteCloser) ([]ServerRecord, error) { defer func() { _ = conn.Close() }() - _, err = conn.Write( + _, err := conn.Write( []byte{ 0x48, 0x54, 0x52, 0x4B, // HTRK 0x00, 0x01, // Version @@ -189,9 +192,13 @@ func serverScanner(data []byte, _ bool) (advance int, token []byte, err error) { // Write implements io.Writer for ServerRecord func (s *ServerRecord) Write(b []byte) (n int, err error) { + if len(b) < 13 { + return 0, errors.New("too few bytes") + } copy(s.IPAddr[:], b[0:4]) copy(s.Port[:], b[4:6]) copy(s.NumUsers[:], b[6:8]) + s.NameSize = b[10] nameLen := int(b[10]) s.Name = b[11 : 11+nameLen] diff --git a/hotline/tracker_test.go b/hotline/tracker_test.go index b5c5c55..14c0363 100644 --- a/hotline/tracker_test.go +++ b/hotline/tracker_test.go @@ -1,8 +1,10 @@ package hotline import ( + "bytes" "fmt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io" "reflect" "testing" @@ -191,3 +193,157 @@ func Test_serverScanner(t *testing.T) { }) } } + +type mockConn struct { + readBuffer *bytes.Buffer + writeBuffer *bytes.Buffer + closed bool +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + return m.readBuffer.Read(b) +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + return m.writeBuffer.Write(b) +} + +func (m *mockConn) Close() error { + m.closed = true + return nil +} + +func TestGetListing(t *testing.T) { + tests := []struct { + name string + mockConn *mockConn + wantErr bool + wantResult []ServerRecord + }{ + { + name: "Successful retrieval", + mockConn: &mockConn{ + readBuffer: bytes.NewBuffer([]byte{ + // TrackerHeader + 0x48, 0x54, 0x52, 0x4B, // Protocol "HTRK" + 0x00, 0x01, // Version 1 + // ServerInfoHeader + 0x00, 0x01, // MsgType (1) + 0x00, 0x14, // MsgDataSize (20) + 0x00, 0x02, // SrvCount (2) + 0x00, 0x02, // SrvCountDup (2) + // ServerRecord 1 + 192, 168, 1, 1, // IP address + 0x1F, 0x90, // Port 8080 + 0x00, 0x10, // NumUsers 16 + 0x00, 0x00, // Unused + 0x04, // NameSize + 'S', 'e', 'r', 'v', // Name + 0x0B, // DescriptionSize + 'M', 'y', ' ', 'S', 'e', 'r', 'v', 'e', 'r', ' ', '1', // Description + // ServerRecord 2 + 10, 0, 0, 1, // IP address + 0x1F, 0x91, // Port 8081 + 0x00, 0x05, // NumUsers 5 + 0x00, 0x00, // Unused + 0x04, // NameSize + 'S', 'e', 'r', 'v', // Name + 0x0B, // DescriptionSize + 'M', 'y', ' ', 'S', 'e', 'r', 'v', 'e', 'r', ' ', '2', // Description + }), + writeBuffer: &bytes.Buffer{}, + }, + wantErr: false, + wantResult: []ServerRecord{ + { + IPAddr: [4]byte{192, 168, 1, 1}, + Port: [2]byte{0x1F, 0x90}, + NumUsers: [2]byte{0x00, 0x10}, + Unused: [2]byte{0x00, 0x00}, + NameSize: 4, + Name: []byte("Serv"), + DescriptionSize: 11, + Description: []byte("My Server 1"), + }, + { + IPAddr: [4]byte{10, 0, 0, 1}, + Port: [2]byte{0x1F, 0x91}, + NumUsers: [2]byte{0x00, 0x05}, + Unused: [2]byte{0x00, 0x00}, + NameSize: 4, + Name: []byte("Serv"), + DescriptionSize: 11, + Description: []byte("My Server 2"), + }, + }, + }, + { + name: "Write error", + mockConn: &mockConn{ + readBuffer: &bytes.Buffer{}, + writeBuffer: &bytes.Buffer{}, + }, + wantErr: true, + wantResult: nil, + }, + { + name: "Read error on TrackerHeader", + mockConn: &mockConn{ + readBuffer: bytes.NewBuffer([]byte{ + // incomplete data to cause read error + 0x48, + }), + writeBuffer: &bytes.Buffer{}, + }, + wantErr: true, + wantResult: nil, + }, + { + name: "Read error on ServerInfoHeader", + mockConn: &mockConn{ + readBuffer: bytes.NewBuffer([]byte{ + // TrackerHeader + 0x48, 0x54, 0x52, 0x4B, // Protocol "HTRK" + 0x00, 0x01, // Version 1 + // incomplete ServerInfoHeader + 0x00, + }), + writeBuffer: &bytes.Buffer{}, + }, + wantErr: true, + wantResult: nil, + }, + { + name: "Scanner error", + mockConn: &mockConn{ + readBuffer: bytes.NewBuffer([]byte{ + // TrackerHeader + 0x48, 0x54, 0x52, 0x4B, // Protocol "HTRK" + 0x00, 0x01, // Version 1 + // ServerInfoHeader + 0x00, 0x01, // MsgType (1) + 0x00, 0x14, // MsgDataSize (20) + 0x00, 0x01, // SrvCount (1) + 0x00, 0x01, // SrvCountDup (1) + // incomplete ServerRecord to cause scanner error + 192, 168, 1, 1, + }), + writeBuffer: &bytes.Buffer{}, + }, + wantErr: true, + wantResult: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetListing(tt.mockConn) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantResult, got) + } + }) + } +} diff --git a/hotline/transaction.go b/hotline/transaction.go index 0caf72f..f8c7dfd 100644 --- a/hotline/transaction.go +++ b/hotline/transaction.go @@ -7,119 +7,189 @@ import ( "errors" "fmt" "io" + "log/slog" "math/rand" "slices" ) -const ( - TranError = 0 - TranGetMsgs = 101 - TranNewMsg = 102 - TranOldPostNews = 103 - TranServerMsg = 104 - TranChatSend = 105 - TranChatMsg = 106 - TranLogin = 107 - TranSendInstantMsg = 108 - TranShowAgreement = 109 - TranDisconnectUser = 110 - TranDisconnectMsg = 111 // TODO: implement server initiated friendly disconnect - TranInviteNewChat = 112 - TranInviteToChat = 113 - TranRejectChatInvite = 114 - TranJoinChat = 115 - TranLeaveChat = 116 - TranNotifyChatChangeUser = 117 - TranNotifyChatDeleteUser = 118 - TranNotifyChatSubject = 119 - TranSetChatSubject = 120 - TranAgreed = 121 - TranServerBanner = 122 - TranGetFileNameList = 200 - TranDownloadFile = 202 - TranUploadFile = 203 - TranNewFolder = 205 - TranDeleteFile = 204 - TranGetFileInfo = 206 - TranSetFileInfo = 207 - TranMoveFile = 208 - TranMakeFileAlias = 209 - TranDownloadFldr = 210 - TranDownloadInfo = 211 // TODO: implement file transfer queue - TranDownloadBanner = 212 - TranUploadFldr = 213 - TranGetUserNameList = 300 - TranNotifyChangeUser = 301 - TranNotifyDeleteUser = 302 - TranGetClientInfoText = 303 - TranSetClientUserInfo = 304 - TranListUsers = 348 - TranUpdateUser = 349 - TranNewUser = 350 - TranDeleteUser = 351 - TranGetUser = 352 - TranSetUser = 353 - TranUserAccess = 354 - TranUserBroadcast = 355 - TranGetNewsCatNameList = 370 - TranGetNewsArtNameList = 371 - TranDelNewsItem = 380 - TranNewNewsFldr = 381 - TranNewNewsCat = 382 - TranGetNewsArtData = 400 - TranPostNewsArt = 410 - TranDelNewsArt = 411 - TranKeepAlive = 500 +var ( + TranError = [2]byte{0x00, 0x00} // 0 + TranGetMsgs = [2]byte{0x00, 0x65} // 101 + TranNewMsg = [2]byte{0x00, 0x66} // 102 + TranOldPostNews = [2]byte{0x00, 0x67} // 103 + TranServerMsg = [2]byte{0x00, 0x68} // 104 + TranChatSend = [2]byte{0x00, 0x69} // 105 + TranChatMsg = [2]byte{0x00, 0x6A} // 106 + TranLogin = [2]byte{0x00, 0x6B} // 107 + TranSendInstantMsg = [2]byte{0x00, 0x6C} // 108 + TranShowAgreement = [2]byte{0x00, 0x6D} // 109 + TranDisconnectUser = [2]byte{0x00, 0x6E} // 110 + TranDisconnectMsg = [2]byte{0x00, 0x6F} // 111 + TranInviteNewChat = [2]byte{0x00, 0x70} // 112 + TranInviteToChat = [2]byte{0x00, 0x71} // 113 + TranRejectChatInvite = [2]byte{0x00, 0x72} // 114 + TranJoinChat = [2]byte{0x00, 0x73} // 115 + TranLeaveChat = [2]byte{0x00, 0x74} // 116 + TranNotifyChatChangeUser = [2]byte{0x00, 0x75} // 117 + TranNotifyChatDeleteUser = [2]byte{0x00, 0x76} // 118 + TranNotifyChatSubject = [2]byte{0x00, 0x77} // 119 + TranSetChatSubject = [2]byte{0x00, 0x78} // 120 + TranAgreed = [2]byte{0x00, 0x79} // 121 + TranServerBanner = [2]byte{0x00, 0x7A} // 122 + TranGetFileNameList = [2]byte{0x00, 0xC8} // 200 + TranDownloadFile = [2]byte{0x00, 0xCA} // 202 + TranUploadFile = [2]byte{0x00, 0xCB} // 203 + TranNewFolder = [2]byte{0x00, 0xCD} // 205 + TranDeleteFile = [2]byte{0x00, 0xCC} // 204 + TranGetFileInfo = [2]byte{0x00, 0xCE} // 206 + TranSetFileInfo = [2]byte{0x00, 0xCF} // 207 + TranMoveFile = [2]byte{0x00, 0xD0} // 208 + TranMakeFileAlias = [2]byte{0x00, 0xD1} // 209 + TranDownloadFldr = [2]byte{0x00, 0xD2} // 210 + TranDownloadInfo = [2]byte{0x00, 0xD3} // 211 + TranDownloadBanner = [2]byte{0x00, 0xD4} // 212 + TranUploadFldr = [2]byte{0x00, 0xD5} // 213 + TranGetUserNameList = [2]byte{0x01, 0x2C} // 300 + TranNotifyChangeUser = [2]byte{0x01, 0x2D} // 301 + TranNotifyDeleteUser = [2]byte{0x01, 0x2E} // 302 + TranGetClientInfoText = [2]byte{0x01, 0x2F} // 303 + TranSetClientUserInfo = [2]byte{0x01, 0x30} // 304 + TranListUsers = [2]byte{0x01, 0x5C} // 348 + TranUpdateUser = [2]byte{0x01, 0x5D} // 349 + TranNewUser = [2]byte{0x01, 0x5E} // 350 + TranDeleteUser = [2]byte{0x01, 0x5F} // 351 + TranGetUser = [2]byte{0x01, 0x60} // 352 + TranSetUser = [2]byte{0x01, 0x61} // 353 + TranUserAccess = [2]byte{0x01, 0x62} // 354 + TranUserBroadcast = [2]byte{0x01, 0x63} // 355 + TranGetNewsCatNameList = [2]byte{0x01, 0x72} // 370 + TranGetNewsArtNameList = [2]byte{0x01, 0x73} // 371 + TranDelNewsItem = [2]byte{0x01, 0x7C} // 380 + TranNewNewsFldr = [2]byte{0x01, 0x7D} // 381 + TranNewNewsCat = [2]byte{0x01, 0x7E} // 382 + TranGetNewsArtData = [2]byte{0x01, 0x90} // 400 + TranPostNewsArt = [2]byte{0x01, 0x9A} // 410 + TranDelNewsArt = [2]byte{0x01, 0x9B} // 411 + TranKeepAlive = [2]byte{0x01, 0xF4} // 500 ) type Transaction struct { - Flags byte // Reserved (should be 0) - IsReply byte // Request (0) or reply (1) - Type [2]byte // Requested operation (user defined) - ID [4]byte // Unique transaction ID (must be != 0) - ErrorCode [4]byte // Used in the reply (user defined, 0 = no error) - TotalSize [4]byte // Total data size for the transaction (all parts) - DataSize [4]byte // Size of data in this transaction part. This allows splitting large transactions into smaller parts. - ParamCount [2]byte // Number of the parameters for this transaction + Flags byte // Reserved (should be 0) + IsReply byte // Request (0) or reply (1) + Type TranType // Requested operation (user defined) + ID [4]byte // Unique transaction ID (must be != 0) + ErrorCode [4]byte // Used in the reply (user defined, 0 = no error) + TotalSize [4]byte // Total data size for the fields in this transaction. + DataSize [4]byte // Size of data in this transaction part. This allows splitting large transactions into smaller parts. + ParamCount [2]byte // Number of the parameters for this transaction Fields []Field - clientID *[]byte // Internal identifier for target client + clientID [2]byte // Internal identifier for target client readOffset int // Internal offset to track read progress } -func NewTransaction(t int, clientID *[]byte, fields ...Field) *Transaction { - typeSlice := make([]byte, 2) - binary.BigEndian.PutUint16(typeSlice, uint16(t)) +type TranType [2]byte + +var tranTypeNames = map[TranType]string{ + TranChatMsg: "Receive Chat", + TranNotifyChangeUser: "TranNotifyChangeUser", + TranError: "TranError", + TranShowAgreement: "TranShowAgreement", + TranUserAccess: "TranUserAccess", + TranNotifyDeleteUser: "TranNotifyDeleteUser", + TranAgreed: "TranAgreed", + TranChatSend: "Send Chat", + TranDelNewsArt: "TranDelNewsArt", + TranDelNewsItem: "TranDelNewsItem", + TranDeleteFile: "TranDeleteFile", + TranDeleteUser: "TranDeleteUser", + TranDisconnectUser: "TranDisconnectUser", + TranDownloadFile: "TranDownloadFile", + TranDownloadFldr: "TranDownloadFldr", + TranGetClientInfoText: "TranGetClientInfoText", + TranGetFileInfo: "TranGetFileInfo", + TranGetFileNameList: "TranGetFileNameList", + TranGetMsgs: "TranGetMsgs", + TranGetNewsArtData: "TranGetNewsArtData", + TranGetNewsArtNameList: "TranGetNewsArtNameList", + TranGetNewsCatNameList: "TranGetNewsCatNameList", + TranGetUser: "TranGetUser", + TranGetUserNameList: "tranHandleGetUserNameList", + TranInviteNewChat: "TranInviteNewChat", + TranInviteToChat: "TranInviteToChat", + TranJoinChat: "TranJoinChat", + TranKeepAlive: "TranKeepAlive", + TranLeaveChat: "TranJoinChat", + TranListUsers: "TranListUsers", + TranMoveFile: "TranMoveFile", + TranNewFolder: "TranNewFolder", + TranNewNewsCat: "TranNewNewsCat", + TranNewNewsFldr: "TranNewNewsFldr", + TranNewUser: "TranNewUser", + TranUpdateUser: "TranUpdateUser", + TranOldPostNews: "TranOldPostNews", + TranPostNewsArt: "TranPostNewsArt", + TranRejectChatInvite: "TranRejectChatInvite", + TranSendInstantMsg: "TranSendInstantMsg", + TranSetChatSubject: "TranSetChatSubject", + TranMakeFileAlias: "TranMakeFileAlias", + TranSetClientUserInfo: "TranSetClientUserInfo", + TranSetFileInfo: "TranSetFileInfo", + TranSetUser: "TranSetUser", + TranUploadFile: "TranUploadFile", + TranUploadFldr: "TranUploadFldr", + TranUserBroadcast: "TranUserBroadcast", + TranDownloadBanner: "TranDownloadBanner", +} - idSlice := make([]byte, 4) - binary.BigEndian.PutUint32(idSlice, rand.Uint32()) +func (t TranType) LogValue() slog.Value { + return slog.StringValue(tranTypeNames[t]) +} - return &Transaction{ +// NewTransaction creates a new Transaction with the specified type, client ID, and optional fields. +func NewTransaction(t, clientID [2]byte, fields ...Field) Transaction { + transaction := Transaction{ + Type: t, clientID: clientID, - Type: [2]byte(typeSlice), - ID: [4]byte(idSlice), Fields: fields, } + + binary.BigEndian.PutUint32(transaction.ID[:], rand.Uint32()) + + return transaction } -// Write implements io.Writer interface for Transaction +// Write implements io.Writer interface for Transaction. +// Transactions read from the network are read as complete tokens with a bufio.Scanner, so +// the arg p is guaranteed to have the full byte payload of a complete transaction. func (t *Transaction) Write(p []byte) (n int, err error) { - totalSize := binary.BigEndian.Uint32(p[12:16]) + // Make sure we have the minimum number of bytes for a transaction. + if len(p) < 22 { + return 0, errors.New("buffer too small") + } - // the buf may include extra bytes that are not part of the transaction - // tranLen represents the length of bytes that are part of the transaction + // Read the total size field. + totalSize := binary.BigEndian.Uint32(p[12:16]) tranLen := int(20 + totalSize) - if tranLen > len(p) { - return n, errors.New("buflen too small for tranLen") - } + paramCount := binary.BigEndian.Uint16(p[20:22]) + + t.Flags = p[0] + t.IsReply = p[1] + copy(t.Type[:], p[2:4]) + copy(t.ID[:], p[4:8]) + copy(t.ErrorCode[:], p[8:12]) + copy(t.TotalSize[:], p[12:16]) + copy(t.DataSize[:], p[16:20]) + copy(t.ParamCount[:], p[20:22]) - // Create a new scanner for parsing incoming bytes into transaction tokens scanner := bufio.NewScanner(bytes.NewReader(p[22:tranLen])) scanner.Split(fieldScanner) - for i := 0; i < int(binary.BigEndian.Uint16(p[20:22])); i++ { - scanner.Scan() + for i := 0; i < int(paramCount); i++ { + if !scanner.Scan() { + return 0, fmt.Errorf("error scanning field: %w", scanner.Err()) + } var field Field if _, err := field.Write(scanner.Bytes()); err != nil { @@ -128,16 +198,11 @@ func (t *Transaction) Write(p []byte) (n int, err error) { t.Fields = append(t.Fields, field) } - t.Flags = p[0] - t.IsReply = p[1] - t.Type = [2]byte(p[2:4]) - t.ID = [4]byte(p[4:8]) - t.ErrorCode = [4]byte(p[8:12]) - t.TotalSize = [4]byte(p[12:16]) - t.DataSize = [4]byte(p[16:20]) - t.ParamCount = [2]byte(p[20:22]) - - return len(p), err + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("scanner error: %w", err) + } + + return len(p), nil } const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields @@ -253,16 +318,12 @@ func (t *Transaction) Size() []byte { return bs } -func (t *Transaction) GetField(id int) Field { +func (t *Transaction) GetField(id [2]byte) Field { for _, field := range t.Fields { - if id == int(binary.BigEndian.Uint16(field.ID[:])) { + if id == field.ID { return field } } return Field{} } - -func (t *Transaction) IsError() bool { - return t.ErrorCode == [4]byte{0, 0, 0, 1} -} diff --git a/hotline/transaction_handlers.go b/hotline/transaction_handlers.go index 3d7ba14..718dd59 100644 --- a/hotline/transaction_handlers.go +++ b/hotline/transaction_handlers.go @@ -4,8 +4,8 @@ import ( "bufio" "bytes" "encoding/binary" - "errors" "fmt" + "github.com/davecgh/go-spew/spew" "gopkg.in/yaml.v3" "io" "math/big" @@ -17,233 +17,59 @@ import ( "time" ) -type HandlerFunc func(*ClientConn, *Transaction) ([]Transaction, error) - -type TransactionType struct { - Handler HandlerFunc // function for handling the transaction type - Name string // Name of transaction as it will appear in logging - RequiredFields []requiredField -} - -var TransactionHandlers = map[uint16]TransactionType{ - // Server initiated - TranChatMsg: { - Name: "TranChatMsg", - }, - // Server initiated - TranNotifyChangeUser: { - Name: "TranNotifyChangeUser", - }, - TranError: { - Name: "TranError", - }, - TranShowAgreement: { - Name: "TranShowAgreement", - }, - TranUserAccess: { - Name: "TranUserAccess", - }, - TranNotifyDeleteUser: { - Name: "TranNotifyDeleteUser", - }, - TranAgreed: { - Name: "TranAgreed", - Handler: HandleTranAgreed, - }, - TranChatSend: { - Name: "TranChatSend", - Handler: HandleChatSend, - RequiredFields: []requiredField{ - { - ID: FieldData, - minLen: 0, - }, - }, - }, - TranDelNewsArt: { - Name: "TranDelNewsArt", - Handler: HandleDelNewsArt, - }, - TranDelNewsItem: { - Name: "TranDelNewsItem", - Handler: HandleDelNewsItem, - }, - TranDeleteFile: { - Name: "TranDeleteFile", - Handler: HandleDeleteFile, - }, - TranDeleteUser: { - Name: "TranDeleteUser", - Handler: HandleDeleteUser, - }, - TranDisconnectUser: { - Name: "TranDisconnectUser", - Handler: HandleDisconnectUser, - }, - TranDownloadFile: { - Name: "TranDownloadFile", - Handler: HandleDownloadFile, - }, - TranDownloadFldr: { - Name: "TranDownloadFldr", - Handler: HandleDownloadFolder, - }, - TranGetClientInfoText: { - Name: "TranGetClientInfoText", - Handler: HandleGetClientInfoText, - }, - TranGetFileInfo: { - Name: "TranGetFileInfo", - Handler: HandleGetFileInfo, - }, - TranGetFileNameList: { - Name: "TranGetFileNameList", - Handler: HandleGetFileNameList, - }, - TranGetMsgs: { - Name: "TranGetMsgs", - Handler: HandleGetMsgs, - }, - TranGetNewsArtData: { - Name: "TranGetNewsArtData", - Handler: HandleGetNewsArtData, - }, - TranGetNewsArtNameList: { - Name: "TranGetNewsArtNameList", - Handler: HandleGetNewsArtNameList, - }, - TranGetNewsCatNameList: { - Name: "TranGetNewsCatNameList", - Handler: HandleGetNewsCatNameList, - }, - TranGetUser: { - Name: "TranGetUser", - Handler: HandleGetUser, - }, - TranGetUserNameList: { - Name: "tranHandleGetUserNameList", - Handler: HandleGetUserNameList, - }, - TranInviteNewChat: { - Name: "TranInviteNewChat", - Handler: HandleInviteNewChat, - }, - TranInviteToChat: { - Name: "TranInviteToChat", - Handler: HandleInviteToChat, - }, - TranJoinChat: { - Name: "TranJoinChat", - Handler: HandleJoinChat, - }, - TranKeepAlive: { - Name: "TranKeepAlive", - Handler: HandleKeepAlive, - }, - TranLeaveChat: { - Name: "TranJoinChat", - Handler: HandleLeaveChat, - }, - TranListUsers: { - Name: "TranListUsers", - Handler: HandleListUsers, - }, - TranMoveFile: { - Name: "TranMoveFile", - Handler: HandleMoveFile, - }, - TranNewFolder: { - Name: "TranNewFolder", - Handler: HandleNewFolder, - }, - TranNewNewsCat: { - Name: "TranNewNewsCat", - Handler: HandleNewNewsCat, - }, - TranNewNewsFldr: { - Name: "TranNewNewsFldr", - Handler: HandleNewNewsFldr, - }, - TranNewUser: { - Name: "TranNewUser", - Handler: HandleNewUser, - }, - TranUpdateUser: { - Name: "TranUpdateUser", - Handler: HandleUpdateUser, - }, - TranOldPostNews: { - Name: "TranOldPostNews", - Handler: HandleTranOldPostNews, - }, - TranPostNewsArt: { - Name: "TranPostNewsArt", - Handler: HandlePostNewsArt, - }, - TranRejectChatInvite: { - Name: "TranRejectChatInvite", - Handler: HandleRejectChatInvite, - }, - TranSendInstantMsg: { - Name: "TranSendInstantMsg", - Handler: HandleSendInstantMsg, - RequiredFields: []requiredField{ - { - ID: FieldData, - minLen: 0, - }, - { - ID: FieldUserID, - }, - }, - }, - TranSetChatSubject: { - Name: "TranSetChatSubject", - Handler: HandleSetChatSubject, - }, - TranMakeFileAlias: { - Name: "TranMakeFileAlias", - Handler: HandleMakeAlias, - RequiredFields: []requiredField{ - {ID: FieldFileName, minLen: 1}, - {ID: FieldFilePath, minLen: 1}, - {ID: FieldFileNewPath, minLen: 1}, - }, - }, - TranSetClientUserInfo: { - Name: "TranSetClientUserInfo", - Handler: HandleSetClientUserInfo, - }, - TranSetFileInfo: { - Name: "TranSetFileInfo", - Handler: HandleSetFileInfo, - }, - TranSetUser: { - Name: "TranSetUser", - Handler: HandleSetUser, - }, - TranUploadFile: { - Name: "TranUploadFile", - Handler: HandleUploadFile, - }, - TranUploadFldr: { - Name: "TranUploadFldr", - Handler: HandleUploadFolder, - }, - TranUserBroadcast: { - Name: "TranUserBroadcast", - Handler: HandleUserBroadcast, - }, - TranDownloadBanner: { - Name: "TranDownloadBanner", - Handler: HandleDownloadBanner, - }, +// HandlerFunc is the signature of a func to handle a Hotline transaction. +type HandlerFunc func(*ClientConn, *Transaction) []Transaction + +// TransactionHandlers maps a transaction type to a handler function. +var TransactionHandlers = map[TranType]HandlerFunc{ + TranAgreed: HandleTranAgreed, + TranChatSend: HandleChatSend, + TranDelNewsArt: HandleDelNewsArt, + TranDelNewsItem: HandleDelNewsItem, + TranDeleteFile: HandleDeleteFile, + TranDeleteUser: HandleDeleteUser, + TranDisconnectUser: HandleDisconnectUser, + TranDownloadFile: HandleDownloadFile, + TranDownloadFldr: HandleDownloadFolder, + TranGetClientInfoText: HandleGetClientInfoText, + TranGetFileInfo: HandleGetFileInfo, + TranGetFileNameList: HandleGetFileNameList, + TranGetMsgs: HandleGetMsgs, + TranGetNewsArtData: HandleGetNewsArtData, + TranGetNewsArtNameList: HandleGetNewsArtNameList, + TranGetNewsCatNameList: HandleGetNewsCatNameList, + TranGetUser: HandleGetUser, + TranGetUserNameList: HandleGetUserNameList, + TranInviteNewChat: HandleInviteNewChat, + TranInviteToChat: HandleInviteToChat, + TranJoinChat: HandleJoinChat, + TranKeepAlive: HandleKeepAlive, + TranLeaveChat: HandleLeaveChat, + TranListUsers: HandleListUsers, + TranMoveFile: HandleMoveFile, + TranNewFolder: HandleNewFolder, + TranNewNewsCat: HandleNewNewsCat, + TranNewNewsFldr: HandleNewNewsFldr, + TranNewUser: HandleNewUser, + TranUpdateUser: HandleUpdateUser, + TranOldPostNews: HandleTranOldPostNews, + TranPostNewsArt: HandlePostNewsArt, + TranRejectChatInvite: HandleRejectChatInvite, + TranSendInstantMsg: HandleSendInstantMsg, + TranSetChatSubject: HandleSetChatSubject, + TranMakeFileAlias: HandleMakeAlias, + TranSetClientUserInfo: HandleSetClientUserInfo, + TranSetFileInfo: HandleSetFileInfo, + TranSetUser: HandleSetUser, + TranUploadFile: HandleUploadFile, + TranUploadFldr: HandleUploadFolder, + TranUserBroadcast: HandleUserBroadcast, + TranDownloadBanner: HandleDownloadBanner, } -func HandleChatSend(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleChatSend(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessSendChat) { - res = append(res, cc.NewErrReply(t, "You are not allowed to participate in chat.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to participate in chat.") } // Truncate long usernames @@ -262,31 +88,33 @@ func HandleChatSend(cc *ClientConn, t *Transaction) (res []Transaction, err erro // All clients *except* Frogblast omit this field for public chat, but Frogblast sends a value of 00 00 00 00. chatID := t.GetField(FieldChatID).Data if chatID != nil && !bytes.Equal([]byte{0, 0, 0, 0}, chatID) { - chatInt := binary.BigEndian.Uint32(chatID) - privChat := cc.Server.PrivateChats[chatInt] - - clients := sortedClients(privChat.ClientConn) + privChat := cc.Server.PrivateChats[[4]byte(chatID)] // send the message to all connected clients of the private chat - for _, c := range clients { - res = append(res, *NewTransaction( + for _, c := range privChat.ClientConn { + res = append(res, NewTransaction( TranChatMsg, c.ID, NewField(FieldChatID, chatID), NewField(FieldData, []byte(formattedMsg)), )) } - return res, err + return res } - for _, c := range sortedClients(cc.Server.Clients) { - // Filter out clients that do not have the read chat permission + //cc.Server.mux.Lock() + for _, c := range cc.Server.Clients { + if c == nil || cc.Account == nil { + continue + } + // Skip clients that do not have the read chat permission. if c.Authorize(accessReadChat) { - res = append(res, *NewTransaction(TranChatMsg, c.ID, NewField(FieldData, []byte(formattedMsg)))) + res = append(res, NewTransaction(TranChatMsg, c.ID, NewField(FieldData, []byte(formattedMsg)))) } } + //cc.Server.mux.Unlock() - return res, err + return res } // HandleSendInstantMsg sends instant message to the user on the current server. @@ -304,21 +132,20 @@ func HandleChatSend(cc *ClientConn, t *Transaction) (res []Transaction, err erro // // Fields used in the reply: // None -func HandleSendInstantMsg(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleSendInstantMsg(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessSendPrivMsg) { - res = append(res, cc.NewErrReply(t, "You are not allowed to send private messages.")) - return res, errors.New("user is not allowed to send private messages") + return cc.NewErrReply(t, "You are not allowed to send private messages.") } msg := t.GetField(FieldData) - ID := t.GetField(FieldUserID) + userID := t.GetField(FieldUserID) reply := NewTransaction( TranServerMsg, - &ID.Data, + [2]byte(userID.Data), NewField(FieldData, msg.Data), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldOptions, []byte{0, 1}), ) @@ -328,70 +155,63 @@ func HandleSendInstantMsg(cc *ClientConn, t *Transaction) (res []Transaction, er reply.Fields = append(reply.Fields, NewField(FieldQuotingMsg, t.GetField(FieldQuotingMsg).Data)) } - id, err := byteToInt(ID.Data) - if err != nil { - return res, errors.New("invalid client ID") - } - otherClient, ok := cc.Server.Clients[uint16(id)] + otherClient, ok := cc.Server.Clients[[2]byte(userID.Data)] if !ok { - return res, errors.New("invalid client ID") + return res } // Check if target user has "Refuse private messages" flag - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(otherClient.Flags))) - if flagBitmap.Bit(UserFlagRefusePM) == 1 { + if otherClient.Flags.IsSet(UserFlagRefusePM) { res = append(res, - *NewTransaction( + NewTransaction( TranServerMsg, cc.ID, NewField(FieldData, []byte(string(otherClient.UserName)+" does not accept private messages.")), NewField(FieldUserName, otherClient.UserName), - NewField(FieldUserID, *otherClient.ID), + NewField(FieldUserID, otherClient.ID[:]), NewField(FieldOptions, []byte{0, 2}), ), ) } else { - res = append(res, *reply) + res = append(res, reply) } // Respond with auto reply if other client has it enabled if len(otherClient.AutoReply) > 0 { res = append(res, - *NewTransaction( + NewTransaction( TranServerMsg, cc.ID, NewField(FieldData, otherClient.AutoReply), NewField(FieldUserName, otherClient.UserName), - NewField(FieldUserID, *otherClient.ID), + NewField(FieldUserID, otherClient.ID[:]), NewField(FieldOptions, []byte{0, 1}), ), ) } - res = append(res, cc.NewReply(t)) - - return res, err + return append(res, cc.NewReply(t)) } var fileTypeFLDR = [4]byte{0x66, 0x6c, 0x64, 0x72} -func HandleGetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction) { fileName := t.GetField(FieldFileName).Data filePath := t.GetField(FieldFilePath).Data fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } fw, err := newFileWrapper(cc.Server.FS, fullFilePath, 0) if err != nil { - return res, err + return res } encodedName, err := txtEncoder.String(fw.name) if err != nil { - return res, fmt.Errorf("invalid filepath encoding: %w", err) + return res } fields := []Field{ @@ -414,7 +234,7 @@ func HandleGetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err e } res = append(res, cc.NewReply(t, fields...)) - return res, err + return res } // HandleSetFileInfo updates a file or folder Name and/or comment from the Get Info window @@ -424,54 +244,52 @@ func HandleGetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err e // * 211 File new Name Optional // * 210 File comment Optional // Fields used in the reply: None -func HandleSetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleSetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction) { fileName := t.GetField(FieldFileName).Data filePath := t.GetField(FieldFilePath).Data fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } fi, err := cc.Server.FS.Stat(fullFilePath) if err != nil { - return res, err + return res } hlFile, err := newFileWrapper(cc.Server.FS, fullFilePath, 0) if err != nil { - return res, err + return res } if t.GetField(FieldFileComment).Data != nil { switch mode := fi.Mode(); { case mode.IsDir(): if !cc.Authorize(accessSetFolderComment) { - res = append(res, cc.NewErrReply(t, "You are not allowed to set comments for folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to set comments for folders.") } case mode.IsRegular(): if !cc.Authorize(accessSetFileComment) { - res = append(res, cc.NewErrReply(t, "You are not allowed to set comments for files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to set comments for files.") } } if err := hlFile.ffo.FlatFileInformationFork.setComment(t.GetField(FieldFileComment).Data); err != nil { - return res, err + return res } w, err := hlFile.infoForkWriter() if err != nil { - return res, err + return res } _, err = io.Copy(w, &hlFile.ffo.FlatFileInformationFork) if err != nil { - return res, err + return res } } fullNewFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, t.GetField(FieldFileNewName).Data) if err != nil { - return nil, err + return nil } fileNewName := t.GetField(FieldFileNewName).Data @@ -480,41 +298,38 @@ func HandleSetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err e switch mode := fi.Mode(); { case mode.IsDir(): if !cc.Authorize(accessRenameFolder) { - res = append(res, cc.NewErrReply(t, "You are not allowed to rename folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to rename folders.") } err = os.Rename(fullFilePath, fullNewFilePath) if os.IsNotExist(err) { - res = append(res, cc.NewErrReply(t, "Cannot rename folder "+string(fileName)+" because it does not exist or cannot be found.")) - return res, err + return cc.NewErrReply(t, "Cannot rename folder "+string(fileName)+" because it does not exist or cannot be found.") + } case mode.IsRegular(): if !cc.Authorize(accessRenameFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to rename files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to rename files.") } fileDir, err := readPath(cc.Server.Config.FileRoot, filePath, []byte{}) if err != nil { - return nil, err + return nil } hlFile.name, err = txtDecoder.String(string(fileNewName)) if err != nil { - return res, fmt.Errorf("invalid filepath encoding: %w", err) + return res } err = hlFile.move(fileDir) if os.IsNotExist(err) { - res = append(res, cc.NewErrReply(t, "Cannot rename file "+string(fileName)+" because it does not exist or cannot be found.")) - return res, err + return cc.NewErrReply(t, "Cannot rename file "+string(fileName)+" because it does not exist or cannot be found.") } if err != nil { - return res, err + return res } } } res = append(res, cc.NewReply(t)) - return res, err + return res } // HandleDeleteFile deletes a file or folder @@ -522,98 +337,91 @@ func HandleSetFileInfo(cc *ClientConn, t *Transaction) (res []Transaction, err e // * 201 File Name // * 202 File path // Fields used in the reply: none -func HandleDeleteFile(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDeleteFile(cc *ClientConn, t *Transaction) (res []Transaction) { fileName := t.GetField(FieldFileName).Data filePath := t.GetField(FieldFilePath).Data fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } hlFile, err := newFileWrapper(cc.Server.FS, fullFilePath, 0) if err != nil { - return res, err + return res } fi, err := hlFile.dataFile() if err != nil { - res = append(res, cc.NewErrReply(t, "Cannot delete file "+string(fileName)+" because it does not exist or cannot be found.")) - return res, nil + return cc.NewErrReply(t, "Cannot delete file "+string(fileName)+" because it does not exist or cannot be found.") } switch mode := fi.Mode(); { case mode.IsDir(): if !cc.Authorize(accessDeleteFolder) { - res = append(res, cc.NewErrReply(t, "You are not allowed to delete folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to delete folders.") } case mode.IsRegular(): if !cc.Authorize(accessDeleteFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to delete files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to delete files.") } } if err := hlFile.delete(); err != nil { - return res, err + return res } res = append(res, cc.NewReply(t)) - return res, err + return res } // HandleMoveFile moves files or folders. Note: seemingly not documented -func HandleMoveFile(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleMoveFile(cc *ClientConn, t *Transaction) (res []Transaction) { fileName := string(t.GetField(FieldFileName).Data) filePath, err := readPath(cc.Server.Config.FileRoot, t.GetField(FieldFilePath).Data, t.GetField(FieldFileName).Data) if err != nil { - return res, err + return res } fileNewPath, err := readPath(cc.Server.Config.FileRoot, t.GetField(FieldFileNewPath).Data, nil) if err != nil { - return res, err + return res } cc.logger.Info("Move file", "src", filePath+"/"+fileName, "dst", fileNewPath+"/"+fileName) hlFile, err := newFileWrapper(cc.Server.FS, filePath, 0) if err != nil { - return res, err + return res } fi, err := hlFile.dataFile() if err != nil { - res = append(res, cc.NewErrReply(t, "Cannot delete file "+fileName+" because it does not exist or cannot be found.")) - return res, err + return cc.NewErrReply(t, "Cannot delete file "+fileName+" because it does not exist or cannot be found.") } switch mode := fi.Mode(); { case mode.IsDir(): if !cc.Authorize(accessMoveFolder) { - res = append(res, cc.NewErrReply(t, "You are not allowed to move folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to move folders.") } case mode.IsRegular(): if !cc.Authorize(accessMoveFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to move files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to move files.") } } if err := hlFile.move(fileNewPath); err != nil { - return res, err + return res } // TODO: handle other possible errors; e.g. fileWrapper delete fails due to fileWrapper permission issue res = append(res, cc.NewReply(t)) - return res, err + return res } -func HandleNewFolder(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleNewFolder(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessCreateFolder) { - res = append(res, cc.NewErrReply(t, "You are not allowed to create folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to create folders.") } folderName := string(t.GetField(FieldFileName).Data) @@ -626,7 +434,7 @@ func HandleNewFolder(cc *ClientConn, t *Transaction) (res []Transaction, err err var newFp FilePath _, err := newFp.Write(t.GetField(FieldFilePath).Data) if err != nil { - return nil, err + return res } for _, pathItem := range newFp.Items { @@ -634,31 +442,30 @@ func HandleNewFolder(cc *ClientConn, t *Transaction) (res []Transaction, err err } } newFolderPath := path.Join(cc.Server.Config.FileRoot, subPath, folderName) - newFolderPath, err = txtDecoder.String(newFolderPath) + newFolderPath, err := txtDecoder.String(newFolderPath) if err != nil { - return res, fmt.Errorf("invalid filepath encoding: %w", err) + return res } // TODO: check path and folder Name lengths if _, err := cc.Server.FS.Stat(newFolderPath); !os.IsNotExist(err) { msg := fmt.Sprintf("Cannot create folder \"%s\" because there is already a file or folder with that Name.", folderName) - return []Transaction{cc.NewErrReply(t, msg)}, nil + return cc.NewErrReply(t, msg) } if err := cc.Server.FS.Mkdir(newFolderPath, 0777); err != nil { msg := fmt.Sprintf("Cannot create folder \"%s\" because an error occurred.", folderName) - return []Transaction{cc.NewErrReply(t, msg)}, nil + return cc.NewErrReply(t, msg) } res = append(res, cc.NewReply(t)) - return res, err + return res } -func HandleSetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleSetUser(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessModifyUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to modify accounts.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to modify accounts.") } login := string(encodeString(t.GetField(FieldUserLogin).Data)) @@ -668,7 +475,7 @@ func HandleSetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error account := cc.Server.Accounts[login] if account == nil { - return append(res, cc.NewErrReply(t, "Account not found.")), nil + return cc.NewErrReply(t, "Account not found.") } account.Name = userName copy(account.Access[:], newAccessLvl) @@ -685,33 +492,30 @@ func HandleSetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error out, err := yaml.Marshal(&account) if err != nil { - return res, err + return res } if err := os.WriteFile(filepath.Join(cc.Server.ConfigDir, "Users", login+".yaml"), out, 0666); err != nil { - return res, err + return res } // Notify connected clients logged in as the user of the new access level for _, c := range cc.Server.Clients { if c.Account.Login == login { - // Note: comment out these two lines to test server-side deny messages newT := NewTransaction(TranUserAccess, c.ID, NewField(FieldUserAccess, newAccessLvl)) - res = append(res, *newT) + res = append(res, newT) - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(c.Flags))) if c.Authorize(accessDisconUser) { - flagBitmap.SetBit(flagBitmap, UserFlagAdmin, 1) + c.Flags.Set(UserFlagAdmin, 1) } else { - flagBitmap.SetBit(flagBitmap, UserFlagAdmin, 0) + c.Flags.Set(UserFlagAdmin, 0) } - binary.BigEndian.PutUint16(c.Flags, uint16(flagBitmap.Int64())) c.Account.Access = account.Access cc.sendAll( TranNotifyChangeUser, - NewField(FieldUserID, *c.ID), - NewField(FieldUserFlags, c.Flags), + NewField(FieldUserID, c.ID[:]), + NewField(FieldUserFlags, c.Flags[:]), NewField(FieldUserName, c.UserName), NewField(FieldUserIconID, c.Icon), ) @@ -719,19 +523,17 @@ func HandleSetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error } res = append(res, cc.NewReply(t)) - return res, err + return res } -func HandleGetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetUser(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessOpenUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to view accounts.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to view accounts.") } account := cc.Server.Accounts[string(t.GetField(FieldUserLogin).Data)] if account == nil { - res = append(res, cc.NewErrReply(t, "Account does not exist.")) - return res, err + return cc.NewErrReply(t, "Account does not exist.") } res = append(res, cc.NewReply(t, @@ -740,13 +542,12 @@ func HandleGetUser(cc *ClientConn, t *Transaction) (res []Transaction, err error NewField(FieldUserPassword, []byte(account.Password)), NewField(FieldUserAccess, account.Access[:]), )) - return res, err + return res } -func HandleListUsers(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleListUsers(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessOpenUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to view accounts.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to view accounts.") } var userFields []Field @@ -754,14 +555,14 @@ func HandleListUsers(cc *ClientConn, t *Transaction) (res []Transaction, err err accCopy := *acc b, err := io.ReadAll(&accCopy) if err != nil { - return res, err + return res } userFields = append(userFields, NewField(FieldData, b)) } res = append(res, cc.NewReply(t, userFields...)) - return res, err + return res } // HandleUpdateUser is used by the v1.5+ multi-user editor to perform account editing for multiple users at a time. @@ -773,7 +574,7 @@ func HandleListUsers(cc *ClientConn, t *Transaction) (res []Transaction, err err // The Transaction sent by the client includes one data field per user that was modified. This data field in turn // contains another data field encoded in its payload with a varying number of sub fields depending on which action is // performed. This seems to be the only place in the Hotline protocol where a data field contains another data field. -func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction) { for _, field := range t.Fields { var subFields []Field @@ -786,7 +587,7 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er var field Field if _, err := field.Write(scanner.Bytes()); err != nil { - return res, fmt.Errorf("error reading field: %w", err) + return res } subFields = append(subFields, field) } @@ -794,15 +595,14 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er // If there's only one subfield, that indicates this is a delete operation for the login in FieldData if len(subFields) == 1 { if !cc.Authorize(accessDeleteUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to delete accounts.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to delete accounts.") } login := string(encodeString(getField(FieldData, &subFields).Data)) cc.logger.Info("DeleteUser", "login", login) if err := cc.Server.DeleteUser(login); err != nil { - return res, err + return res } continue } @@ -832,8 +632,7 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er // account exists, so this is an update action if !cc.Authorize(accessModifyUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to modify accounts.")) - return res, nil + return cc.NewErrReply(t, "You are not allowed to modify accounts.") } // This part is a bit tricky. There are three possibilities: @@ -856,7 +655,7 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er copy(acc.Access[:], getField(FieldUserAccess, &subFields).Data) } - err = cc.Server.UpdateUser( + err := cc.Server.UpdateUser( string(encodeString(getField(FieldData, &subFields).Data)), string(encodeString(getField(FieldUserLogin, &subFields).Data)), string(getField(FieldUserName, &subFields).Data), @@ -864,12 +663,11 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er acc.Access, ) if err != nil { - return res, err + return res } } else { if !cc.Authorize(accessCreateUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to create new accounts.")) - return res, nil + return cc.NewErrReply(t, "You are not allowed to create new accounts.") } cc.logger.Info("CreateUser", "login", userLogin) @@ -881,35 +679,32 @@ func HandleUpdateUser(cc *ClientConn, t *Transaction) (res []Transaction, err er for i := 0; i < 64; i++ { if newAccess.IsSet(i) { if !cc.Authorize(i) { - return append(res, cc.NewErrReply(t, "Cannot create account with more access than yourself.")), nil + return cc.NewErrReply(t, "Cannot create account with more access than yourself.") } } } - err = cc.Server.NewUser(userLogin, string(getField(FieldUserName, &subFields).Data), string(getField(FieldUserPassword, &subFields).Data), newAccess) + err := cc.Server.NewUser(userLogin, string(getField(FieldUserName, &subFields).Data), string(getField(FieldUserPassword, &subFields).Data), newAccess) if err != nil { - return append(res, cc.NewErrReply(t, "Cannot create account because there is already an account with that login.")), nil + return cc.NewErrReply(t, "Cannot create account because there is already an account with that login.") } } } - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } // HandleNewUser creates a new user account -func HandleNewUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleNewUser(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessCreateUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to create new accounts.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to create new accounts.") } login := string(encodeString(t.GetField(FieldUserLogin).Data)) // If the account already dataFile, reply with an error if _, ok := cc.Server.Accounts[login]; ok { - res = append(res, cc.NewErrReply(t, "Cannot create account "+login+" because there is already an account with that login.")) - return res, err + return cc.NewErrReply(t, "Cannot create account "+login+" because there is already an account with that login.") } newAccess := accessBitmap{} @@ -919,52 +714,45 @@ func HandleNewUser(cc *ClientConn, t *Transaction) (res []Transaction, err error for i := 0; i < 64; i++ { if newAccess.IsSet(i) { if !cc.Authorize(i) { - res = append(res, cc.NewErrReply(t, "Cannot create account with more access than yourself.")) - return res, err + return cc.NewErrReply(t, "Cannot create account with more access than yourself.") } } } if err := cc.Server.NewUser(login, string(t.GetField(FieldUserName).Data), string(t.GetField(FieldUserPassword).Data), newAccess); err != nil { - res = append(res, cc.NewErrReply(t, "Cannot create account because there is already an account with that login.")) - return res, err + return cc.NewErrReply(t, "Cannot create account because there is already an account with that login.") } - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } -func HandleDeleteUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDeleteUser(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessDeleteUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to delete accounts.")) - return res, nil + return cc.NewErrReply(t, "You are not allowed to delete accounts.") } login := string(encodeString(t.GetField(FieldUserLogin).Data)) if err := cc.Server.DeleteUser(login); err != nil { - return res, err + return res } - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } // HandleUserBroadcast sends an Administrator Message to all connected clients of the server -func HandleUserBroadcast(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleUserBroadcast(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessBroadcast) { - res = append(res, cc.NewErrReply(t, "You are not allowed to send broadcast messages.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to send broadcast messages.") } cc.sendAll( TranServerMsg, - NewField(FieldData, t.GetField(TranGetMsgs).Data), + NewField(FieldData, t.GetField(FieldData).Data), NewField(FieldChatOptions, []byte{0}), ) - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } // HandleGetClientInfoText returns user information for the specific user. @@ -975,31 +763,30 @@ func HandleUserBroadcast(cc *ClientConn, t *Transaction) (res []Transaction, err // Fields used in the reply: // 102 User Name // 101 Data User info text string -func HandleGetClientInfoText(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetClientInfoText(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessGetClientInfo) { - res = append(res, cc.NewErrReply(t, "You are not allowed to get client info.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to get client info.") } - clientID, _ := byteToInt(t.GetField(FieldUserID).Data) + clientID := t.GetField(FieldUserID).Data - clientConn := cc.Server.Clients[uint16(clientID)] + clientConn := cc.Server.Clients[[2]byte(clientID)] if clientConn == nil { - return append(res, cc.NewErrReply(t, "User not found.")), err + return cc.NewErrReply(t, "User not found.") } res = append(res, cc.NewReply(t, NewField(FieldData, []byte(clientConn.String())), NewField(FieldUserName, clientConn.UserName), )) - return res, err + return res } -func HandleGetUserNameList(cc *ClientConn, t *Transaction) (res []Transaction, err error) { - return []Transaction{cc.NewReply(t, cc.Server.connectedUsers()...)}, nil +func HandleGetUserNameList(cc *ClientConn, t *Transaction) (res []Transaction) { + return []Transaction{cc.NewReply(t, cc.Server.connectedUsers()...)} } -func HandleTranAgreed(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleTranAgreed(cc *ClientConn, t *Transaction) (res []Transaction) { if t.GetField(FieldUserName).Data != nil { if cc.Authorize(accessAnyName) { cc.UserName = t.GetField(FieldUserName).Data @@ -1016,54 +803,46 @@ func HandleTranAgreed(cc *ClientConn, t *Transaction) (res []Transaction, err er options := t.GetField(FieldOptions).Data optBitmap := big.NewInt(int64(binary.BigEndian.Uint16(options))) - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(cc.Flags))) - // Check refuse private PM option - if optBitmap.Bit(UserOptRefusePM) == 1 { - flagBitmap.SetBit(flagBitmap, UserFlagRefusePM, 1) - binary.BigEndian.PutUint16(cc.Flags, uint16(flagBitmap.Int64())) - } + + cc.flagsMU.Lock() + defer cc.flagsMU.Unlock() + cc.Flags.Set(UserFlagRefusePM, optBitmap.Bit(UserOptRefusePM)) // Check refuse private chat option - if optBitmap.Bit(UserOptRefuseChat) == 1 { - flagBitmap.SetBit(flagBitmap, UserFlagRefusePChat, 1) - binary.BigEndian.PutUint16(cc.Flags, uint16(flagBitmap.Int64())) - } + cc.Flags.Set(UserFlagRefusePChat, optBitmap.Bit(UserOptRefuseChat)) // Check auto response if optBitmap.Bit(UserOptAutoResponse) == 1 { cc.AutoReply = t.GetField(FieldAutomaticResponse).Data - } else { - cc.AutoReply = []byte{} } trans := cc.notifyOthers( - *NewTransaction( - TranNotifyChangeUser, nil, + NewTransaction( + TranNotifyChangeUser, [2]byte{0, 0}, NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldUserIconID, cc.Icon), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserFlags, cc.Flags[:]), ), ) res = append(res, trans...) if cc.Server.Config.BannerFile != "" { - res = append(res, *NewTransaction(TranServerBanner, cc.ID, NewField(FieldBannerType, []byte("JPEG")))) + res = append(res, NewTransaction(TranServerBanner, cc.ID, NewField(FieldBannerType, []byte("JPEG")))) } res = append(res, cc.NewReply(t)) - return res, err + return res } // HandleTranOldPostNews updates the flat news // Fields used in this request: // 101 Data -func HandleTranOldPostNews(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleTranOldPostNews(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsPostArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to post news.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to post news.") } cc.Server.flatNewsMux.Lock() @@ -1087,7 +866,7 @@ func HandleTranOldPostNews(cc *ClientConn, t *Transaction) (res []Transaction, e // update news on disk if err := cc.Server.FS.WriteFile(filepath.Join(cc.Server.ConfigDir, "MessageBoard.txt"), cc.Server.FlatNews, 0644); err != nil { - return res, err + return res } // Notify all clients of updated news @@ -1097,20 +876,18 @@ func HandleTranOldPostNews(cc *ClientConn, t *Transaction) (res []Transaction, e ) res = append(res, cc.NewReply(t)) - return res, err + return res } -func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessDisconUser) { - res = append(res, cc.NewErrReply(t, "You are not allowed to disconnect users.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to disconnect users.") } - clientConn := cc.Server.Clients[binary.BigEndian.Uint16(t.GetField(FieldUserID).Data)] + clientConn := cc.Server.Clients[[2]byte(t.GetField(FieldUserID).Data)] if clientConn.Authorize(accessCannotBeDiscon) { - res = append(res, cc.NewErrReply(t, clientConn.Account.Login+" is not allowed to be disconnected.")) - return res, err + return cc.NewErrReply(t, clientConn.Account.Login+" is not allowed to be disconnected.") } // If FieldOptions is set, then the client IP is banned in addition to disconnected. @@ -1122,7 +899,7 @@ func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction, er // send message: "You are temporarily banned on this server" cc.logger.Info("Disconnect & temporarily ban " + string(clientConn.UserName)) - res = append(res, *NewTransaction( + res = append(res, NewTransaction( TranServerMsg, clientConn.ID, NewField(FieldData, []byte("You are temporarily banned on this server")), @@ -1135,7 +912,7 @@ func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction, er // send message: "You are permanently banned on this server" cc.logger.Info("Disconnect & ban " + string(clientConn.UserName)) - res = append(res, *NewTransaction( + res = append(res, NewTransaction( TranServerMsg, clientConn.ID, NewField(FieldData, []byte("You are permanently banned on this server")), @@ -1147,7 +924,7 @@ func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction, er err := cc.Server.writeBanList() if err != nil { - return res, err + return res } } @@ -1157,16 +934,15 @@ func HandleDisconnectUser(cc *ClientConn, t *Transaction) (res []Transaction, er clientConn.Disconnect() }() - return append(res, cc.NewReply(t)), err + return append(res, cc.NewReply(t)) } // HandleGetNewsCatNameList returns a list of news categories for a path // Fields used in the request: // 325 News path (Optional) -func HandleGetNewsCatNameList(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetNewsCatNameList(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsReadArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to read news.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to read news.") } pathStrs := ReadNewsPath(t.GetField(FieldNewsPath).Data) @@ -1184,21 +960,19 @@ func HandleGetNewsCatNameList(cc *ClientConn, t *Transaction) (res []Transaction var fieldData []Field for _, k := range keys { cat := cats[k] - b, _ := cat.MarshalBinary() - fieldData = append(fieldData, NewField( - FieldNewsCatListData15, - b, - )) + + b, _ := io.ReadAll(&cat) + + fieldData = append(fieldData, NewField(FieldNewsCatListData15, b)) } res = append(res, cc.NewReply(t, fieldData...)) - return res, err + return res } -func HandleNewNewsCat(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleNewNewsCat(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsCreateCat) { - res = append(res, cc.NewErrReply(t, "You are not allowed to create news categories.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to create news categories.") } name := string(t.GetField(FieldNewsCatName).Data) @@ -1213,19 +987,18 @@ func HandleNewNewsCat(cc *ClientConn, t *Transaction) (res []Transaction, err er } if err := cc.Server.writeThreadedNews(); err != nil { - return res, err + return res } res = append(res, cc.NewReply(t)) - return res, err + return res } // Fields used in the request: // 322 News category Name // 325 News path -func HandleNewNewsFldr(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleNewNewsFldr(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsCreateFldr) { - res = append(res, cc.NewErrReply(t, "You are not allowed to create news folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to create news folders.") } name := string(t.GetField(FieldFileName).Data) @@ -1239,10 +1012,10 @@ func HandleNewNewsFldr(cc *ClientConn, t *Transaction) (res []Transaction, err e SubCats: make(map[string]NewsCategoryListData15), } if err := cc.Server.writeThreadedNews(); err != nil { - return res, err + return res } res = append(res, cc.NewReply(t)) - return res, err + return res } // HandleGetNewsArtData gets the list of article names at the specified news path. @@ -1252,10 +1025,9 @@ func HandleNewNewsFldr(cc *ClientConn, t *Transaction) (res []Transaction, err e // Fields used in the reply: // 321 News article list data Optional -func HandleGetNewsArtNameList(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetNewsArtNameList(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsReadArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to read news.")) - return res, nil + return cc.NewErrReply(t, "You are not allowed to read news.") } pathStrs := ReadNewsPath(t.GetField(FieldNewsPath).Data) @@ -1272,11 +1044,11 @@ func HandleGetNewsArtNameList(cc *ClientConn, t *Transaction) (res []Transaction b, err := io.ReadAll(&nald) if err != nil { - return res, fmt.Errorf("error loading news articles: %w", err) + return res } res = append(res, cc.NewReply(t, NewField(FieldNewsArtListData, b))) - return res, nil + return res } // HandleGetNewsArtData requests information about the specific news article. @@ -1297,10 +1069,9 @@ func HandleGetNewsArtNameList(cc *ClientConn, t *Transaction) (res []Transaction // 336 First child article ID // 327 News article data flavor "Should be “text/plain” // 333 News article data Optional (if data flavor is “text/plain”) -func HandleGetNewsArtData(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetNewsArtData(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsReadArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to read news.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to read news.") } var cat NewsCategoryListData15 @@ -1315,13 +1086,12 @@ func HandleGetNewsArtData(cc *ClientConn, t *Transaction) (res []Transaction, er // some third party clients such as Frogblast and Heildrun will always send 4 bytes convertedID, err := byteToInt(t.GetField(FieldNewsArtID).Data) if err != nil { - return res, err + return res } art := cat.Articles[uint32(convertedID)] if art == nil { - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } res = append(res, cc.NewReply(t, @@ -1335,7 +1105,7 @@ func HandleGetNewsArtData(cc *ClientConn, t *Transaction) (res []Transaction, er NewField(FieldNewsArtDataFlav, []byte("text/plain")), NewField(FieldNewsArtData, []byte(art.Data)), )) - return res, err + return res } // HandleDelNewsItem deletes an existing threaded news folder or category from the server. @@ -1343,7 +1113,7 @@ func HandleGetNewsArtData(cc *ClientConn, t *Transaction) (res []Transaction, er // 325 News path // Fields used in the reply: // None -func HandleDelNewsItem(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDelNewsItem(cc *ClientConn, t *Transaction) (res []Transaction) { pathStrs := ReadNewsPath(t.GetField(FieldNewsPath).Data) cats := cc.Server.ThreadedNews.Categories @@ -1356,27 +1126,27 @@ func HandleDelNewsItem(cc *ClientConn, t *Transaction) (res []Transaction, err e if cats[delName].Type == [2]byte{0, 3} { if !cc.Authorize(accessNewsDeleteCat) { - return append(res, cc.NewErrReply(t, "You are not allowed to delete news categories.")), nil + return cc.NewErrReply(t, "You are not allowed to delete news categories.") } } else { if !cc.Authorize(accessNewsDeleteFldr) { - return append(res, cc.NewErrReply(t, "You are not allowed to delete news folders.")), nil + return cc.NewErrReply(t, "You are not allowed to delete news folders.") } } delete(cats, delName) if err := cc.Server.writeThreadedNews(); err != nil { - return res, err + return res } - return append(res, cc.NewReply(t)), nil + return append(res, cc.NewReply(t)) } -func HandleDelNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDelNewsArt(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsDeleteArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to delete news articles.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to delete news articles.") + } // Request Fields @@ -1386,7 +1156,7 @@ func HandleDelNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err er pathStrs := ReadNewsPath(t.GetField(FieldNewsPath).Data) ID, err := byteToInt(t.GetField(FieldNewsArtID).Data) if err != nil { - return res, err + return res } // TODO: Delete recursive @@ -1399,11 +1169,11 @@ func HandleDelNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err er cats[catName] = cat if err := cc.Server.writeThreadedNews(); err != nil { - return res, err + return res } res = append(res, cc.NewReply(t)) - return res, err + return res } // Request fields @@ -1413,10 +1183,9 @@ func HandleDelNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err er // 334 News article flags // 327 News article data flavor Currently “text/plain” // 333 News article data -func HandlePostNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandlePostNewsArt(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsPostArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to post news articles.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to post news articles.") } pathStrs := ReadNewsPath(t.GetField(FieldNewsPath).Data) @@ -1427,7 +1196,7 @@ func HandlePostNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err e artID, err := byteToInt(t.GetField(FieldNewsArtID).Data) if err != nil { - return res, err + return res } convertedArtID := uint32(artID) bs := make([]byte, 4) @@ -1437,15 +1206,12 @@ func HandlePostNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err e defer cc.Server.mux.Unlock() newArt := NewsArtData{ - Title: string(t.GetField(FieldNewsArtTitle).Data), - Poster: string(cc.UserName), - Date: toHotlineTime(time.Now()), - PrevArt: [4]byte{}, - NextArt: [4]byte{}, - ParentArt: [4]byte(bs), - FirstChildArt: [4]byte{}, - DataFlav: []byte("text/plain"), - Data: string(t.GetField(FieldNewsArtData).Data), + Title: string(t.GetField(FieldNewsArtTitle).Data), + Poster: string(cc.UserName), + Date: toHotlineTime(time.Now()), + ParentArt: [4]byte(bs), + DataFlav: []byte("text/plain"), + Data: string(t.GetField(FieldNewsArtData).Data), } var keys []int @@ -1479,29 +1245,26 @@ func HandlePostNewsArt(cc *ClientConn, t *Transaction) (res []Transaction, err e cats[catName] = cat if err := cc.Server.writeThreadedNews(); err != nil { - return res, err + return res } - res = append(res, cc.NewReply(t)) - return res, err + return append(res, cc.NewReply(t)) } // HandleGetMsgs returns the flat news data -func HandleGetMsgs(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetMsgs(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessNewsReadArt) { - res = append(res, cc.NewErrReply(t, "You are not allowed to read news.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to read news.") } res = append(res, cc.NewReply(t, NewField(FieldData, cc.Server.FlatNews))) - return res, err + return res } -func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessDownloadFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to download files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to download files.") } fileName := t.GetField(FieldFileName).Data @@ -1512,7 +1275,7 @@ func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction, err var frd FileResumeData if resumeData != nil { if err := frd.UnmarshalBinary(t.GetField(FieldFileResumeData).Data); err != nil { - return res, err + return res } // TODO: handle rsrc fork offset dataOffset = int64(binary.BigEndian.Uint32(frd.ForkInfoList[0].DataSize[:])) @@ -1520,12 +1283,12 @@ func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction, err fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } hlFile, err := newFileWrapper(cc.Server.FS, fullFilePath, dataOffset) if err != nil { - return res, err + return res } xferSize := hlFile.ffo.TransferSize(0) @@ -1536,7 +1299,7 @@ func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction, err if resumeData != nil { var frd FileResumeData if err := frd.UnmarshalBinary(t.GetField(FieldFileResumeData).Data); err != nil { - return res, err + return res } ft.fileResumeData = &frd } @@ -1556,45 +1319,45 @@ func HandleDownloadFile(cc *ClientConn, t *Transaction) (res []Transaction, err NewField(FieldFileSize, hlFile.ffo.FlatFileDataForkHeader.DataSize[:]), )) - return res, err + return res } // Download all files from the specified folder and sub-folders -func HandleDownloadFolder(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDownloadFolder(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessDownloadFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to download folders.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to download folders.") } fullFilePath, err := readPath(cc.Server.Config.FileRoot, t.GetField(FieldFilePath).Data, t.GetField(FieldFileName).Data) if err != nil { - return res, err + return res } transferSize, err := CalcTotalSize(fullFilePath) if err != nil { - return res, err + return res } itemCount, err := CalcItemCount(fullFilePath) if err != nil { - return res, err + return res } + spew.Dump(itemCount) fileTransfer := cc.newFileTransfer(FolderDownload, t.GetField(FieldFileName).Data, t.GetField(FieldFilePath).Data, transferSize) var fp FilePath _, err = fp.Write(t.GetField(FieldFilePath).Data) if err != nil { - return res, err + return res } res = append(res, cc.NewReply(t, - NewField(FieldRefNum, fileTransfer.ReferenceNumber), + NewField(FieldRefNum, fileTransfer.refNum[:]), NewField(FieldTransferSize, transferSize), NewField(FieldFolderItemCount, itemCount), NewField(FieldWaitingCount, []byte{0x00, 0x00}), // TODO: Implement waiting count )) - return res, err + return res } // Upload all files from the local folder and its subfolders to the specified path on the server @@ -1604,19 +1367,18 @@ func HandleDownloadFolder(cc *ClientConn, t *Transaction) (res []Transaction, er // 108 transfer size Total size of all items in the folder // 220 Folder item count // 204 File transfer options "Optional Currently set to 1" (TODO: ??) -func HandleUploadFolder(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleUploadFolder(cc *ClientConn, t *Transaction) (res []Transaction) { var fp FilePath if t.GetField(FieldFilePath).Data != nil { - if _, err = fp.Write(t.GetField(FieldFilePath).Data); err != nil { - return res, err + if _, err := fp.Write(t.GetField(FieldFilePath).Data); err != nil { + return res } } // Handle special cases for Upload and Drop Box folders if !cc.Authorize(accessUploadAnywhere) { if !fp.IsUploadDir() && !fp.IsDropbox() { - res = append(res, cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload of the folder \"%v\" because you are only allowed to upload to the \"Uploads\" folder.", string(t.GetField(FieldFileName).Data)))) - return res, err + return cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload of the folder \"%v\" because you are only allowed to upload to the \"Uploads\" folder.", string(t.GetField(FieldFileName).Data))) } } @@ -1628,8 +1390,7 @@ func HandleUploadFolder(cc *ClientConn, t *Transaction) (res []Transaction, err fileTransfer.FolderItemCount = t.GetField(FieldFolderItemCount).Data - res = append(res, cc.NewReply(t, NewField(FieldRefNum, fileTransfer.ReferenceNumber))) - return res, err + return append(res, cc.NewReply(t, NewField(FieldRefNum, fileTransfer.refNum[:]))) } // HandleUploadFile @@ -1639,10 +1400,9 @@ func HandleUploadFolder(cc *ClientConn, t *Transaction) (res []Transaction, err // 204 File transfer options "Optional // Used only to resume download, currently has value 2" // 108 File transfer size "Optional used if download is not resumed" -func HandleUploadFile(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleUploadFile(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessUploadFile) { - res = append(res, cc.NewErrReply(t, "You are not allowed to upload files.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to upload files.") } fileName := t.GetField(FieldFileName).Data @@ -1652,37 +1412,35 @@ func HandleUploadFile(cc *ClientConn, t *Transaction) (res []Transaction, err er var fp FilePath if filePath != nil { - if _, err = fp.Write(filePath); err != nil { - return res, err + if _, err := fp.Write(filePath); err != nil { + return res } } // Handle special cases for Upload and Drop Box folders if !cc.Authorize(accessUploadAnywhere) { if !fp.IsUploadDir() && !fp.IsDropbox() { - res = append(res, cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload of the file \"%v\" because you are only allowed to upload to the \"Uploads\" folder.", string(fileName)))) - return res, err + return cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload of the file \"%v\" because you are only allowed to upload to the \"Uploads\" folder.", string(fileName))) } } fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } if _, err := cc.Server.FS.Stat(fullFilePath); err == nil { - res = append(res, cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload because there is already a file named \"%v\". Try choosing a different Name.", string(fileName)))) - return res, err + return cc.NewErrReply(t, fmt.Sprintf("Cannot accept upload because there is already a file named \"%v\". Try choosing a different Name.", string(fileName))) } ft := cc.newFileTransfer(FileUpload, fileName, filePath, transferSize) - replyT := cc.NewReply(t, NewField(FieldRefNum, ft.ReferenceNumber)) + replyT := cc.NewReply(t, NewField(FieldRefNum, ft.refNum[:])) // client has requested to resume a partially transferred file if transferOptions != nil { fileInfo, err := cc.Server.FS.Stat(fullFilePath + incompleteFileSuffix) if err != nil { - return res, err + return res } offset := make([]byte, 4) @@ -1700,10 +1458,10 @@ func HandleUploadFile(cc *ClientConn, t *Transaction) (res []Transaction, err er } res = append(res, replyT) - return res, err + return res } -func HandleSetClientUserInfo(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleSetClientUserInfo(cc *ClientConn, t *Transaction) (res []Transaction) { if len(t.GetField(FieldUserIconID).Data) == 4 { cc.Icon = t.GetField(FieldUserIconID).Data[2:] } else { @@ -1713,17 +1471,20 @@ func HandleSetClientUserInfo(cc *ClientConn, t *Transaction) (res []Transaction, cc.UserName = t.GetField(FieldUserName).Data } + cc.flagsMU.Lock() + defer cc.flagsMU.Unlock() + // the options field is only passed by the client versions > 1.2.3. options := t.GetField(FieldOptions).Data if options != nil { optBitmap := big.NewInt(int64(binary.BigEndian.Uint16(options))) - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(cc.Flags))) + flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(cc.Flags[:]))) flagBitmap.SetBit(flagBitmap, UserFlagRefusePM, optBitmap.Bit(UserOptRefusePM)) - binary.BigEndian.PutUint16(cc.Flags, uint16(flagBitmap.Int64())) + binary.BigEndian.PutUint16(cc.Flags[:], uint16(flagBitmap.Int64())) flagBitmap.SetBit(flagBitmap, UserFlagRefusePChat, optBitmap.Bit(UserOptRefuseChat)) - binary.BigEndian.PutUint16(cc.Flags, uint16(flagBitmap.Int64())) + binary.BigEndian.PutUint16(cc.Flags[:], uint16(flagBitmap.Int64())) // Check auto response if optBitmap.Bit(UserOptAutoResponse) == 1 { @@ -1733,60 +1494,59 @@ func HandleSetClientUserInfo(cc *ClientConn, t *Transaction) (res []Transaction, } } - for _, c := range sortedClients(cc.Server.Clients) { - res = append(res, *NewTransaction( + for _, c := range cc.Server.Clients { + res = append(res, NewTransaction( TranNotifyChangeUser, c.ID, - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldUserIconID, cc.Icon), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserFlags, cc.Flags[:]), NewField(FieldUserName, cc.UserName), )) } - return res, err + return res } // HandleKeepAlive responds to keepalive transactions with an empty reply // * HL 1.9.2 Client sends keepalive msg every 3 minutes // * HL 1.2.3 Client doesn't send keepalives -func HandleKeepAlive(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleKeepAlive(cc *ClientConn, t *Transaction) (res []Transaction) { res = append(res, cc.NewReply(t)) - return res, err + return res } -func HandleGetFileNameList(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleGetFileNameList(cc *ClientConn, t *Transaction) (res []Transaction) { fullPath, err := readPath( cc.Server.Config.FileRoot, t.GetField(FieldFilePath).Data, nil, ) if err != nil { - return res, fmt.Errorf("error reading file path: %w", err) + return res } var fp FilePath if t.GetField(FieldFilePath).Data != nil { if _, err = fp.Write(t.GetField(FieldFilePath).Data); err != nil { - return res, fmt.Errorf("error writing file path: %w", err) + return res } } // Handle special case for drop box folders if fp.IsDropbox() && !cc.Authorize(accessViewDropBoxes) { - res = append(res, cc.NewErrReply(t, "You are not allowed to view drop boxes.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to view drop boxes.") } fileNames, err := getFileNameList(fullPath, cc.Server.Config.IgnoreFiles) if err != nil { - return res, fmt.Errorf("getFileNameList: %w", err) + return res } res = append(res, cc.NewReply(t, fileNames...)) - return res, err + return res } // ================================= @@ -1802,10 +1562,9 @@ func HandleGetFileNameList(cc *ClientConn, t *Transaction) (res []Transaction, e // 1. ClientB sends TranJoinChat with FieldChatID // HandleInviteNewChat invites users to new private chat -func HandleInviteNewChat(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleInviteNewChat(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessOpenChat) { - res = append(res, cc.NewErrReply(t, "You are not allowed to request private chat.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to request private chat.") } // Client to Invite @@ -1813,99 +1572,88 @@ func HandleInviteNewChat(cc *ClientConn, t *Transaction) (res []Transaction, err newChatID := cc.Server.NewPrivateChat(cc) // Check if target user has "Refuse private chat" flag - binary.BigEndian.Uint16(targetID) - targetClient := cc.Server.Clients[binary.BigEndian.Uint16(targetID)] - - flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(targetClient.Flags))) + targetClient := cc.Server.Clients[[2]byte(targetID)] + flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(targetClient.Flags[:]))) if flagBitmap.Bit(UserFlagRefusePChat) == 1 { res = append(res, - *NewTransaction( + NewTransaction( TranServerMsg, cc.ID, NewField(FieldData, []byte(string(targetClient.UserName)+" does not accept private chats.")), NewField(FieldUserName, targetClient.UserName), - NewField(FieldUserID, *targetClient.ID), + NewField(FieldUserID, targetClient.ID[:]), NewField(FieldOptions, []byte{0, 2}), ), ) } else { res = append(res, - *NewTransaction( + NewTransaction( TranInviteToChat, - &targetID, - NewField(FieldChatID, newChatID), + [2]byte(targetID), + NewField(FieldChatID, newChatID[:]), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), ), ) } res = append(res, cc.NewReply(t, - NewField(FieldChatID, newChatID), + NewField(FieldChatID, newChatID[:]), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldUserIconID, cc.Icon), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserFlags, cc.Flags[:]), ), ) - return res, err + return res } -func HandleInviteToChat(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleInviteToChat(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessOpenChat) { - res = append(res, cc.NewErrReply(t, "You are not allowed to request private chat.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to request private chat.") } // Client to Invite targetID := t.GetField(FieldUserID).Data chatID := t.GetField(FieldChatID).Data - res = append(res, - *NewTransaction( + return []Transaction{ + NewTransaction( TranInviteToChat, - &targetID, + [2]byte(targetID), NewField(FieldChatID, chatID), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), ), - ) - res = append(res, cc.NewReply( t, NewField(FieldChatID, chatID), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldUserIconID, cc.Icon), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserFlags, cc.Flags[:]), ), - ) - - return res, err + } } -func HandleRejectChatInvite(cc *ClientConn, t *Transaction) (res []Transaction, err error) { - chatID := t.GetField(FieldChatID).Data - chatInt := binary.BigEndian.Uint32(chatID) - - privChat := cc.Server.PrivateChats[chatInt] - - resMsg := append(cc.UserName, []byte(" declined invitation to chat")...) +func HandleRejectChatInvite(cc *ClientConn, t *Transaction) (res []Transaction) { + chatID := [4]byte(t.GetField(FieldChatID).Data) + privChat := cc.Server.PrivateChats[chatID] - for _, c := range sortedClients(privChat.ClientConn) { + for _, c := range privChat.ClientConn { res = append(res, - *NewTransaction( + NewTransaction( TranChatMsg, c.ID, - NewField(FieldChatID, chatID), - NewField(FieldData, resMsg), + NewField(FieldChatID, chatID[:]), + NewField(FieldData, append(cc.UserName, []byte(" declined invitation to chat")...)), ), ) } - return res, err + return res } // HandleJoinChat is sent from a v1.8+ Hotline client when the joins a private chat @@ -1913,45 +1661,44 @@ func HandleRejectChatInvite(cc *ClientConn, t *Transaction) (res []Transaction, // * 115 Chat subject // * 300 User Name with info (Optional) // * 300 (more user names with info) -func HandleJoinChat(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleJoinChat(cc *ClientConn, t *Transaction) (res []Transaction) { chatID := t.GetField(FieldChatID).Data - chatInt := binary.BigEndian.Uint32(chatID) - privChat := cc.Server.PrivateChats[chatInt] + privChat := cc.Server.PrivateChats[[4]byte(chatID)] // Send TranNotifyChatChangeUser to current members of the chat to inform of new user - for _, c := range sortedClients(privChat.ClientConn) { + for _, c := range privChat.ClientConn { res = append(res, - *NewTransaction( + NewTransaction( TranNotifyChatChangeUser, c.ID, NewField(FieldChatID, chatID), NewField(FieldUserName, cc.UserName), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), NewField(FieldUserIconID, cc.Icon), - NewField(FieldUserFlags, cc.Flags), + NewField(FieldUserFlags, cc.Flags[:]), ), ) } - privChat.ClientConn[cc.uint16ID()] = cc + privChat.ClientConn[cc.ID] = cc replyFields := []Field{NewField(FieldChatSubject, []byte(privChat.Subject))} - for _, c := range sortedClients(privChat.ClientConn) { + for _, c := range privChat.ClientConn { b, err := io.ReadAll(&User{ - ID: *c.ID, + ID: c.ID, Icon: c.Icon, - Flags: c.Flags, + Flags: c.Flags[:], Name: string(c.UserName), }) if err != nil { - return res, nil + return res } replyFields = append(replyFields, NewField(FieldUsernameWithInfo, b)) } res = append(res, cc.NewReply(t, replyFields...)) - return res, err + return res } // HandleLeaveChat is sent from a v1.8+ Hotline client when the user exits a private chat @@ -1959,30 +1706,29 @@ func HandleJoinChat(cc *ClientConn, t *Transaction) (res []Transaction, err erro // - 114 FieldChatID // // Reply is not expected. -func HandleLeaveChat(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleLeaveChat(cc *ClientConn, t *Transaction) (res []Transaction) { chatID := t.GetField(FieldChatID).Data - chatInt := binary.BigEndian.Uint32(chatID) - privChat, ok := cc.Server.PrivateChats[chatInt] + privChat, ok := cc.Server.PrivateChats[[4]byte(chatID)] if !ok { - return res, nil + return res } - delete(privChat.ClientConn, cc.uint16ID()) + delete(privChat.ClientConn, cc.ID) // Notify members of the private chat that the user has left - for _, c := range sortedClients(privChat.ClientConn) { + for _, c := range privChat.ClientConn { res = append(res, - *NewTransaction( + NewTransaction( TranNotifyChatDeleteUser, c.ID, NewField(FieldChatID, chatID), - NewField(FieldUserID, *cc.ID), + NewField(FieldUserID, cc.ID[:]), ), ) } - return res, err + return res } // HandleSetChatSubject is sent from a v1.8+ Hotline client when the user sets a private chat subject @@ -1990,16 +1736,15 @@ func HandleLeaveChat(cc *ClientConn, t *Transaction) (res []Transaction, err err // * 114 Chat ID // * 115 Chat subject // Reply is not expected. -func HandleSetChatSubject(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleSetChatSubject(cc *ClientConn, t *Transaction) (res []Transaction) { chatID := t.GetField(FieldChatID).Data - chatInt := binary.BigEndian.Uint32(chatID) - privChat := cc.Server.PrivateChats[chatInt] + privChat := cc.Server.PrivateChats[[4]byte(chatID)] privChat.Subject = string(t.GetField(FieldChatSubject).Data) - for _, c := range sortedClients(privChat.ClientConn) { + for _, c := range privChat.ClientConn { res = append(res, - *NewTransaction( + NewTransaction( TranNotifyChatSubject, c.ID, NewField(FieldChatID, chatID), @@ -2008,7 +1753,7 @@ func HandleSetChatSubject(cc *ClientConn, t *Transaction) (res []Transaction, er ) } - return res, err + return res } // HandleMakeAlias makes a file alias using the specified path. @@ -2019,10 +1764,9 @@ func HandleSetChatSubject(cc *ClientConn, t *Transaction) (res []Transaction, er // // Fields used in the reply: // None -func HandleMakeAlias(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleMakeAlias(cc *ClientConn, t *Transaction) (res []Transaction) { if !cc.Authorize(accessMakeAlias) { - res = append(res, cc.NewErrReply(t, "You are not allowed to make aliases.")) - return res, err + return cc.NewErrReply(t, "You are not allowed to make aliases.") } fileName := t.GetField(FieldFileName).Data filePath := t.GetField(FieldFilePath).Data @@ -2030,23 +1774,22 @@ func HandleMakeAlias(cc *ClientConn, t *Transaction) (res []Transaction, err err fullFilePath, err := readPath(cc.Server.Config.FileRoot, filePath, fileName) if err != nil { - return res, err + return res } fullNewFilePath, err := readPath(cc.Server.Config.FileRoot, fileNewPath, fileName) if err != nil { - return res, err + return res } cc.logger.Debug("Make alias", "src", fullFilePath, "dst", fullNewFilePath) if err := cc.Server.FS.Symlink(fullFilePath, fullNewFilePath); err != nil { - res = append(res, cc.NewErrReply(t, "Error creating alias")) - return res, nil + return cc.NewErrReply(t, "Error creating alias") } res = append(res, cc.NewReply(t)) - return res, err + return res } // HandleDownloadBanner handles requests for a new banner from the server @@ -2055,12 +1798,12 @@ func HandleMakeAlias(cc *ClientConn, t *Transaction) (res []Transaction, err err // Fields used in the reply: // 107 FieldRefNum Used later for transfer // 108 FieldTransferSize Size of data to be downloaded -func HandleDownloadBanner(cc *ClientConn, t *Transaction) (res []Transaction, err error) { +func HandleDownloadBanner(cc *ClientConn, t *Transaction) (res []Transaction) { ft := cc.newFileTransfer(bannerDownload, []byte{}, []byte{}, make([]byte, 4)) binary.BigEndian.PutUint32(ft.TransferSize, uint32(len(cc.Server.banner))) return append(res, cc.NewReply(t, NewField(FieldRefNum, ft.refNum[:]), NewField(FieldTransferSize, ft.TransferSize), - )), err + )) } diff --git a/hotline/transaction_handlers_test.go b/hotline/transaction_handlers_test.go index 9ea4a0a..b7cc0af 100644 --- a/hotline/transaction_handlers_test.go +++ b/hotline/transaction_handlers_test.go @@ -2,7 +2,6 @@ package hotline import ( "errors" - "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "io" @@ -10,6 +9,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" ) @@ -17,13 +17,12 @@ import ( func TestHandleSetChatSubject(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { - name string - args args - want []Transaction - wantErr bool + name string + args args + want []Transaction }{ { name: "sends chat subject to private chat members", @@ -31,42 +30,42 @@ func TestHandleSetChatSubject(t *testing.T) { cc: &ClientConn{ UserName: []byte{0x00, 0x01}, Server: &Server{ - PrivateChats: map[uint32]*PrivateChat{ - uint32(1): { + PrivateChats: map[[4]byte]*PrivateChat{ + [4]byte{0, 0, 0, 1}: { Subject: "unset", - ClientConn: map[uint16]*ClientConn{ - uint16(1): { + ClientConn: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Type: [2]byte{0, 0x6a}, ID: [4]byte{0, 0, 0, 1}, Fields: []Field{ @@ -77,7 +76,7 @@ func TestHandleSetChatSubject(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x77}, Fields: []Field{ NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -85,7 +84,7 @@ func TestHandleSetChatSubject(t *testing.T) { }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Type: [2]byte{0, 0x77}, Fields: []Field{ NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -93,16 +92,11 @@ func TestHandleSetChatSubject(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := HandleSetChatSubject(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleSetChatSubject() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := HandleSetChatSubject(tt.args.cc, &tt.args.t) if !tranAssertEqual(t, tt.want, got) { t.Errorf("HandleSetChatSubject() got = %v, want %v", got, tt.want) } @@ -113,61 +107,58 @@ func TestHandleSetChatSubject(t *testing.T) { func TestHandleLeaveChat(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { - name string - args args - want []Transaction - wantErr bool + name string + args args + want []Transaction }{ { name: "returns expected transactions", args: args{ cc: &ClientConn{ - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, Server: &Server{ - PrivateChats: map[uint32]*PrivateChat{ - uint32(1): { - ClientConn: map[uint16]*ClientConn{ - uint16(1): { + PrivateChats: map[[4]byte]*PrivateChat{ + [4]byte{0, 0, 0, 1}: { + ClientConn: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: NewTransaction(TranDeleteUser, nil, NewField(FieldChatID, []byte{0, 0, 0, 1})), + t: NewTransaction(TranDeleteUser, [2]byte{}, NewField(FieldChatID, []byte{0, 0, 0, 1})), }, want: []Transaction{ { - clientID: &[]byte{0, 1}, - Flags: 0x00, - IsReply: 0x00, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x76}, Fields: []Field{ NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -175,16 +166,11 @@ func TestHandleLeaveChat(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := HandleLeaveChat(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleLeaveChat() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := HandleLeaveChat(tt.args.cc, &tt.args.t) if !tranAssertEqual(t, tt.want, got) { t.Errorf("HandleLeaveChat() got = %v, want %v", got, tt.want) } @@ -195,42 +181,40 @@ func TestHandleLeaveChat(t *testing.T) { func TestHandleGetUserNameList(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { - name string - args args - want []Transaction - wantErr bool + name string + args args + want []Transaction }{ { name: "replies with userlist transaction", args: args{ cc: &ClientConn{ - - ID: &[]byte{1, 1}, + ID: [2]byte{0, 1}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { - ID: &[]byte{0, 1}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { + ID: [2]byte{0, 1}, Icon: []byte{0, 2}, - Flags: []byte{0, 3}, + Flags: [2]byte{0, 3}, UserName: []byte{0, 4}, }, - uint16(2): { - ID: &[]byte{0, 2}, + [2]byte{0, 2}: { + ID: [2]byte{0, 2}, Icon: []byte{0, 2}, - Flags: []byte{0, 3}, + Flags: [2]byte{0, 3}, UserName: []byte{0, 4}, }, }, }, }, - t: &Transaction{}, + t: Transaction{}, }, want: []Transaction{ { - clientID: &[]byte{1, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{ NewField( @@ -244,16 +228,11 @@ func TestHandleGetUserNameList(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := HandleGetUserNameList(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleGetUserNameList() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := HandleGetUserNameList(tt.args.cc, &tt.args.t) assert.Equal(t, tt.want, got) }) } @@ -262,13 +241,12 @@ func TestHandleGetUserNameList(t *testing.T) { func TestHandleChatSend(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { - name string - args args - want []Transaction - wantErr bool + name string + args args + want []Transaction }{ { name: "sends chat msg transaction to all clients", @@ -283,23 +261,23 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte{0x00, 0x01}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("hai")), }, @@ -307,7 +285,7 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Flags: 0x00, IsReply: 0x00, Type: [2]byte{0, 0x6a}, @@ -316,7 +294,7 @@ func TestHandleChatSend(t *testing.T) { }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Flags: 0x00, IsReply: 0x00, Type: [2]byte{0, 0x6a}, @@ -325,7 +303,6 @@ func TestHandleChatSend(t *testing.T) { }, }, }, - wantErr: false, }, { name: "treats Chat ID 00 00 00 00 as a public chat message", @@ -340,23 +317,23 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte{0x00, 0x01}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("hai")), NewField(FieldChatID, []byte{0, 0, 0, 0}), @@ -365,21 +342,20 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldData, []byte{0x0d, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x00, 0x01, 0x3a, 0x20, 0x20, 0x68, 0x61, 0x69}), }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldData, []byte{0x0d, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x00, 0x01, 0x3a, 0x20, 0x20, 0x68, 0x61, 0x69}), }, }, }, - wantErr: false, }, { name: "when user does not have required permission", @@ -396,7 +372,7 @@ func TestHandleChatSend(t *testing.T) { }, }, t: NewTransaction( - TranChatSend, &[]byte{0, 1}, + TranChatSend, [2]byte{0, 1}, NewField(FieldData, []byte("hai")), ), }, @@ -409,7 +385,6 @@ func TestHandleChatSend(t *testing.T) { }, }, }, - wantErr: false, }, { name: "sends chat msg as emote if FieldChatOptions is set to 1", @@ -424,23 +399,23 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte("Testy McTest"), Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("performed action")), NewField(FieldChatOptions, []byte{0x00, 0x01}), @@ -449,7 +424,7 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Flags: 0x00, IsReply: 0x00, Type: [2]byte{0, 0x6a}, @@ -458,7 +433,7 @@ func TestHandleChatSend(t *testing.T) { }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Flags: 0x00, IsReply: 0x00, Type: [2]byte{0, 0x6a}, @@ -467,7 +442,6 @@ func TestHandleChatSend(t *testing.T) { }, }, }, - wantErr: false, }, { name: "does not send chat msg as emote if FieldChatOptions is set to 0", @@ -482,23 +456,23 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte("Testy McTest"), Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("hello")), NewField(FieldChatOptions, []byte{0x00, 0x00}), @@ -507,21 +481,20 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldData, []byte("\r Testy McTest: hello")), }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldData, []byte("\r Testy McTest: hello")), }, }, }, - wantErr: false, }, { name: "only sends chat msg to clients with accessReadChat permission", @@ -536,26 +509,26 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte{0x00, 0x01}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: func() accessBitmap { var bits accessBitmap bits.Set(accessReadChat) return bits }()}, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{0, 0, 0, 0, 0, 0, 0, 0}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("hai")), }, @@ -563,14 +536,13 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldData, []byte{0x0d, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x00, 0x01, 0x3a, 0x20, 0x20, 0x68, 0x61, 0x69}), }, }, }, - wantErr: false, }, { name: "only sends private chat msg to members of private chat", @@ -585,41 +557,41 @@ func TestHandleChatSend(t *testing.T) { }, UserName: []byte{0x00, 0x01}, Server: &Server{ - PrivateChats: map[uint32]*PrivateChat{ - uint32(1): { - ClientConn: map[uint16]*ClientConn{ - uint16(1): { - ID: &[]byte{0, 1}, + PrivateChats: map[[4]byte]*PrivateChat{ + [4]byte{0, 0, 0, 1}: { + ClientConn: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { + ID: [2]byte{0, 1}, }, - uint16(2): { - ID: &[]byte{0, 2}, + [2]byte{0, 2}: { + ID: [2]byte{0, 2}, }, }, }, }, - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Access: accessBitmap{255, 255, 255, 255, 255, 255, 255, 255}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, }, - uint16(2): { + [2]byte{0, 2}: { Account: &Account{ Access: accessBitmap{0, 0, 0, 0, 0, 0, 0, 0}, }, - ID: &[]byte{0, 2}, + ID: [2]byte{0, 2}, }, - uint16(3): { + [2]byte{0, 3}: { Account: &Account{ Access: accessBitmap{0, 0, 0, 0, 0, 0, 0, 0}, }, - ID: &[]byte{0, 3}, + ID: [2]byte{0, 3}, }, }, }, }, - t: &Transaction{ + t: Transaction{ Fields: []Field{ NewField(FieldData, []byte("hai")), NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -628,7 +600,7 @@ func TestHandleChatSend(t *testing.T) { }, want: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -636,7 +608,7 @@ func TestHandleChatSend(t *testing.T) { }, }, { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Type: [2]byte{0, 0x6a}, Fields: []Field{ NewField(FieldChatID, []byte{0, 0, 0, 1}), @@ -644,17 +616,11 @@ func TestHandleChatSend(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := HandleChatSend(tt.args.cc, tt.args.t) - - if (err != nil) != tt.wantErr { - t.Errorf("HandleChatSend() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := HandleChatSend(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.want, got) }) } @@ -663,19 +629,18 @@ func TestHandleChatSend(t *testing.T) { func TestHandleGetFileInfo(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr bool }{ { name: "returns expected fields when a valid file is requested", args: args{ cc: &ClientConn{ - ID: &[]byte{0x00, 0x01}, + ID: [2]byte{0x00, 0x01}, Server: &Server{ FS: &OSFileStore{}, Config: &Config{ @@ -687,14 +652,14 @@ func TestHandleGetFileInfo(t *testing.T) { }, }, t: NewTransaction( - TranGetFileInfo, nil, + TranGetFileInfo, [2]byte{}, NewField(FieldFileName, []byte("testfile.txt")), NewField(FieldFilePath, []byte{0x00, 0x00}), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Type: [2]byte{0, 0}, Fields: []Field{ @@ -708,18 +673,13 @@ func TestHandleGetFileInfo(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetFileInfo(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleGetFileInfo() error = %v, wantErr %v", err, tt.wantErr) - return - } + gotRes := HandleGetFileInfo(tt.args.cc, &tt.args.t) - // Clear the fileWrapper timestamp fields to work around problems running the tests in multiple timezones + // Clear the file timestamp fields to work around problems running the tests in multiple timezones // TODO: revisit how to test this by mocking the stat calls gotRes[0].Fields[4].Data = make([]byte, 8) gotRes[0].Fields[5].Data = make([]byte, 8) @@ -734,13 +694,12 @@ func TestHandleGetFileInfo(t *testing.T) { func TestHandleNewFolder(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr bool }{ { name: "without required permission", @@ -754,8 +713,8 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - accessCreateFolder, - &[]byte{0, 0}, + TranNewFolder, + [2]byte{0, 0}, ), }, wantRes: []Transaction{ @@ -767,7 +726,6 @@ func TestHandleNewFolder(t *testing.T) { }, }, }, - wantErr: false, }, { name: "when path is nested", @@ -780,7 +738,7 @@ func TestHandleNewFolder(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ Config: &Config{ FileRoot: "/Files/", @@ -794,7 +752,7 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - TranNewFolder, &[]byte{0, 1}, + TranNewFolder, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFolder")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -806,11 +764,10 @@ func TestHandleNewFolder(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, }, }, - wantErr: false, }, { name: "when path is not nested", @@ -823,7 +780,7 @@ func TestHandleNewFolder(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ Config: &Config{ FileRoot: "/Files", @@ -837,17 +794,16 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - TranNewFolder, &[]byte{0, 1}, + TranNewFolder, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFolder")), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, }, }, - wantErr: false, }, { name: "when Write returns an err", @@ -860,7 +816,7 @@ func TestHandleNewFolder(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ Config: &Config{ FileRoot: "/Files/", @@ -874,7 +830,7 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - TranNewFolder, &[]byte{0, 1}, + TranNewFolder, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFolder")), NewField(FieldFilePath, []byte{ 0x00, @@ -882,7 +838,6 @@ func TestHandleNewFolder(t *testing.T) { ), }, wantRes: []Transaction{}, - wantErr: true, }, { name: "FieldFileName does not allow directory traversal", @@ -895,7 +850,7 @@ func TestHandleNewFolder(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ Config: &Config{ FileRoot: "/Files/", @@ -909,16 +864,16 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - TranNewFolder, &[]byte{0, 1}, + TranNewFolder, [2]byte{0, 1}, NewField(FieldFileName, []byte("../../testFolder")), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, }, - }, wantErr: false, + }, }, { name: "FieldFilePath does not allow directory traversal", @@ -931,7 +886,7 @@ func TestHandleNewFolder(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ Config: &Config{ FileRoot: "/Files/", @@ -945,7 +900,7 @@ func TestHandleNewFolder(t *testing.T) { }, }, t: NewTransaction( - TranNewFolder, &[]byte{0, 1}, + TranNewFolder, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFolder")), NewField(FieldFilePath, []byte{ 0x00, 0x02, @@ -960,19 +915,15 @@ func TestHandleNewFolder(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, }, - }, wantErr: false, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleNewFolder(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleNewFolder() error = %v, wantErr %v", err, tt.wantErr) - return - } + gotRes := HandleNewFolder(tt.args.cc, &tt.args.t) if !tranAssertEqual(t, tt.wantRes, gotRes) { t.Errorf("HandleNewFolder() gotRes = %v, want %v", gotRes, tt.wantRes) @@ -984,13 +935,12 @@ func TestHandleNewFolder(t *testing.T) { func TestHandleUploadFile(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr bool }{ { name: "when request is valid and user has Upload Anywhere permission", @@ -1015,7 +965,7 @@ func TestHandleUploadFile(t *testing.T) { }, }, t: NewTransaction( - TranUploadFile, &[]byte{0, 1}, + TranUploadFile, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFile")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -1033,7 +983,6 @@ func TestHandleUploadFile(t *testing.T) { }, }, }, - wantErr: false, }, { name: "when user does not have required access", @@ -1047,7 +996,7 @@ func TestHandleUploadFile(t *testing.T) { }, }, t: NewTransaction( - TranUploadFile, &[]byte{0, 1}, + TranUploadFile, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFile")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -1066,17 +1015,11 @@ func TestHandleUploadFile(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleUploadFile(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleUploadFile() error = %v, wantErr %v", err, tt.wantErr) - return - } - + gotRes := HandleUploadFile(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1085,13 +1028,12 @@ func TestHandleUploadFile(t *testing.T) { func TestHandleMakeAlias(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr bool }{ { name: "with valid input and required permissions", @@ -1126,7 +1068,7 @@ func TestHandleMakeAlias(t *testing.T) { }, }, t: NewTransaction( - TranMakeFileAlias, &[]byte{0, 1}, + TranMakeFileAlias, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFile")), NewField(FieldFilePath, EncodeFilePath(strings.Join([]string{"foo"}, "/"))), NewField(FieldFileNewPath, EncodeFilePath(strings.Join([]string{"bar"}, "/"))), @@ -1138,7 +1080,6 @@ func TestHandleMakeAlias(t *testing.T) { Fields: []Field(nil), }, }, - wantErr: false, }, { name: "when symlink returns an error", @@ -1173,7 +1114,7 @@ func TestHandleMakeAlias(t *testing.T) { }, }, t: NewTransaction( - TranMakeFileAlias, &[]byte{0, 1}, + TranMakeFileAlias, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFile")), NewField(FieldFilePath, EncodeFilePath(strings.Join([]string{"foo"}, "/"))), NewField(FieldFileNewPath, EncodeFilePath(strings.Join([]string{"bar"}, "/"))), @@ -1188,7 +1129,6 @@ func TestHandleMakeAlias(t *testing.T) { }, }, }, - wantErr: false, }, { name: "when user does not have required permission", @@ -1211,7 +1151,7 @@ func TestHandleMakeAlias(t *testing.T) { }, }, t: NewTransaction( - TranMakeFileAlias, &[]byte{0, 1}, + TranMakeFileAlias, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFile")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -1236,17 +1176,11 @@ func TestHandleMakeAlias(t *testing.T) { }, }, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleMakeAlias(tt.args.cc, tt.args.t) - if (err != nil) != tt.wantErr { - t.Errorf("HandleMakeAlias(%v, %v)", tt.args.cc, tt.args.t) - return - } - + gotRes := HandleMakeAlias(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1255,13 +1189,12 @@ func TestHandleMakeAlias(t *testing.T) { func TestHandleGetUser(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when account is valid", @@ -1286,7 +1219,7 @@ func TestHandleGetUser(t *testing.T) { }, }, t: NewTransaction( - TranGetUser, &[]byte{0, 1}, + TranGetUser, [2]byte{0, 1}, NewField(FieldUserLogin, []byte("guest")), ), }, @@ -1301,7 +1234,6 @@ func TestHandleGetUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user does not have required permission", @@ -1318,7 +1250,7 @@ func TestHandleGetUser(t *testing.T) { }, }, t: NewTransaction( - TranGetUser, &[]byte{0, 1}, + TranGetUser, [2]byte{0, 1}, NewField(FieldUserLogin, []byte("nonExistentUser")), ), }, @@ -1331,7 +1263,6 @@ func TestHandleGetUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when account does not exist", @@ -1349,7 +1280,7 @@ func TestHandleGetUser(t *testing.T) { }, }, t: NewTransaction( - TranGetUser, &[]byte{0, 1}, + TranGetUser, [2]byte{0, 1}, NewField(FieldUserLogin, []byte("nonExistentUser")), ), }, @@ -1364,16 +1295,11 @@ func TestHandleGetUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetUser(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetUser(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleGetUser(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1382,13 +1308,12 @@ func TestHandleGetUser(t *testing.T) { func TestHandleDeleteUser(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user dataFile", @@ -1418,7 +1343,7 @@ func TestHandleDeleteUser(t *testing.T) { }, }, t: NewTransaction( - TranDeleteUser, &[]byte{0, 1}, + TranDeleteUser, [2]byte{0, 1}, NewField(FieldUserLogin, encodeString([]byte("testuser"))), ), }, @@ -1430,7 +1355,6 @@ func TestHandleDeleteUser(t *testing.T) { Fields: []Field(nil), }, }, - wantErr: assert.NoError, }, { name: "when user does not have required permission", @@ -1447,7 +1371,7 @@ func TestHandleDeleteUser(t *testing.T) { }, }, t: NewTransaction( - TranDeleteUser, &[]byte{0, 1}, + TranDeleteUser, [2]byte{0, 1}, NewField(FieldUserLogin, encodeString([]byte("testuser"))), ), }, @@ -1460,16 +1384,11 @@ func TestHandleDeleteUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDeleteUser(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDeleteUser(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleDeleteUser(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1478,13 +1397,12 @@ func TestHandleDeleteUser(t *testing.T) { func TestHandleGetMsgs(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "returns news data", @@ -1502,7 +1420,7 @@ func TestHandleGetMsgs(t *testing.T) { }, }, t: NewTransaction( - TranGetMsgs, &[]byte{0, 1}, + TranGetMsgs, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -1513,7 +1431,6 @@ func TestHandleGetMsgs(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user does not have required permission", @@ -1530,7 +1447,7 @@ func TestHandleGetMsgs(t *testing.T) { }, }, t: NewTransaction( - TranGetMsgs, &[]byte{0, 1}, + TranGetMsgs, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -1542,16 +1459,11 @@ func TestHandleGetMsgs(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetMsgs(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetMsgs(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleGetMsgs(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1560,13 +1472,12 @@ func TestHandleGetMsgs(t *testing.T) { func TestHandleNewUser(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -1583,7 +1494,7 @@ func TestHandleNewUser(t *testing.T) { }, }, t: NewTransaction( - TranNewUser, &[]byte{0, 1}, + TranNewUser, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -1595,7 +1506,6 @@ func TestHandleNewUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user attempts to create account with greater access", @@ -1613,7 +1523,7 @@ func TestHandleNewUser(t *testing.T) { }, }, t: NewTransaction( - TranNewUser, &[]byte{0, 1}, + TranNewUser, [2]byte{0, 1}, NewField(FieldUserLogin, []byte("userB")), NewField( FieldUserAccess, @@ -1634,16 +1544,11 @@ func TestHandleNewUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleNewUser(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleNewUser(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleNewUser(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1652,13 +1557,12 @@ func TestHandleNewUser(t *testing.T) { func TestHandleListUsers(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -1675,7 +1579,7 @@ func TestHandleListUsers(t *testing.T) { }, }, t: NewTransaction( - TranNewUser, &[]byte{0, 1}, + TranNewUser, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -1687,7 +1591,6 @@ func TestHandleListUsers(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user has required permission", @@ -1712,7 +1615,7 @@ func TestHandleListUsers(t *testing.T) { }, }, t: NewTransaction( - TranGetClientInfoText, &[]byte{0, 1}, + TranGetClientInfoText, [2]byte{0, 1}, NewField(FieldUserID, []byte{0, 1}), ), }, @@ -1728,15 +1631,11 @@ func TestHandleListUsers(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleListUsers(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleListUsers(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleListUsers(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) @@ -1746,13 +1645,12 @@ func TestHandleListUsers(t *testing.T) { func TestHandleDownloadFile(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -1766,7 +1664,7 @@ func TestHandleDownloadFile(t *testing.T) { }, Server: &Server{}, }, - t: NewTransaction(TranDownloadFile, &[]byte{0, 1}), + t: NewTransaction(TranDownloadFile, [2]byte{0, 1}), }, wantRes: []Transaction{ { @@ -1777,7 +1675,6 @@ func TestHandleDownloadFile(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "with a valid file", @@ -1803,8 +1700,8 @@ func TestHandleDownloadFile(t *testing.T) { }, }, t: NewTransaction( - accessDownloadFile, - &[]byte{0, 1}, + TranDownloadFile, + [2]byte{0, 1}, NewField(FieldFileName, []byte("testfile.txt")), NewField(FieldFilePath, []byte{0x0, 0x00}), ), @@ -1820,7 +1717,6 @@ func TestHandleDownloadFile(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when client requests to resume 1k test file at offset 256", @@ -1863,30 +1759,23 @@ func TestHandleDownloadFile(t *testing.T) { }, }, t: NewTransaction( - accessDownloadFile, - &[]byte{0, 1}, + TranDownloadFile, + [2]byte{0, 1}, NewField(FieldFileName, []byte("testfile-1k")), NewField(FieldFilePath, []byte{0x00, 0x00}), NewField( FieldFileResumeData, func() []byte { frd := FileResumeData{ - Format: [4]byte{}, - Version: [2]byte{}, - RSVD: [34]byte{}, ForkCount: [2]byte{0, 2}, ForkInfoList: []ForkInfoList{ { Fork: [4]byte{0x44, 0x41, 0x54, 0x41}, // "DATA" DataSize: [4]byte{0, 0, 0x01, 0x00}, // request offset 256 - RSVDA: [4]byte{}, - RSVDB: [4]byte{}, }, { Fork: [4]byte{0x4d, 0x41, 0x43, 0x52}, // "MACR" DataSize: [4]byte{0, 0, 0, 0}, - RSVDA: [4]byte{}, - RSVDB: [4]byte{}, }, }, } @@ -1907,16 +1796,11 @@ func TestHandleDownloadFile(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDownloadFile(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDownloadFile(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleDownloadFile(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -1925,13 +1809,12 @@ func TestHandleDownloadFile(t *testing.T) { func TestHandleUpdateUser(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when action is create user without required permission", @@ -1950,7 +1833,7 @@ func TestHandleUpdateUser(t *testing.T) { }, t: NewTransaction( TranUpdateUser, - &[]byte{0, 0}, + [2]byte{0, 0}, NewField(FieldData, []byte{ 0x00, 0x04, // field count @@ -1981,7 +1864,6 @@ func TestHandleUpdateUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when action is modify user without required permission", @@ -2003,7 +1885,7 @@ func TestHandleUpdateUser(t *testing.T) { }, t: NewTransaction( TranUpdateUser, - &[]byte{0, 0}, + [2]byte{0, 0}, NewField(FieldData, []byte{ 0x00, 0x04, // field count @@ -2034,7 +1916,6 @@ func TestHandleUpdateUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when action is delete user without required permission", @@ -2055,7 +1936,7 @@ func TestHandleUpdateUser(t *testing.T) { }, t: NewTransaction( TranUpdateUser, - &[]byte{0, 0}, + [2]byte{0, 0}, NewField(FieldData, []byte{ 0x00, 0x01, 0x00, 0x65, @@ -2073,16 +1954,11 @@ func TestHandleUpdateUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleUpdateUser(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleUpdateUser(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleUpdateUser(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2091,13 +1967,12 @@ func TestHandleUpdateUser(t *testing.T) { func TestHandleDelNewsArt(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "without required permission", @@ -2112,7 +1987,7 @@ func TestHandleDelNewsArt(t *testing.T) { }, t: NewTransaction( TranDelNewsArt, - &[]byte{0, 0}, + [2]byte{0, 0}, ), }, wantRes: []Transaction{ @@ -2124,15 +1999,11 @@ func TestHandleDelNewsArt(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDelNewsArt(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDelNewsArt(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleDelNewsArt(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2141,13 +2012,12 @@ func TestHandleDelNewsArt(t *testing.T) { func TestHandleDisconnectUser(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "without required permission", @@ -2162,7 +2032,7 @@ func TestHandleDisconnectUser(t *testing.T) { }, t: NewTransaction( TranDelNewsArt, - &[]byte{0, 0}, + [2]byte{0, 0}, ), }, wantRes: []Transaction{ @@ -2174,15 +2044,14 @@ func TestHandleDisconnectUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when target user has 'cannot be disconnected' priv", args: args{ cc: &ClientConn{ Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { Account: &Account{ Login: "unnamed", Access: func() accessBitmap { @@ -2204,7 +2073,7 @@ func TestHandleDisconnectUser(t *testing.T) { }, t: NewTransaction( TranDelNewsArt, - &[]byte{0, 0}, + [2]byte{0, 0}, NewField(FieldUserID, []byte{0, 1}), ), }, @@ -2217,15 +2086,11 @@ func TestHandleDisconnectUser(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDisconnectUser(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDisconnectUser(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleDisconnectUser(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2234,13 +2099,12 @@ func TestHandleDisconnectUser(t *testing.T) { func TestHandleSendInstantMsg(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "without required permission", @@ -2255,7 +2119,7 @@ func TestHandleSendInstantMsg(t *testing.T) { }, t: NewTransaction( TranDelNewsArt, - &[]byte{0, 0}, + [2]byte{0, 0}, ), }, wantRes: []Transaction{ @@ -2267,7 +2131,6 @@ func TestHandleSendInstantMsg(t *testing.T) { }, }, }, - wantErr: assert.Error, }, { name: "when client 1 sends a message to client 2", @@ -2280,40 +2143,39 @@ func TestHandleSendInstantMsg(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, UserName: []byte("User1"), Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(2): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 2}: { AutoReply: []byte(nil), - Flags: []byte{0, 0}, + Flags: [2]byte{0, 0}, }, }, }, }, t: NewTransaction( TranSendInstantMsg, - &[]byte{0, 1}, + [2]byte{0, 1}, NewField(FieldData, []byte("hai")), NewField(FieldUserID, []byte{0, 2}), ), }, wantRes: []Transaction{ - *NewTransaction( + NewTransaction( TranServerMsg, - &[]byte{0, 2}, + [2]byte{0, 2}, NewField(FieldData, []byte("hai")), NewField(FieldUserName, []byte("User1")), NewField(FieldUserID, []byte{0, 1}), NewField(FieldOptions, []byte{0, 1}), ), { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field(nil), }, }, - wantErr: assert.NoError, }, { name: "when client 2 has autoreply enabled", @@ -2326,13 +2188,13 @@ func TestHandleSendInstantMsg(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, UserName: []byte("User1"), Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(2): { - Flags: []byte{0, 0}, - ID: &[]byte{0, 2}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 2}: { + Flags: [2]byte{0, 0}, + ID: [2]byte{0, 2}, UserName: []byte("User2"), AutoReply: []byte("autohai"), }, @@ -2341,35 +2203,34 @@ func TestHandleSendInstantMsg(t *testing.T) { }, t: NewTransaction( TranSendInstantMsg, - &[]byte{0, 1}, + [2]byte{0, 1}, NewField(FieldData, []byte("hai")), NewField(FieldUserID, []byte{0, 2}), ), }, wantRes: []Transaction{ - *NewTransaction( + NewTransaction( TranServerMsg, - &[]byte{0, 2}, + [2]byte{0, 2}, NewField(FieldData, []byte("hai")), NewField(FieldUserName, []byte("User1")), NewField(FieldUserID, []byte{0, 1}), NewField(FieldOptions, []byte{0, 1}), ), - *NewTransaction( + NewTransaction( TranServerMsg, - &[]byte{0, 1}, + [2]byte{0, 1}, NewField(FieldData, []byte("autohai")), NewField(FieldUserName, []byte("User2")), NewField(FieldUserID, []byte{0, 2}), NewField(FieldOptions, []byte{0, 1}), ), { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field(nil), }, }, - wantErr: assert.NoError, }, { name: "when client 2 has refuse private messages enabled", @@ -2382,13 +2243,13 @@ func TestHandleSendInstantMsg(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, UserName: []byte("User1"), Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(2): { - Flags: []byte{255, 255}, - ID: &[]byte{0, 2}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 2}: { + Flags: [2]byte{255, 255}, + ID: [2]byte{0, 2}, UserName: []byte("User2"), }, }, @@ -2396,36 +2257,31 @@ func TestHandleSendInstantMsg(t *testing.T) { }, t: NewTransaction( TranSendInstantMsg, - &[]byte{0, 1}, + [2]byte{0, 1}, NewField(FieldData, []byte("hai")), NewField(FieldUserID, []byte{0, 2}), ), }, wantRes: []Transaction{ - *NewTransaction( + NewTransaction( TranServerMsg, - &[]byte{0, 1}, + [2]byte{0, 1}, NewField(FieldData, []byte("User2 does not accept private messages.")), NewField(FieldUserName, []byte("User2")), NewField(FieldUserID, []byte{0, 2}), NewField(FieldOptions, []byte{0, 2}), ), { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field(nil), }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleSendInstantMsg(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleSendInstantMsg(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleSendInstantMsg(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2434,13 +2290,12 @@ func TestHandleSendInstantMsg(t *testing.T) { func TestHandleDeleteFile(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission to delete a folder", @@ -2477,7 +2332,7 @@ func TestHandleDeleteFile(t *testing.T) { }, }, t: NewTransaction( - TranDeleteFile, &[]byte{0, 1}, + TranDeleteFile, [2]byte{0, 1}, NewField(FieldFileName, []byte("testfile")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -2496,7 +2351,6 @@ func TestHandleDeleteFile(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "deletes all associated metadata files", @@ -2539,7 +2393,7 @@ func TestHandleDeleteFile(t *testing.T) { }, }, t: NewTransaction( - TranDeleteFile, &[]byte{0, 1}, + TranDeleteFile, [2]byte{0, 1}, NewField(FieldFileName, []byte("testfile")), NewField(FieldFilePath, []byte{ 0x00, 0x01, @@ -2555,16 +2409,11 @@ func TestHandleDeleteFile(t *testing.T) { Fields: []Field(nil), }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDeleteFile(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDeleteFile(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleDeleteFile(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) tt.args.cc.Server.FS.(*MockFileStore).AssertExpectations(t) @@ -2575,13 +2424,12 @@ func TestHandleDeleteFile(t *testing.T) { func TestHandleGetFileNameList(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when FieldFilePath is a drop box, but user does not have accessViewDropBoxes ", @@ -2604,7 +2452,7 @@ func TestHandleGetFileNameList(t *testing.T) { }, }, t: NewTransaction( - TranGetFileNameList, &[]byte{0, 1}, + TranGetFileNameList, [2]byte{0, 1}, NewField(FieldFilePath, []byte{ 0x00, 0x01, 0x00, 0x00, @@ -2622,7 +2470,6 @@ func TestHandleGetFileNameList(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "with file root", @@ -2638,7 +2485,7 @@ func TestHandleGetFileNameList(t *testing.T) { }, }, t: NewTransaction( - TranGetFileNameList, &[]byte{0, 1}, + TranGetFileNameList, [2]byte{0, 1}, NewField(FieldFilePath, []byte{ 0x00, 0x00, 0x00, 0x00, @@ -2670,16 +2517,11 @@ func TestHandleGetFileNameList(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetFileNameList(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetFileNameList(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleGetFileNameList(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2688,13 +2530,12 @@ func TestHandleGetFileNameList(t *testing.T) { func TestHandleGetClientInfoText(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -2711,7 +2552,7 @@ func TestHandleGetClientInfoText(t *testing.T) { }, }, t: NewTransaction( - TranGetClientInfoText, &[]byte{0, 1}, + TranGetClientInfoText, [2]byte{0, 1}, NewField(FieldUserID, []byte{0, 1}), ), }, @@ -2724,7 +2565,6 @@ func TestHandleGetClientInfoText(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "with a valid user", @@ -2743,8 +2583,8 @@ func TestHandleGetClientInfoText(t *testing.T) { }, Server: &Server{ Accounts: map[string]*Account{}, - Clients: map[uint16]*ClientConn{ - uint16(1): { + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { UserName: []byte("Testy McTest"), RemoteAddr: "1.2.3.4:12345", Account: &Account{ @@ -2767,7 +2607,7 @@ func TestHandleGetClientInfoText(t *testing.T) { }, }, t: NewTransaction( - TranGetClientInfoText, &[]byte{0, 1}, + TranGetClientInfoText, [2]byte{0, 1}, NewField(FieldUserID, []byte{0, 1}), ), }, @@ -2807,15 +2647,11 @@ None. }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetClientInfoText(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetClientInfoText(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleGetClientInfoText(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2824,13 +2660,12 @@ None. func TestHandleTranAgreed(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "normal request flow", @@ -2844,9 +2679,9 @@ func TestHandleTranAgreed(t *testing.T) { return bits }()}, Icon: []byte{0, 1}, - Flags: []byte{0, 1}, + Flags: [2]byte{0, 1}, Version: []byte{0, 1}, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, logger: NewTestLogger(), Server: &Server{ Config: &Config{ @@ -2855,7 +2690,7 @@ func TestHandleTranAgreed(t *testing.T) { }, }, t: NewTransaction( - TranAgreed, nil, + TranAgreed, [2]byte{}, NewField(FieldUserName, []byte("username")), NewField(FieldUserIconID, []byte{0, 1}), NewField(FieldOptions, []byte{0, 0}), @@ -2863,27 +2698,23 @@ func TestHandleTranAgreed(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x7a}, Fields: []Field{ NewField(FieldBannerType, []byte("JPEG")), }, }, { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{}, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleTranAgreed(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleTranAgreed(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleTranAgreed(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2892,13 +2723,12 @@ func TestHandleTranAgreed(t *testing.T) { func TestHandleSetClientUserInfo(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when client does not have accessAnyName", @@ -2910,26 +2740,26 @@ func TestHandleSetClientUserInfo(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, UserName: []byte("Guest"), - Flags: []byte{0, 1}, + Flags: [2]byte{0, 1}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(1): { - ID: &[]byte{0, 1}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 1}: { + ID: [2]byte{0, 1}, }, }, }, }, t: NewTransaction( - TranSetClientUserInfo, nil, + TranSetClientUserInfo, [2]byte{}, NewField(FieldUserIconID, []byte{0, 1}), NewField(FieldUserName, []byte("NOPE")), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0x01, 0x2d}, Fields: []Field{ NewField(FieldUserID, []byte{0, 1}), @@ -2938,16 +2768,11 @@ func TestHandleSetClientUserInfo(t *testing.T) { NewField(FieldUserName, []byte("Guest"))}, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleSetClientUserInfo(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleSetClientUserInfo(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - + gotRes := HandleSetClientUserInfo(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -2956,13 +2781,12 @@ func TestHandleSetClientUserInfo(t *testing.T) { func TestHandleDelNewsItem(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have permission to delete a news category", @@ -2971,7 +2795,7 @@ func TestHandleDelNewsItem(t *testing.T) { Account: &Account{ Access: accessBitmap{}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ ThreadedNews: &ThreadedNews{Categories: map[string]NewsCategoryListData15{ "test": { @@ -2982,7 +2806,7 @@ func TestHandleDelNewsItem(t *testing.T) { }, }, t: NewTransaction( - TranDelNewsItem, nil, + TranDelNewsItem, [2]byte{}, NewField(FieldNewsPath, []byte{ 0, 1, @@ -2995,7 +2819,7 @@ func TestHandleDelNewsItem(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, ErrorCode: [4]byte{0, 0, 0, 1}, Fields: []Field{ @@ -3003,7 +2827,6 @@ func TestHandleDelNewsItem(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user does not have permission to delete a news folder", @@ -3012,7 +2835,7 @@ func TestHandleDelNewsItem(t *testing.T) { Account: &Account{ Access: accessBitmap{}, }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ ThreadedNews: &ThreadedNews{Categories: map[string]NewsCategoryListData15{ "testcat": { @@ -3023,7 +2846,7 @@ func TestHandleDelNewsItem(t *testing.T) { }, }, t: NewTransaction( - TranDelNewsItem, nil, + TranDelNewsItem, [2]byte{}, NewField(FieldNewsPath, []byte{ 0, 1, @@ -3036,7 +2859,7 @@ func TestHandleDelNewsItem(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, ErrorCode: [4]byte{0, 0, 0, 1}, Fields: []Field{ @@ -3044,7 +2867,6 @@ func TestHandleDelNewsItem(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user deletes a news folder", @@ -3057,7 +2879,7 @@ func TestHandleDelNewsItem(t *testing.T) { return bits }(), }, - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ ConfigDir: "/fakeConfigRoot", FS: func() *MockFileStore { @@ -3074,7 +2896,7 @@ func TestHandleDelNewsItem(t *testing.T) { }, }, t: NewTransaction( - TranDelNewsItem, nil, + TranDelNewsItem, [2]byte{}, NewField(FieldNewsPath, []byte{ 0, 1, @@ -3087,20 +2909,17 @@ func TestHandleDelNewsItem(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{}, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDelNewsItem(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDelNewsItem(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleDelNewsItem(tt.args.cc, &tt.args.t) + tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -3109,13 +2928,12 @@ func TestHandleDelNewsItem(t *testing.T) { func TestHandleTranOldPostNews(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -3129,7 +2947,7 @@ func TestHandleTranOldPostNews(t *testing.T) { }, }, t: NewTransaction( - TranOldPostNews, &[]byte{0, 1}, + TranOldPostNews, [2]byte{0, 1}, NewField(FieldData, []byte("hai")), ), }, @@ -3142,7 +2960,6 @@ func TestHandleTranOldPostNews(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user posts news update", @@ -3166,7 +2983,7 @@ func TestHandleTranOldPostNews(t *testing.T) { }, }, t: NewTransaction( - TranOldPostNews, &[]byte{0, 1}, + TranOldPostNews, [2]byte{0, 1}, NewField(FieldData, []byte("hai")), ), }, @@ -3175,15 +2992,11 @@ func TestHandleTranOldPostNews(t *testing.T) { IsReply: 0x01, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleTranOldPostNews(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleTranOldPostNews(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleTranOldPostNews(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) @@ -3193,13 +3006,12 @@ func TestHandleTranOldPostNews(t *testing.T) { func TestHandleInviteNewChat(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -3212,7 +3024,7 @@ func TestHandleInviteNewChat(t *testing.T) { }(), }, }, - t: NewTransaction(TranInviteNewChat, &[]byte{0, 1}), + t: NewTransaction(TranInviteNewChat, [2]byte{0, 1}), }, wantRes: []Transaction{ { @@ -3223,13 +3035,12 @@ func TestHandleInviteNewChat(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when userA invites userB to new private chat", args: args{ cc: &ClientConn{ - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Account: &Account{ Access: func() accessBitmap { var bits accessBitmap @@ -3239,26 +3050,25 @@ func TestHandleInviteNewChat(t *testing.T) { }, UserName: []byte("UserA"), Icon: []byte{0, 1}, - Flags: []byte{0, 0}, + Flags: [2]byte{0, 0}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(2): { - ID: &[]byte{0, 2}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 2}: { + ID: [2]byte{0, 2}, UserName: []byte("UserB"), - Flags: []byte{0, 0}, }, }, - PrivateChats: make(map[uint32]*PrivateChat), + PrivateChats: make(map[[4]byte]*PrivateChat), }, }, t: NewTransaction( - TranInviteNewChat, &[]byte{0, 1}, + TranInviteNewChat, [2]byte{0, 1}, NewField(FieldUserID, []byte{0, 2}), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 2}, + clientID: [2]byte{0, 2}, Type: [2]byte{0, 0x71}, Fields: []Field{ NewField(FieldChatID, []byte{0x52, 0xfd, 0xfc, 0x07}), @@ -3268,7 +3078,7 @@ func TestHandleInviteNewChat(t *testing.T) { }, { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{ NewField(FieldChatID, []byte{0x52, 0xfd, 0xfc, 0x07}), @@ -3279,13 +3089,12 @@ func TestHandleInviteNewChat(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when userA invites userB to new private chat, but UserB has refuse private chat enabled", args: args{ cc: &ClientConn{ - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Account: &Account{ Access: func() accessBitmap { var bits accessBitmap @@ -3295,26 +3104,26 @@ func TestHandleInviteNewChat(t *testing.T) { }, UserName: []byte("UserA"), Icon: []byte{0, 1}, - Flags: []byte{0, 0}, + Flags: [2]byte{0, 0}, Server: &Server{ - Clients: map[uint16]*ClientConn{ - uint16(2): { - ID: &[]byte{0, 2}, + Clients: map[[2]byte]*ClientConn{ + [2]byte{0, 2}: { + ID: [2]byte{0, 2}, UserName: []byte("UserB"), - Flags: []byte{255, 255}, + Flags: [2]byte{255, 255}, }, }, - PrivateChats: make(map[uint32]*PrivateChat), + PrivateChats: make(map[[4]byte]*PrivateChat), }, }, t: NewTransaction( - TranInviteNewChat, &[]byte{0, 1}, + TranInviteNewChat, [2]byte{0, 1}, NewField(FieldUserID, []byte{0, 2}), ), }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, Type: [2]byte{0, 0x68}, Fields: []Field{ NewField(FieldData, []byte("UserB does not accept private chats.")), @@ -3324,7 +3133,7 @@ func TestHandleInviteNewChat(t *testing.T) { }, }, { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{ NewField(FieldChatID, []byte{0x52, 0xfd, 0xfc, 0x07}), @@ -3335,15 +3144,12 @@ func TestHandleInviteNewChat(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleInviteNewChat(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleInviteNewChat(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleInviteNewChat(tt.args.cc, &tt.args.t) + tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -3352,13 +3158,12 @@ func TestHandleInviteNewChat(t *testing.T) { func TestHandleGetNewsArtData(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -3375,7 +3180,7 @@ func TestHandleGetNewsArtData(t *testing.T) { }, }, t: NewTransaction( - TranGetNewsArtData, &[]byte{0, 1}, + TranGetNewsArtData, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -3387,15 +3192,11 @@ func TestHandleGetNewsArtData(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetNewsArtData(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetNewsArtData(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleGetNewsArtData(tt.args.cc, &tt.args.t) tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -3404,13 +3205,12 @@ func TestHandleGetNewsArtData(t *testing.T) { func TestHandleGetNewsArtNameList(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -3427,7 +3227,7 @@ func TestHandleGetNewsArtNameList(t *testing.T) { }, }, t: NewTransaction( - TranGetNewsArtNameList, &[]byte{0, 1}, + TranGetNewsArtNameList, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -3441,7 +3241,6 @@ func TestHandleGetNewsArtNameList(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "when user has required access", @@ -3487,7 +3286,7 @@ func TestHandleGetNewsArtNameList(t *testing.T) { }, t: NewTransaction( TranGetNewsArtNameList, - &[]byte{0, 1}, + [2]byte{0, 1}, // 00000000 00 01 00 00 10 45 78 61 6d 70 6c 65 20 43 61 74 |.....Example Cat| // 00000010 65 67 6f 72 79 |egory| NewField(FieldNewsPath, []byte{ @@ -3510,15 +3309,12 @@ func TestHandleGetNewsArtNameList(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleGetNewsArtNameList(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleGetNewsArtNameList(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleGetNewsArtNameList(tt.args.cc, &tt.args.t) + tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -3527,13 +3323,12 @@ func TestHandleGetNewsArtNameList(t *testing.T) { func TestHandleNewNewsFldr(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ { name: "when user does not have required permission", @@ -3550,7 +3345,7 @@ func TestHandleNewNewsFldr(t *testing.T) { }, }, t: NewTransaction( - TranGetNewsArtNameList, &[]byte{0, 1}, + TranGetNewsArtNameList, [2]byte{0, 1}, ), }, wantRes: []Transaction{ @@ -3564,7 +3359,6 @@ func TestHandleNewNewsFldr(t *testing.T) { }, }, }, - wantErr: assert.NoError, }, { name: "with a valid request", @@ -3578,7 +3372,7 @@ func TestHandleNewNewsFldr(t *testing.T) { }(), }, logger: NewTestLogger(), - ID: &[]byte{0, 1}, + ID: [2]byte{0, 1}, Server: &Server{ ConfigDir: "/fakeConfigRoot", FS: func() *MockFileStore { @@ -3596,7 +3390,7 @@ func TestHandleNewNewsFldr(t *testing.T) { }, }, t: NewTransaction( - TranGetNewsArtNameList, &[]byte{0, 1}, + TranGetNewsArtNameList, [2]byte{0, 1}, NewField(FieldFileName, []byte("testFolder")), NewField(FieldNewsPath, []byte{ @@ -3610,12 +3404,11 @@ func TestHandleNewNewsFldr(t *testing.T) { }, wantRes: []Transaction{ { - clientID: &[]byte{0, 1}, + clientID: [2]byte{0, 1}, IsReply: 0x01, Fields: []Field{}, }, }, - wantErr: assert.NoError, }, //{ // Name: "when there is an error writing the threaded news file", @@ -3629,7 +3422,7 @@ func TestHandleNewNewsFldr(t *testing.T) { // }(), // }, // logger: NewTestLogger(), - // ID: &[]byte{0, 1}, + // ID: [2]byte{0, 1}, // Server: &Server{ // ConfigDir: "/fakeConfigRoot", // FS: func() *MockFileStore { @@ -3649,7 +3442,7 @@ func TestHandleNewNewsFldr(t *testing.T) { // }, // }, // t: NewTransaction( - // TranGetNewsArtNameList, &[]byte{0, 1}, + // TranGetNewsArtNameList, [2]byte{0, 1}, // NewField(FieldFileName, []byte("testFolder")), // NewField(FieldNewsPath, // []byte{ @@ -3663,7 +3456,7 @@ func TestHandleNewNewsFldr(t *testing.T) { // }, // wantRes: []Transaction{ // { - // clientID: &[]byte{0, 1}, + // clientID: [2]byte{0, 1}, // Flags: 0x00, // IsReply: 0x01, // Type: [2]byte{0, 0}, @@ -3673,15 +3466,11 @@ func TestHandleNewNewsFldr(t *testing.T) { // }, // }, // }, - // wantErr: assert.Error, - // }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleNewNewsFldr(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleNewNewsFldr(%v, %v)", tt.args.cc, tt.args.t)) { - return - } + gotRes := HandleNewNewsFldr(tt.args.cc, &tt.args.t) + tranAssertEqual(t, tt.wantRes, gotRes) }) } @@ -3690,23 +3479,114 @@ func TestHandleNewNewsFldr(t *testing.T) { func TestHandleDownloadBanner(t *testing.T) { type args struct { cc *ClientConn - t *Transaction + t Transaction } tests := []struct { name string args args wantRes []Transaction - wantErr assert.ErrorAssertionFunc }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := HandleDownloadBanner(tt.args.cc, tt.args.t) - if !tt.wantErr(t, err, fmt.Sprintf("HandleDownloadBanner(%v, %v)", tt.args.cc, tt.args.t)) { - return - } - assert.Equalf(t, tt.wantRes, gotRes, "HandleDownloadBanner(%v, %v)", tt.args.cc, tt.args.t) + gotRes := HandleDownloadBanner(tt.args.cc, &tt.args.t) + + assert.Equalf(t, tt.wantRes, gotRes, "HandleDownloadBanner(%v, %v)", tt.args.cc, &tt.args.t) + }) + } +} + +func TestHandlePostNewsArt(t *testing.T) { + type args struct { + cc *ClientConn + t Transaction + } + tests := []struct { + name string + args args + wantRes []Transaction + }{ + { + name: "without required permission", + args: args{ + cc: &ClientConn{ + Account: &Account{ + Access: func() accessBitmap { + var bits accessBitmap + return bits + }(), + }, + }, + t: NewTransaction( + TranPostNewsArt, + [2]byte{0, 0}, + ), + }, + wantRes: []Transaction{ + { + IsReply: 0x01, + ErrorCode: [4]byte{0, 0, 0, 1}, + Fields: []Field{ + NewField(FieldError, []byte("You are not allowed to post news articles.")), + }, + }, + }, + }, + { + name: "with required permission", + args: args{ + cc: &ClientConn{ + Server: &Server{ + FS: func() *MockFileStore { + mfs := &MockFileStore{} + mfs.On("WriteFile", "ThreadedNews.yaml", mock.Anything, mock.Anything).Return(nil, os.ErrNotExist) + return mfs + }(), + mux: sync.Mutex{}, + threadedNewsMux: sync.Mutex{}, + ThreadedNews: &ThreadedNews{ + Categories: map[string]NewsCategoryListData15{ + "www": { + Type: [2]byte{}, + Name: "www", + Articles: map[uint32]*NewsArtData{}, + SubCats: nil, + GUID: [16]byte{}, + AddSN: [4]byte{}, + DeleteSN: [4]byte{}, + readOffset: 0, + }, + }, + }, + }, + Account: &Account{ + Access: func() accessBitmap { + var bits accessBitmap + bits.Set(accessNewsPostArt) + return bits + }(), + }, + }, + t: NewTransaction( + TranPostNewsArt, + [2]byte{0, 0}, + NewField(FieldNewsPath, []byte{0x00, 0x01, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77}), + NewField(FieldNewsArtID, []byte{0x00, 0x00, 0x00, 0x00}), + ), + }, + wantRes: []Transaction{ + { + IsReply: 0x01, + ErrorCode: [4]byte{0, 0, 0, 0}, + Fields: []Field{}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tranAssertEqual(t, tt.wantRes, HandlePostNewsArt(tt.args.cc, &tt.args.t)) }) } } diff --git a/hotline/transaction_test.go b/hotline/transaction_test.go index b5ee990..820b08c 100644 --- a/hotline/transaction_test.go +++ b/hotline/transaction_test.go @@ -301,7 +301,7 @@ func Test_transactionScanner(t *testing.T) { func TestTransaction_Read(t1 *testing.T) { type fields struct { - clientID *[]byte + clientID [2]byte Flags byte IsReply byte Type [2]byte @@ -408,3 +408,80 @@ func TestTransaction_Read(t1 *testing.T) { }) } } + +func TestTransaction_Write(t1 *testing.T) { + type args struct { + p []byte + } + tests := []struct { + name string + args args + wantN int + wantErr assert.ErrorAssertionFunc + wantTransaction Transaction + }{ + { + name: "returns error if arg p is too small", + args: args{p: []byte{ + 0x00, 0x00, + }}, + wantN: 0, + wantErr: assert.Error, + wantTransaction: Transaction{}, + }, + //{ + // name: "returns error if param data is invalid", + // args: args{p: []byte{ + // 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x15, 0x72, + // 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, + // 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x65, + // 0x00, 0x03, 0x68, 0x61, 0x69, + // }}, + // wantN: 0, + // wantErr: assert.Error, + // wantTransaction: Transaction{}, + //}, + { + name: "writes bytes to transaction", + args: args{p: []byte{ + 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x15, 0x72, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, + 0x00, 0x00, 0x00, 0x09, 0x00, 0x01, 0x00, 0x65, + 0x00, 0x03, 0x68, 0x61, 0x69, + }}, + wantN: 29, + wantErr: assert.NoError, + wantTransaction: Transaction{ + Flags: 0, + IsReply: 0, + Type: TranChatSend, + ID: [4]byte{}, + ErrorCode: [4]byte{}, + TotalSize: [4]byte{0, 0, 0, 9}, + DataSize: [4]byte{0, 0, 0, 9}, + ParamCount: [2]byte{0, 1}, + Fields: []Field{ + { + ID: FieldData, + FieldSize: [2]byte{0, 3}, + Data: []byte("hai"), + }, + }, + clientID: [2]byte{}, + readOffset: 0, + }, + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &Transaction{} + gotN, err := t.Write(tt.args.p) + if !tt.wantErr(t1, err, fmt.Sprintf("Write(%v)", tt.args.p)) { + return + } + assert.Equalf(t1, tt.wantN, gotN, "Write(%v)", tt.args.p) + + tranAssertEqual(t1, []Transaction{tt.wantTransaction}, []Transaction{*t}) + }) + } +} diff --git a/hotline/user.go b/hotline/user.go index 02deb31..26625a2 100644 --- a/hotline/user.go +++ b/hotline/user.go @@ -3,26 +3,40 @@ package hotline import ( "encoding/binary" "io" + "math/big" "slices" ) // User flags are stored as a 2 byte bitmap and represent various user states const ( - UserFlagAway = iota // User is away - UserFlagAdmin // User is admin - UserFlagRefusePM // User refuses private messages - UserFlagRefusePChat // User refuses private chat + UserFlagAway = 0 // User is away + UserFlagAdmin = 1 // User is admin + UserFlagRefusePM = 2 // User refuses private messages + UserFlagRefusePChat = 3 // User refuses private chat ) // FieldOptions flags are sent from v1.5+ clients as part of TranAgreed const ( - UserOptRefusePM = iota // User has "Refuse private messages" pref set - UserOptRefuseChat // User has "Refuse private chat" pref set - UserOptAutoResponse // User has "Automatic response" pref set + UserOptRefusePM = 0 // User has "Refuse private messages" pref set + UserOptRefuseChat = 1 // User has "Refuse private chat" pref set + UserOptAutoResponse = 2 // User has "Automatic response" pref set ) +type UserFlags [2]byte + +func (flag *UserFlags) IsSet(i int) bool { + flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(flag[:]))) + return flagBitmap.Bit(i) == 1 +} + +func (flag *UserFlags) Set(i int, newVal uint) { + flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(flag[:]))) + flagBitmap.SetBit(flagBitmap, i, newVal) + binary.BigEndian.PutUint16(flag[:], uint16(flagBitmap.Int64())) +} + type User struct { - ID []byte // Size 2 + ID [2]byte Icon []byte // Size 2 Flags []byte // Size 2 Name string // Variable length user name @@ -43,7 +57,7 @@ func (u *User) Read(p []byte) (int, error) { } b := slices.Concat( - u.ID, + u.ID[:], u.Icon, u.Flags, nameLen, @@ -55,13 +69,14 @@ func (u *User) Read(p []byte) (int, error) { } n := copy(p, b) + u.readOffset = n - return n, io.EOF + return n, nil } func (u *User) Write(p []byte) (int, error) { namelen := int(binary.BigEndian.Uint16(p[6:8])) - u.ID = p[0:2] + u.ID = [2]byte(p[0:2]) u.Icon = p[2:4] u.Flags = p[4:6] u.Name = string(p[8 : 8+namelen]) diff --git a/hotline/user_test.go b/hotline/user_test.go index c55e35d..90c59f7 100644 --- a/hotline/user_test.go +++ b/hotline/user_test.go @@ -28,7 +28,7 @@ func TestReadUser(t *testing.T) { }, }, want: &User{ - ID: []byte{ + ID: [2]byte{ 0x00, 0x01, }, Icon: []byte{