]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/transaction.go
Use fixed size array types in Transaction fields
[rbdr/mobius] / hotline / transaction.go
index b01fdfb112e8c6bb61922728cd0d879dabfb00e9..0caf72f652e3d255505f4b15dbeebbf219996621 100644 (file)
@@ -1,12 +1,14 @@
 package hotline
 
 import (
+       "bufio"
        "bytes"
        "encoding/binary"
        "errors"
        "fmt"
-       "github.com/jhalter/mobius/concat"
+       "io"
        "math/rand"
+       "slices"
 )
 
 const (
@@ -71,17 +73,18 @@ const (
 )
 
 type Transaction struct {
-       clientID *[]byte
-
-       Flags      byte   // Reserved (should be 0)
-       IsReply    byte   // Request (0) or reply (1)
-       Type       []byte // Requested operation (user defined)
-       ID         []byte // Unique transaction ID (must be != 0)
-       ErrorCode  []byte // Used in the reply (user defined, 0 = no error)
-       TotalSize  []byte // Total data size for the transaction (all parts)
-       DataSize   []byte // Size of data in this transaction part. This allows splitting large transactions into smaller parts.
-       ParamCount []byte // Number of the parameters for this transaction
+       Flags      byte    // Reserved (should be 0)
+       IsReply    byte    // Request (0) or reply (1)
+       Type       [2]byte // Requested operation (user defined)
+       ID         [4]byte // Unique transaction ID (must be != 0)
+       ErrorCode  [4]byte // Used in the reply (user defined, 0 = no error)
+       TotalSize  [4]byte // Total data size for the transaction (all parts)
+       DataSize   [4]byte // Size of data in this transaction part. This allows splitting large transactions into smaller parts.
+       ParamCount [2]byte // Number of the parameters for this transaction
        Fields     []Field
+
+       clientID   *[]byte // Internal identifier for target client
+       readOffset int     // Internal offset to track read progress
 }
 
 func NewTransaction(t int, clientID *[]byte, fields ...Field) *Transaction {
@@ -92,13 +95,10 @@ func NewTransaction(t int, clientID *[]byte, fields ...Field) *Transaction {
        binary.BigEndian.PutUint32(idSlice, rand.Uint32())
 
        return &Transaction{
-               clientID:  clientID,
-               Flags:     0x00,
-               IsReply:   0x00,
-               Type:      typeSlice,
-               ID:        idSlice,
-               ErrorCode: []byte{0, 0, 0, 0},
-               Fields:    fields,
+               clientID: clientID,
+               Type:     [2]byte(typeSlice),
+               ID:       [4]byte(idSlice),
+               Fields:   fields,
        }
 }
 
@@ -113,20 +113,29 @@ func (t *Transaction) Write(p []byte) (n int, err error) {
        if tranLen > len(p) {
                return n, errors.New("buflen too small for tranLen")
        }
-       fields, err := ReadFields(p[20:22], p[22:tranLen])
-       if err != nil {
-               return n, err
+
+       // Create a new scanner for parsing incoming bytes into transaction tokens
+       scanner := bufio.NewScanner(bytes.NewReader(p[22:tranLen]))
+       scanner.Split(fieldScanner)
+
+       for i := 0; i < int(binary.BigEndian.Uint16(p[20:22])); i++ {
+               scanner.Scan()
+
+               var field Field
+               if _, err := field.Write(scanner.Bytes()); err != nil {
+                       return 0, fmt.Errorf("error reading field: %w", err)
+               }
+               t.Fields = append(t.Fields, field)
        }
 
        t.Flags = p[0]
        t.IsReply = p[1]
-       t.Type = p[2:4]
-       t.ID = p[4:8]
-       t.ErrorCode = p[8:12]
-       t.TotalSize = p[12:16]
-       t.DataSize = p[16:20]
-       t.ParamCount = p[20:22]
-       t.Fields = fields
+       t.Type = [2]byte(p[2:4])
+       t.ID = [4]byte(p[4:8])
+       t.ErrorCode = [4]byte(p[8:12])
+       t.TotalSize = [4]byte(p[12:16])
+       t.DataSize = [4]byte(p[16:20])
+       t.ParamCount = [2]byte(p[20:22])
 
        return len(p), err
 }
@@ -177,8 +186,8 @@ func ReadFields(paramCount []byte, buf []byte) ([]Field, error) {
                }
 
                fields = append(fields, Field{
-                       ID:        fieldID,
-                       FieldSize: fieldSize,
+                       ID:        [2]byte(fieldID),
+                       FieldSize: [2]byte(fieldSize),
                        Data:      buf[4 : 4+fieldSizeInt],
                })
 
@@ -192,27 +201,42 @@ func ReadFields(paramCount []byte, buf []byte) ([]Field, error) {
        return fields, nil
 }
 
-func (t *Transaction) MarshalBinary() (data []byte, err error) {
+// Read implements the io.Reader interface for Transaction
+func (t *Transaction) Read(p []byte) (int, error) {
        payloadSize := t.Size()
 
        fieldCount := make([]byte, 2)
        binary.BigEndian.PutUint16(fieldCount, uint16(len(t.Fields)))
 
-       var fieldPayload []byte
+       bbuf := new(bytes.Buffer)
+
        for _, field := range t.Fields {
-               fieldPayload = append(fieldPayload, field.Payload()...)
+               f := field
+               _, err := bbuf.ReadFrom(&f)
+               if err != nil {
+                       return 0, fmt.Errorf("error reading field: %w", err)
+               }
        }
 
-       return concat.Slices(
+       buf := slices.Concat(
                []byte{t.Flags, t.IsReply},
-               t.Type,
-               t.ID,
-               t.ErrorCode,
+               t.Type[:],
+               t.ID[:],
+               t.ErrorCode[:],
                payloadSize,
                payloadSize, // this is the dataSize field, but seeming the same as totalSize
                fieldCount,
-               fieldPayload,
-       ), err
+               bbuf.Bytes(),
+       )
+
+       if t.readOffset >= len(buf) {
+               return 0, io.EOF // All bytes have been read
+       }
+
+       n := copy(p, buf[t.readOffset:])
+       t.readOffset += n
+
+       return n, nil
 }
 
 // Size returns the total size of the transaction payload
@@ -231,7 +255,7 @@ func (t *Transaction) Size() []byte {
 
 func (t *Transaction) GetField(id int) Field {
        for _, field := range t.Fields {
-               if id == int(binary.BigEndian.Uint16(field.ID)) {
+               if id == int(binary.BigEndian.Uint16(field.ID[:])) {
                        return field
                }
        }
@@ -240,5 +264,5 @@ func (t *Transaction) GetField(id int) Field {
 }
 
 func (t *Transaction) IsError() bool {
-       return bytes.Equal(t.ErrorCode, []byte{0, 0, 0, 1})
+       return t.ErrorCode == [4]byte{0, 0, 0, 1}
 }