From: Jeff Halter Date: Sun, 3 Jul 2022 22:36:08 +0000 (-0700) Subject: Implement io.Writer interface for Transaction X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/854a92fc2755ace61c405df335ddf69b02a3d932?ds=sidebyside Implement io.Writer interface for Transaction --- diff --git a/hotline/server.go b/hotline/server.go index ccf992e..5b8aad0 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -555,9 +555,9 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser scanner.Scan() - clientLogin, _, err := ReadTransaction(scanner.Bytes()) - if err != nil { - panic(err) + var clientLogin Transaction + if _, err := clientLogin.Write(scanner.Bytes()); err != nil { + return err } c := s.NewClientConn(rwc, remoteAddr) @@ -604,7 +604,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser // If authentication fails, send error reply and close connection if !c.Authenticate(login, encodedPassword) { - t := c.NewErrReply(clientLogin, "Incorrect login.") + t := c.NewErrReply(&clientLogin, "Incorrect login.") b, err := t.MarshalBinary() if err != nil { return err @@ -636,7 +636,7 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser c.Flags = []byte{0, 2} } - s.outbox <- c.NewReply(clientLogin, + s.outbox <- c.NewReply(&clientLogin, NewField(fieldVersion, []byte{0x00, 0xbe}), NewField(fieldCommunityBannerID, []byte{0, 0}), NewField(fieldServerName, []byte(s.Config.Name)), @@ -685,11 +685,12 @@ func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser buf := make([]byte, len(scanner.Bytes())) copy(buf, scanner.Bytes()) - t, _, err := ReadTransaction(buf) - if err != nil { - panic(err) + var t Transaction + if _, err := t.Write(buf); err != nil { + return err } - if err := c.handleTransaction(*t); err != nil { + + if err := c.handleTransaction(t); err != nil { c.logger.Errorw("Error handling transaction", "err", err) } } diff --git a/hotline/transaction.go b/hotline/transaction.go index 9b0ac40..9800be3 100644 --- a/hotline/transaction.go +++ b/hotline/transaction.go @@ -102,34 +102,33 @@ func NewTransaction(t int, clientID *[]byte, fields ...Field) *Transaction { } } -// ReadTransaction parses a byte slice into a struct. The input slice may be shorter or longer -// that the transaction size depending on what was read from the network connection. -func ReadTransaction(buf []byte) (*Transaction, int, error) { - totalSize := binary.BigEndian.Uint32(buf[12:16]) +// Write implements io.Writer interface for Transaction +func (t *Transaction) Write(p []byte) (n int, err error) { + totalSize := binary.BigEndian.Uint32(p[12:16]) // the buf may include extra bytes that are not part of the transaction // tranLen represents the length of bytes that are part of the transaction tranLen := int(20 + totalSize) - if tranLen > len(buf) { - return nil, 0, errors.New("buflen too small for tranLen") + if tranLen > len(p) { + return n, errors.New("buflen too small for tranLen") } - fields, err := ReadFields(buf[20:22], buf[22:tranLen]) + fields, err := ReadFields(p[20:22], p[22:tranLen]) if err != nil { - return nil, 0, err + return n, err } - return &Transaction{ - Flags: buf[0], - IsReply: buf[1], - Type: buf[2:4], - ID: buf[4:8], - ErrorCode: buf[8:12], - TotalSize: buf[12:16], - DataSize: buf[16:20], - ParamCount: buf[20:22], - Fields: fields, - }, tranLen, nil + t.Flags = p[0] + t.IsReply = p[1] + t.Type = p[2:4] + t.ID = p[4:8] + t.ErrorCode = p[8:12] + t.TotalSize = p[12:16] + t.DataSize = p[16:20] + t.ParamCount = p[20:22] + t.Fields = fields + + return len(p), err } const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields diff --git a/hotline/transaction_test.go b/hotline/transaction_test.go index 04bcde0..4b4f183 100644 --- a/hotline/transaction_test.go +++ b/hotline/transaction_test.go @@ -111,80 +111,6 @@ func TestReadFields(t *testing.T) { } } -func TestReadTransaction(t *testing.T) { - sampleTransaction := &Transaction{ - Flags: byte(0), - IsReply: byte(0), - Type: []byte{0x000, 0x93}, - ID: []byte{0x000, 0x00, 0x00, 0x01}, - ErrorCode: []byte{0x000, 0x00, 0x00, 0x00}, - TotalSize: []byte{0x000, 0x00, 0x00, 0x08}, - DataSize: []byte{0x000, 0x00, 0x00, 0x08}, - ParamCount: []byte{0x00, 0x01}, - Fields: []Field{ - { - ID: []byte{0x00, 0x01}, - FieldSize: []byte{0x00, 0x02}, - Data: []byte{0xff, 0xff}, - }, - }, - } - - type args struct { - buf []byte - } - tests := []struct { - name string - args args - want *Transaction - want1 int - wantErr bool - }{ - { - name: "when buf contains all bytes for a single transaction", - args: args{ - buf: func() []byte { - b, _ := sampleTransaction.MarshalBinary() - return b - }(), - }, - want: sampleTransaction, - want1: func() int { - b, _ := sampleTransaction.MarshalBinary() - return len(b) - }(), - wantErr: false, - }, - { - name: "when len(buf) is less than the length of the transaction", - args: args{ - buf: func() []byte { - b, _ := sampleTransaction.MarshalBinary() - return b[:len(b)-1] - }(), - }, - want: nil, - want1: 0, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1, err := ReadTransaction(tt.args.buf) - if (err != nil) != tt.wantErr { - t.Errorf("ReadTransaction() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !assert.Equal(t, tt.want, got) { - t.Errorf("ReadTransaction() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("ReadTransaction() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} - func Test_transactionScanner(t *testing.T) { type args struct { data []byte @@ -372,3 +298,80 @@ func Test_transactionScanner(t *testing.T) { }) } } + +func TestTransaction_Write(t1 *testing.T) { + sampleTransaction := &Transaction{ + Flags: byte(0), + IsReply: byte(0), + Type: []byte{0x000, 0x93}, + ID: []byte{0x000, 0x00, 0x00, 0x01}, + ErrorCode: []byte{0x000, 0x00, 0x00, 0x00}, + TotalSize: []byte{0x000, 0x00, 0x00, 0x08}, + DataSize: []byte{0x000, 0x00, 0x00, 0x08}, + ParamCount: []byte{0x00, 0x01}, + Fields: []Field{ + { + ID: []byte{0x00, 0x01}, + FieldSize: []byte{0x00, 0x02}, + Data: []byte{0xff, 0xff}, + }, + }, + } + + type fields struct { + clientID *[]byte + Flags byte + IsReply byte + Type []byte + ID []byte + ErrorCode []byte + TotalSize []byte + DataSize []byte + ParamCount []byte + Fields []Field + } + type args struct { + p []byte + } + tests := []struct { + name string + fields fields + args args + wantN int + wantErr assert.ErrorAssertionFunc + }{ + { + name: "when buf contains all bytes for a single transaction", + fields: fields{}, + args: args{ + p: func() []byte { + b, _ := sampleTransaction.MarshalBinary() + return b + }(), + }, + wantN: 28, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &Transaction{ + clientID: tt.fields.clientID, + Flags: tt.fields.Flags, + IsReply: tt.fields.IsReply, + Type: tt.fields.Type, + ID: tt.fields.ID, + ErrorCode: tt.fields.ErrorCode, + TotalSize: tt.fields.TotalSize, + DataSize: tt.fields.DataSize, + ParamCount: tt.fields.ParamCount, + Fields: tt.fields.Fields, + } + gotN, err := t.Write(tt.args.p) + if !tt.wantErr(t1, err, fmt.Sprintf("Write(%v)", tt.args.p)) { + return + } + assert.Equalf(t1, tt.wantN, gotN, "Write(%v)", tt.args.p) + }) + } +} diff --git a/hotline/ui.go b/hotline/ui.go index 6eba011..709137e 100644 --- a/hotline/ui.go +++ b/hotline/ui.go @@ -208,11 +208,12 @@ func (ui *UI) joinServer(addr, login, password string) error { buf := make([]byte, len(scanner.Bytes())) copy(buf, scanner.Bytes()) - t, _, err := ReadTransaction(buf) + var t Transaction + _, err := t.Write(buf) if err != nil { break } - if err := ui.HLClient.HandleTransaction(t); err != nil { + if err := ui.HLClient.HandleTransaction(&t); err != nil { ui.HLClient.Logger.Errorw("Error handling transaction", "err", err) } }