]> git.r.bdr.sh - rbdr/mobius/commitdiff
Convert bespoke methods to io.Reader/io.Writer interfaces
authorJeff Halter <redacted>
Sun, 9 Jun 2024 03:36:04 +0000 (20:36 -0700)
committerJeff Halter <redacted>
Sun, 9 Jun 2024 21:05:15 +0000 (14:05 -0700)
hotline/account.go
hotline/client.go
hotline/file_name_with_info.go
hotline/file_name_with_info_test.go
hotline/files.go
hotline/transaction_handlers.go
hotline/transaction_handlers_test.go

index b041f3cd6c4eb025381d961285eb3bacd7aefddf..4c5a9b98a608810047a8830b01495973be172bda 100644 (file)
@@ -3,7 +3,9 @@ package hotline
 import (
        "encoding/binary"
        "golang.org/x/crypto/bcrypt"
+       "io"
        "log"
+       "slices"
 )
 
 const GuestAccount = "guest" // default account used when no login is provided for a connection
@@ -30,13 +32,12 @@ func (a *Account) Read(p []byte) (n int, err error) {
        fieldCount := make([]byte, 2)
        binary.BigEndian.PutUint16(fieldCount, uint16(len(fields)))
 
-       p = append(p, fieldCount...)
-
+       var fieldBytes []byte
        for _, field := range fields {
-               p = append(p, field.Payload()...)
+               fieldBytes = append(fieldBytes, field.Payload()...)
        }
 
-       return len(p), nil
+       return copy(p, slices.Concat(fieldCount, fieldBytes)), io.EOF
 }
 
 // hashAndSalt generates a password hash from a users obfuscated plaintext password
