]> git.r.bdr.sh - rbdr/mobius/commitdiff
Implement io.Writer interface for Transaction
authorJeff Halter <redacted>
Sun, 3 Jul 2022 22:36:08 +0000 (15:36 -0700)
committerJeff Halter <redacted>
Sun, 3 Jul 2022 22:36:08 +0000 (15:36 -0700)
hotline/server.go
hotline/transaction.go
hotline/transaction_test.go
hotline/ui.go

index ccf992e2a68eeff39043055e3d7dc6f4861f2693..5b8aad0401f348a9d8fd118cb9a94e527f1af18c 100644 (file)
@@ -555,9 +555,9 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser
 
        scanner.Scan()
 
-       clientLogin, _, err := ReadTransaction(scanner.Bytes())
-       if err != nil {
-               panic(err)
+       var clientLogin Transaction
+       if _, err := clientLogin.Write(scanner.Bytes()); err != nil {
+               return err
        }
 
        c := s.NewClientConn(rwc, remoteAddr)
@@ -604,7 +604,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser
 
        // If authentication fails, send error reply and close connection
        if !c.Authenticate(login, encodedPassword) {
-               t := c.NewErrReply(clientLogin, "Incorrect login.")
+               t := c.NewErrReply(&clientLogin, "Incorrect login.")
                b, err := t.MarshalBinary()
                if err != nil {
                        return err
@@ -636,7 +636,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser
                c.Flags = []byte{0, 2}
        }
 
-       s.outbox <- c.NewReply(clientLogin,
+       s.outbox <- c.NewReply(&clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
                NewField(fieldCommunityBannerID, []byte{0, 0}),
                NewField(fieldServerName, []byte(s.Config.Name)),
@@ -685,11 +685,12 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser
                buf := make([]byte, len(scanner.Bytes()))
                copy(buf, scanner.Bytes())
 
-               t, _, err := ReadTransaction(buf)
-               if err != nil {
-                       panic(err)
+               var t Transaction
+               if _, err := t.Write(buf); err != nil {
+                       return err
                }
-               if err := c.handleTransaction(*t); err != nil {
+
+               if err := c.handleTransaction(t); err != nil {
                        c.logger.Errorw("Error handling transaction", "err", err)
                }
        }
index 9b0ac407f81e9c694f2d710f10416392b2eccc6b..9800be3f336dea1b2234f614c885faa9c3a2a02d 100644 (file)
@@ -102,34 +102,33 @@ func NewTransaction(t int, clientID *[]byte, fields ...Field) *Transaction {
        }
 }
 
-// ReadTransaction parses a byte slice into a struct.  The input slice may be shorter or longer
-// that the transaction size depending on what was read from the network connection.
-func ReadTransaction(buf []byte) (*Transaction, int, error) {
-       totalSize := binary.BigEndian.Uint32(buf[12:16])
+// Write implements io.Writer interface for Transaction
+func (t *Transaction) Write(p []byte) (n int, err error) {
+       totalSize := binary.BigEndian.Uint32(p[12:16])
 
        // the buf may include extra bytes that are not part of the transaction
        // tranLen represents the length of bytes that are part of the transaction
        tranLen := int(20 + totalSize)
 
-       if tranLen > len(buf) {
-               return nil, 0, errors.New("buflen too small for tranLen")
+       if tranLen > len(p) {
+               return n, errors.New("buflen too small for tranLen")
        }
-       fields, err := ReadFields(buf[20:22], buf[22:tranLen])
+       fields, err := ReadFields(p[20:22], p[22:tranLen])
        if err != nil {
-               return nil, 0, err
+               return n, err
        }
 
-       return &Transaction{
-               Flags:      buf[0],
-               IsReply:    buf[1],
-               Type:       buf[2:4],
-               ID:         buf[4:8],
-               ErrorCode:  buf[8:12],
-               TotalSize:  buf[12:16],
-               DataSize:   buf[16:20],
-               ParamCount: buf[20:22],
-               Fields:     fields,
-       }, tranLen, nil
+       t.Flags = p[0]
+       t.IsReply = p[1]
+       t.Type = p[2:4]
+       t.ID = p[4:8]
+       t.ErrorCode = p[8:12]
+       t.TotalSize = p[12:16]
+       t.DataSize = p[16:20]
+       t.ParamCount = p[20:22]
+       t.Fields = fields
+
+       return len(p), err
 }
 
 const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields
index 04bcde0e7851e1b1538bced6411411113a438bff..4b4f183a52b02b2939811d28d2a85ac7fce9a5b2 100644 (file)
@@ -111,80 +111,6 @@ func TestReadFields(t *testing.T) {
        }
 }
 
-func TestReadTransaction(t *testing.T) {
-       sampleTransaction := &Transaction{
-               Flags:      byte(0),
-               IsReply:    byte(0),
-               Type:       []byte{0x000, 0x93},
-               ID:         []byte{0x000, 0x00, 0x00, 0x01},
-               ErrorCode:  []byte{0x000, 0x00, 0x00, 0x00},
-               TotalSize:  []byte{0x000, 0x00, 0x00, 0x08},
-               DataSize:   []byte{0x000, 0x00, 0x00, 0x08},
-               ParamCount: []byte{0x00, 0x01},
-               Fields: []Field{
-                       {
-                               ID:        []byte{0x00, 0x01},
-                               FieldSize: []byte{0x00, 0x02},
-                               Data:      []byte{0xff, 0xff},
-                       },
-               },
-       }
-
-       type args struct {
-               buf []byte
-       }
-       tests := []struct {
-               name    string
-               args    args
-               want    *Transaction
-               want1   int
-               wantErr bool
-       }{
-               {
-                       name: "when buf contains all bytes for a single transaction",
-                       args: args{
-                               buf: func() []byte {
-                                       b, _ := sampleTransaction.MarshalBinary()
-                                       return b
-                               }(),
-                       },
-                       want: sampleTransaction,
-                       want1: func() int {
-                               b, _ := sampleTransaction.MarshalBinary()
-                               return len(b)
-                       }(),
-                       wantErr: false,
-               },
-               {
-                       name: "when len(buf) is less than the length of the transaction",
-                       args: args{
-                               buf: func() []byte {
-                                       b, _ := sampleTransaction.MarshalBinary()
-                                       return b[:len(b)-1]
-                               }(),
-                       },
-                       want:    nil,
-                       want1:   0,
-                       wantErr: true,
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       got, got1, err := ReadTransaction(tt.args.buf)
-                       if (err != nil) != tt.wantErr {
-                               t.Errorf("ReadTransaction() error = %v, wantErr %v", err, tt.wantErr)
-                               return
-                       }
-                       if !assert.Equal(t, tt.want, got) {
-                               t.Errorf("ReadTransaction() got = %v, want %v", got, tt.want)
-                       }
-                       if got1 != tt.want1 {
-                               t.Errorf("ReadTransaction() got1 = %v, want %v", got1, tt.want1)
-                       }
-               })
-       }
-}
-
 func Test_transactionScanner(t *testing.T) {
        type args struct {
                data []byte
@@ -372,3 +298,80 @@ func Test_transactionScanner(t *testing.T) {
                })
        }
 }
+
+func TestTransaction_Write(t1 *testing.T) {
+       sampleTransaction := &Transaction{
+               Flags:      byte(0),
+               IsReply:    byte(0),
+               Type:       []byte{0x000, 0x93},
+               ID:         []byte{0x000, 0x00, 0x00, 0x01},
+               ErrorCode:  []byte{0x000, 0x00, 0x00, 0x00},
+               TotalSize:  []byte{0x000, 0x00, 0x00, 0x08},
+               DataSize:   []byte{0x000, 0x00, 0x00, 0x08},
+               ParamCount: []byte{0x00, 0x01},
+               Fields: []Field{
+                       {
+                               ID:        []byte{0x00, 0x01},
+                               FieldSize: []byte{0x00, 0x02},
+                               Data:      []byte{0xff, 0xff},
+                       },
+               },
+       }
+
+       type fields struct {
+               clientID   *[]byte
+               Flags      byte
+               IsReply    byte
+               Type       []byte
+               ID         []byte
+               ErrorCode  []byte
+               TotalSize  []byte
+               DataSize   []byte
+               ParamCount []byte
+               Fields     []Field
+       }
+       type args struct {
+               p []byte
+       }
+       tests := []struct {
+               name    string
+               fields  fields
+               args    args
+               wantN   int
+               wantErr assert.ErrorAssertionFunc
+       }{
+               {
+                       name:   "when buf contains all bytes for a single transaction",
+                       fields: fields{},
+                       args: args{
+                               p: func() []byte {
+                                       b, _ := sampleTransaction.MarshalBinary()
+                                       return b
+                               }(),
+                       },
+                       wantN:   28,
+                       wantErr: assert.NoError,
+               },
+       }
+       for _, tt := range tests {
+               t1.Run(tt.name, func(t1 *testing.T) {
+                       t := &Transaction{
+                               clientID:   tt.fields.clientID,
+                               Flags:      tt.fields.Flags,
+                               IsReply:    tt.fields.IsReply,
+                               Type:       tt.fields.Type,
+                               ID:         tt.fields.ID,
+                               ErrorCode:  tt.fields.ErrorCode,
+                               TotalSize:  tt.fields.TotalSize,
+                               DataSize:   tt.fields.DataSize,
+                               ParamCount: tt.fields.ParamCount,
+                               Fields:     tt.fields.Fields,
+                       }
+                       gotN, err := t.Write(tt.args.p)
+                       if !tt.wantErr(t1, err, fmt.Sprintf("Write(%v)", tt.args.p)) {
+                               return
+                       }
+                       assert.Equalf(t1, tt.wantN, gotN, "Write(%v)", tt.args.p)
+               })
+       }
+}
index 6eba01175121534c57808122518445d6d16ec7c1..709137e6de6d7c2a33b0279df90345a23d373a69 100644 (file)
@@ -208,11 +208,12 @@ func (ui *UI) joinServer(addr, login, password string) error {
                        buf := make([]byte, len(scanner.Bytes()))
                        copy(buf, scanner.Bytes())
 
-                       t, _, err := ReadTransaction(buf)
+                       var t Transaction
+                       _, err := t.Write(buf)
                        if err != nil {
                                break
                        }
-                       if err := ui.HLClient.HandleTransaction(t); err != nil {
+                       if err := ui.HLClient.HandleTransaction(&t); err != nil {
                                ui.HLClient.Logger.Errorw("Error handling transaction", "err", err)
                        }
                }