From: Jeff Halter Date: Mon, 17 Jun 2024 20:50:28 +0000 (-0700) Subject: Fix io.Reader implementations and wrap more errors X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/45ca5d60383cbe270624c713b916da29af7ba88f?ds=sidebyside Fix io.Reader implementations and wrap more errors --- diff --git a/hotline/client_conn.go b/hotline/client_conn.go index e21d8b1..c886dfc 100644 --- a/hotline/client_conn.go +++ b/hotline/client_conn.go @@ -81,7 +81,7 @@ func (cc *ClientConn) handleTransaction(transaction Transaction) error { transactions, err := handler.Handler(cc, &transaction) if err != nil { - return err + return fmt.Errorf("error handling transaction: %w", err) } for _, t := range transactions { cc.Server.outbox <- t diff --git a/hotline/file_header.go b/hotline/file_header.go index 71a1c00..469f321 100644 --- a/hotline/file_header.go +++ b/hotline/file_header.go @@ -10,6 +10,8 @@ type FileHeader struct { Size [2]byte // Total size of FileHeader payload Type [2]byte // 0 for file, 1 for dir FilePath []byte // encoded file path + + readOffset int // Internal offset to track read progress } func NewFileHeader(fileName string, isDir bool) FileHeader { @@ -28,10 +30,18 @@ func NewFileHeader(fileName string, isDir bool) FileHeader { } func (fh *FileHeader) Read(p []byte) (int, error) { - return copy(p, slices.Concat( + buf := slices.Concat( fh.Size[:], fh.Type[:], fh.FilePath, - ), - ), io.EOF + ) + + if fh.readOffset >= len(buf) { + return 0, io.EOF // All bytes have been read + } + + n := copy(p, buf[fh.readOffset:]) + fh.readOffset += n + + return n, nil } diff --git a/hotline/file_wrapper.go b/hotline/file_wrapper.go index 997325a..342355b 100644 --- a/hotline/file_wrapper.go +++ b/hotline/file_wrapper.go @@ -246,7 +246,7 @@ func (f *fileWrapper) flattenedFileObject() (*flattenedFileObject, error) { _, err = io.Copy(&f.ffo.FlatFileInformationFork, bytes.NewReader(b)) if err != nil { - return nil, err + return nil, fmt.Errorf("error copying FlatFileInformationFork: %w", err) } } else { diff --git a/hotline/files.go b/hotline/files.go index 0963fbf..901b98c 100644 --- a/hotline/files.go +++ b/hotline/files.go @@ -3,6 +3,7 @@ package hotline import ( "encoding/binary" "errors" + "fmt" "io" "io/fs" "os" @@ -36,10 +37,11 @@ const maxFileSize = 4294967296 func getFileNameList(path string, ignoreList []string) (fields []Field, err error) { files, err := os.ReadDir(path) if err != nil { - return fields, nil + return fields, fmt.Errorf("error reading path: %s: %w", path, err) } - for _, file := range files { + for i, _ := range files { + file := files[i] var fnwi FileNameWithInfo if ignoreFile(file.Name(), ignoreList) { @@ -50,14 +52,14 @@ func getFileNameList(path string, ignoreList []string) (fields []Field, err erro fileInfo, err := file.Info() if err != nil { - return fields, err + return fields, fmt.Errorf("error getting file info: %s: %w", file.Name(), err) } // Check if path is a symlink. If so, follow it. if fileInfo.Mode()&os.ModeSymlink != 0 { resolvedPath, err := os.Readlink(filepath.Join(path, file.Name())) if err != nil { - return fields, err + return fields, fmt.Errorf("error following symlink: %s: %w", resolvedPath, err) } rFile, err := os.Stat(resolvedPath) @@ -92,7 +94,7 @@ func getFileNameList(path string, ignoreList []string) (fields []Field, err erro } else if file.IsDir() { dir, err := os.ReadDir(filepath.Join(path, file.Name())) if err != nil { - return fields, err + return fields, fmt.Errorf("readDir: %w", err) } var c uint32 @@ -113,7 +115,7 @@ func getFileNameList(path string, ignoreList []string) (fields []Field, err erro hlFile, err := newFileWrapper(&OSFileStore{}, path+"/"+file.Name(), 0) if err != nil { - return nil, err + return nil, fmt.Errorf("newFileWrapper: %w", err) } copy(fnwi.FileSize[:], hlFile.totalSize()) @@ -135,7 +137,7 @@ func getFileNameList(path string, ignoreList []string) (fields []Field, err erro b, err := io.ReadAll(&fnwi) if err != nil { - return nil, err + return nil, fmt.Errorf("error io.ReadAll: %w", err) } fields = append(fields, NewField(FieldFileNameWithInfo, b)) } diff --git a/hotline/flattened_file_object.go b/hotline/flattened_file_object.go index 4bf8b4d..6419a1f 100644 --- a/hotline/flattened_file_object.go +++ b/hotline/flattened_file_object.go @@ -13,6 +13,8 @@ type flattenedFileObject struct { FlatFileInformationFork FlatFileInformationFork FlatFileDataForkHeader FlatFileForkHeader FlatFileResForkHeader FlatFileForkHeader + + readOffset int // Internal offset to track read progress } // FlatFileHeader is the first section of a "Flattened File Object". All fields have static values. @@ -37,6 +39,8 @@ type FlatFileInformationFork struct { Name []byte // File name CommentSize []byte // Length of the comment Comment []byte // File comment + + readOffset int // Internal offset to track read progress } func NewFlatFileInformationFork(fileName string, modifyTime []byte, typeSignature string, creatorSignature string) FlatFileInformationFork { @@ -133,23 +137,30 @@ type FlatFileForkHeader struct { } func (ffif *FlatFileInformationFork) Read(p []byte) (int, error) { - return copy(p, - slices.Concat( - ffif.Platform, - ffif.TypeSignature, - ffif.CreatorSignature, - ffif.Flags, - ffif.PlatformFlags, - ffif.RSVD, - ffif.CreateDate, - ffif.ModifyDate, - ffif.NameScript, - ffif.ReadNameSize(), - ffif.Name, - ffif.CommentSize, - ffif.Comment, - ), - ), io.EOF + buf := slices.Concat( + ffif.Platform, + ffif.TypeSignature, + ffif.CreatorSignature, + ffif.Flags, + ffif.PlatformFlags, + ffif.RSVD, + ffif.CreateDate, + ffif.ModifyDate, + ffif.NameScript, + ffif.ReadNameSize(), + ffif.Name, + ffif.CommentSize, + ffif.Comment, + ) + + if ffif.readOffset >= len(buf) { + return 0, io.EOF // All bytes have been read + } + + n := copy(p, buf[ffif.readOffset:]) + ffif.readOffset += n + + return n, nil } // Write implements the io.Writer interface for FlatFileInformationFork @@ -173,7 +184,6 @@ func (ffif *FlatFileInformationFork) Write(p []byte) (int, error) { if len(p) > int(total) { ffif.CommentSize = p[total : total+2] commentLen := binary.BigEndian.Uint16(ffif.CommentSize) - commentStartPos := int(total) + 2 commentEndPos := int(total) + 2 + int(commentLen) @@ -182,7 +192,7 @@ func (ffif *FlatFileInformationFork) Write(p []byte) (int, error) { total = uint16(commentEndPos) } - return int(total), nil + return len(p), nil } func (ffif *FlatFileInformationFork) UnmarshalBinary(b []byte) error { @@ -217,7 +227,7 @@ func (ffif *FlatFileInformationFork) UnmarshalBinary(b []byte) error { // Read implements the io.Reader interface for flattenedFileObject func (ffo *flattenedFileObject) Read(p []byte) (int, error) { - return copy(p, slices.Concat( + buf := slices.Concat( ffo.FlatFileHeader.Format[:], ffo.FlatFileHeader.Version[:], ffo.FlatFileHeader.RSVD[:], @@ -243,8 +253,16 @@ func (ffo *flattenedFileObject) Read(p []byte) (int, error) { ffo.FlatFileDataForkHeader.CompressionType[:], ffo.FlatFileDataForkHeader.RSVD[:], ffo.FlatFileDataForkHeader.DataSize[:], - ), - ), io.EOF + ) + + if ffo.readOffset >= len(buf) { + return 0, io.EOF // All bytes have been read + } + + n := copy(p, buf[ffo.readOffset:]) + ffo.readOffset += n + + return n, nil } func (ffo *flattenedFileObject) ReadFrom(r io.Reader) (int64, error) { diff --git a/hotline/tracker.go b/hotline/tracker.go index 6ee2a9f..8f8ddd1 100644 --- a/hotline/tracker.go +++ b/hotline/tracker.go @@ -18,6 +18,8 @@ type TrackerRegistration struct { PassID [4]byte // Random number generated by the server Name string // Server Name Description string // Description of the server + + readOffset int // Internal offset to track read progress } // Read implements io.Reader to write tracker registration payload bytes to slice @@ -25,7 +27,7 @@ func (tr *TrackerRegistration) Read(p []byte) (int, error) { userCount := make([]byte, 2) binary.BigEndian.PutUint16(userCount, uint16(tr.UserCount)) - return copy(p, slices.Concat( + buf := slices.Concat( []byte{0x00, 0x01}, // Magic number, always 1 tr.Port[:], userCount, @@ -35,7 +37,16 @@ func (tr *TrackerRegistration) Read(p []byte) (int, error) { []byte(tr.Name), []byte{uint8(len(tr.Description))}, []byte(tr.Description), - )[:]), io.EOF + ) + + if tr.readOffset >= len(buf) { + return 0, io.EOF // All bytes have been read + } + + n := copy(p, buf[tr.readOffset:]) + tr.readOffset += n + + return n, nil } func register(tracker string, tr *TrackerRegistration) error { diff --git a/hotline/transaction_handlers.go b/hotline/transaction_handlers.go index 2d4eabd..20b87e2 100644 --- a/hotline/transaction_handlers.go +++ b/hotline/transaction_handlers.go @@ -1763,13 +1763,13 @@ func HandleGetFileNameList(cc *ClientConn, t *Transaction) (res []Transaction, e nil, ) if err != nil { - return res, err + return res, fmt.Errorf("error reading file path: %w", err) } var fp FilePath if t.GetField(FieldFilePath).Data != nil { if _, err = fp.Write(t.GetField(FieldFilePath).Data); err != nil { - return res, err + return res, fmt.Errorf("error writing file path: %w", err) } } @@ -1781,7 +1781,7 @@ func HandleGetFileNameList(cc *ClientConn, t *Transaction) (res []Transaction, e fileNames, err := getFileNameList(fullPath, cc.Server.Config.IgnoreFiles) if err != nil { - return res, err + return res, fmt.Errorf("getFileNameList: %w", err) } res = append(res, cc.NewReply(t, fileNames...))