index 35a70862e8da6c56401db6d4427070843f352e21..e056438fbdef7c1ddfcad4d0043508c75f52da7c 100644 (file)
@@ -290,7 +290,7 @@ func handleGetFileNameList(ctx context.Context, c *Client, t *Transaction) (res
 
        for _, f := range t.Fields {
                var fn FileNameWithInfo
-               err = fn.UnmarshalBinary(f.Data)
+               _, err = fn.Write(f.Data)
                if err != nil {
                        return nil, nil
                }
@@ -649,7 +649,6 @@ func (c *Client) Disconnect() error {
        return c.Connection.Close()
 }
 
-
 func (c *Client) HandleTransactions(ctx context.Context) error {
        // Create a new scanner for parsing incoming bytes into transaction tokens
        scanner := bufio.NewScanner(c.Connection)
index f230856d27e57c366da4860ae629ad02e340412c..ac64c74567493646f60357ca4b7e75616c488413 100644 (file)
@@ -3,6 +3,8 @@ package hotline
 import (
        "bytes"
        "encoding/binary"
+       "io"
+       "slices"
 )
 
 type FileNameWithInfo struct {
@@ -24,28 +26,28 @@ func (f *fileNameWithInfoHeader) nameLen() int {
        return int(binary.BigEndian.Uint16(f.NameSize[:]))
 }
 
-func (f *FileNameWithInfo) MarshalBinary() (data []byte, err error) {
-       var buf bytes.Buffer
-       err = binary.Write(&buf, binary.LittleEndian, f.fileNameWithInfoHeader)
-       if err != nil {
-               return data, err
-       }
-
-       _, err = buf.Write(f.name)
-       if err != nil {
-               return data, err
-       }
-
-       return buf.Bytes(), err
+// Read implements io.Reader for FileNameWithInfo
+func (f *FileNameWithInfo) Read(b []byte) (int, error) {
+       return copy(b,
+               slices.Concat(
+                       f.Type[:],
+                       f.Creator[:],
+                       f.FileSize[:],
+                       f.RSVD[:],
+                       f.NameScript[:],
+                       f.NameSize[:],
+                       f.name,
+               ),
+       ), io.EOF
 }
 
-func (f *FileNameWithInfo) UnmarshalBinary(data []byte) error {
-       err := binary.Read(bytes.NewReader(data), binary.BigEndian, &f.fileNameWithInfoHeader)
+func (f *FileNameWithInfo) Write(p []byte) (int, error) {
+       err := binary.Read(bytes.NewReader(p), binary.BigEndian, &f.fileNameWithInfoHeader)
        if err != nil {
-               return err
+               return 0, err
        }
        headerLen := binary.Size(f.fileNameWithInfoHeader)
-       f.name = data[headerLen : headerLen+f.nameLen()]
+       f.name = p[headerLen : headerLen+f.nameLen()]
 
-       return err
+       return len(p), nil
 }
index 1637a64ce38c34a352ea0ee90dbad07e67cc506b..0473631ab12e1996000a449312b837b8db26e2b0 100644 (file)
@@ -2,6 +2,7 @@ package hotline
 
 import (
        "github.com/stretchr/testify/assert"
+       "io"
        "reflect"
        "testing"
 )
@@ -48,7 +49,7 @@ func TestFileNameWithInfo_MarshalBinary(t *testing.T) {
                                fileNameWithInfoHeader: tt.fields.fileNameWithInfoHeader,
                                name:                   tt.fields.name,
                        }
-                       gotData, err := f.MarshalBinary()
+                       gotData, err := io.ReadAll(f)
                        if (err != nil) != tt.wantErr {
                                t.Errorf("MarshalBinary() error = %v, wantErr %v", err, tt.wantErr)
                                return
@@ -108,7 +109,7 @@ func TestFileNameWithInfo_UnmarshalBinary(t *testing.T) {
                                fileNameWithInfoHeader: tt.fields.fileNameWithInfoHeader,
                                name:                   tt.fields.name,
                        }
-                       if err := f.UnmarshalBinary(tt.args.data); (err != nil) != tt.wantErr {
+                       if _, err := f.Write(tt.args.data); (err != nil) != tt.wantErr {
                                t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
                        }
                        if !assert.Equal(t, tt.want, f) {
index dcabf9d986e8057d52ae3302528afc743bf9eb0c..d4bea538258dc2ebc362b455bb48e27bd02952fa 100644 (file)
@@ -3,6 +3,7 @@ package hotline
 import (
        "encoding/binary"
        "errors"
+       "io"
        "io/fs"
        "os"
        "path/filepath"
@@ -132,7 +133,7 @@ func getFileNameList(path string, ignoreList []string) (fields []Field, err erro
 
                fnwi.name = []byte(strippedName)
 
-               b, err := fnwi.MarshalBinary()
+               b, err := io.ReadAll(&fnwi)
                if err != nil {
                        return nil, err
                }
index d4269af7dab773bb0c163982720913d15a4ed9ed..793bb7420cd234191b875ddf94c063959b21cec7 100644 (file)
@@ -6,6 +6,7 @@ import (
        "errors"
        "fmt"
        "gopkg.in/yaml.v3"
+       "io"
        "math/big"
        "os"
        "path"
@@ -747,13 +748,12 @@ func HandleListUsers(cc *ClientConn, t *Transaction) (res []Transaction, err err
 
        var userFields []Field
        for _, acc := range cc.Server.Accounts {
-               b := make([]byte, 0, 100)
-               n, err := acc.Read(b)
+               b, err := io.ReadAll(acc)
                if err != nil {
                        return res, err
                }
 
-               userFields = append(userFields, NewField(FieldData, b[:n]))
+               userFields = append(userFields, NewField(FieldData, b))
        }
 
        res = append(res, cc.NewReply(t, userFields...))
index b03e3b18b77185dc3b080b660768b74182ede5fa..94854f2f68d89ca825d23c8eb17ad07ed2f920c5 100644 (file)
@@ -5,6 +5,7 @@ import (
        "fmt"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/mock"
+       "io"
        "io/fs"
        "math/rand"
        "os"
@@ -2859,7 +2860,7 @@ func TestHandleGetFileNameList(t *testing.T) {
                                                                        },
                                                                        name: []byte("testfile-1k"),
                                                                }
-                                                               b, _ := fnwi.MarshalBinary()
+                                                               b, _ := io.ReadAll(&fnwi)
                                                                return b
                                                        }(),
                                                ),