scanner.Scan()
- clientLogin, _, err := ReadTransaction(scanner.Bytes())
- if err != nil {
- panic(err)
+ var clientLogin Transaction
+ if _, err := clientLogin.Write(scanner.Bytes()); err != nil {
+ return err
}
c := s.NewClientConn(rwc, remoteAddr)
// If authentication fails, send error reply and close connection
if !c.Authenticate(login, encodedPassword) {
- t := c.NewErrReply(clientLogin, "Incorrect login.")
+ t := c.NewErrReply(&clientLogin, "Incorrect login.")
b, err := t.MarshalBinary()
if err != nil {
return err
c.Flags = []byte{0, 2}
}
- s.outbox <- c.NewReply(clientLogin,
+ s.outbox <- c.NewReply(&clientLogin,
NewField(fieldVersion, []byte{0x00, 0xbe}),
NewField(fieldCommunityBannerID, []byte{0, 0}),
NewField(fieldServerName, []byte(s.Config.Name)),
buf := make([]byte, len(scanner.Bytes()))
copy(buf, scanner.Bytes())
- t, _, err := ReadTransaction(buf)
- if err != nil {
- panic(err)
+ var t Transaction
+ if _, err := t.Write(buf); err != nil {
+ return err
}
- if err := c.handleTransaction(*t); err != nil {
+
+ if err := c.handleTransaction(t); err != nil {
c.logger.Errorw("Error handling transaction", "err", err)
}
}
}
}
-// ReadTransaction parses a byte slice into a struct. The input slice may be shorter or longer
-// that the transaction size depending on what was read from the network connection.
-func ReadTransaction(buf []byte) (*Transaction, int, error) {
- totalSize := binary.BigEndian.Uint32(buf[12:16])
+// Write implements io.Writer interface for Transaction
+func (t *Transaction) Write(p []byte) (n int, err error) {
+ totalSize := binary.BigEndian.Uint32(p[12:16])
// the buf may include extra bytes that are not part of the transaction
// tranLen represents the length of bytes that are part of the transaction
tranLen := int(20 + totalSize)
- if tranLen > len(buf) {
- return nil, 0, errors.New("buflen too small for tranLen")
+ if tranLen > len(p) {
+ return n, errors.New("buflen too small for tranLen")
}
- fields, err := ReadFields(buf[20:22], buf[22:tranLen])
+ fields, err := ReadFields(p[20:22], p[22:tranLen])
if err != nil {
- return nil, 0, err
+ return n, err
}
- return &Transaction{
- Flags: buf[0],
- IsReply: buf[1],
- Type: buf[2:4],
- ID: buf[4:8],
- ErrorCode: buf[8:12],
- TotalSize: buf[12:16],
- DataSize: buf[16:20],
- ParamCount: buf[20:22],
- Fields: fields,
- }, tranLen, nil
+ t.Flags = p[0]
+ t.IsReply = p[1]
+ t.Type = p[2:4]
+ t.ID = p[4:8]
+ t.ErrorCode = p[8:12]
+ t.TotalSize = p[12:16]
+ t.DataSize = p[16:20]
+ t.ParamCount = p[20:22]
+ t.Fields = fields
+
+ return len(p), err
}
const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields
}
}
-func TestReadTransaction(t *testing.T) {
- sampleTransaction := &Transaction{
- Flags: byte(0),
- IsReply: byte(0),
- Type: []byte{0x000, 0x93},
- ID: []byte{0x000, 0x00, 0x00, 0x01},
- ErrorCode: []byte{0x000, 0x00, 0x00, 0x00},
- TotalSize: []byte{0x000, 0x00, 0x00, 0x08},
- DataSize: []byte{0x000, 0x00, 0x00, 0x08},
- ParamCount: []byte{0x00, 0x01},
- Fields: []Field{
- {
- ID: []byte{0x00, 0x01},
- FieldSize: []byte{0x00, 0x02},
- Data: []byte{0xff, 0xff},
- },
- },
- }
-
- type args struct {
- buf []byte
- }
- tests := []struct {
- name string
- args args
- want *Transaction
- want1 int
- wantErr bool
- }{
- {
- name: "when buf contains all bytes for a single transaction",
- args: args{
- buf: func() []byte {
- b, _ := sampleTransaction.MarshalBinary()
- return b
- }(),
- },
- want: sampleTransaction,
- want1: func() int {
- b, _ := sampleTransaction.MarshalBinary()
- return len(b)
- }(),
- wantErr: false,
- },
- {
- name: "when len(buf) is less than the length of the transaction",
- args: args{
- buf: func() []byte {
- b, _ := sampleTransaction.MarshalBinary()
- return b[:len(b)-1]
- }(),
- },
- want: nil,
- want1: 0,
- wantErr: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, got1, err := ReadTransaction(tt.args.buf)
- if (err != nil) != tt.wantErr {
- t.Errorf("ReadTransaction() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !assert.Equal(t, tt.want, got) {
- t.Errorf("ReadTransaction() got = %v, want %v", got, tt.want)
- }
- if got1 != tt.want1 {
- t.Errorf("ReadTransaction() got1 = %v, want %v", got1, tt.want1)
- }
- })
- }
-}
-
func Test_transactionScanner(t *testing.T) {
type args struct {
data []byte
})
}
}
+
+func TestTransaction_Write(t1 *testing.T) {
+ sampleTransaction := &Transaction{
+ Flags: byte(0),
+ IsReply: byte(0),
+ Type: []byte{0x000, 0x93},
+ ID: []byte{0x000, 0x00, 0x00, 0x01},
+ ErrorCode: []byte{0x000, 0x00, 0x00, 0x00},
+ TotalSize: []byte{0x000, 0x00, 0x00, 0x08},
+ DataSize: []byte{0x000, 0x00, 0x00, 0x08},
+ ParamCount: []byte{0x00, 0x01},
+ Fields: []Field{
+ {
+ ID: []byte{0x00, 0x01},
+ FieldSize: []byte{0x00, 0x02},
+ Data: []byte{0xff, 0xff},
+ },
+ },
+ }
+
+ type fields struct {
+ clientID *[]byte
+ Flags byte
+ IsReply byte
+ Type []byte
+ ID []byte
+ ErrorCode []byte
+ TotalSize []byte
+ DataSize []byte
+ ParamCount []byte
+ Fields []Field
+ }
+ type args struct {
+ p []byte
+ }
+ tests := []struct {
+ name string
+ fields fields
+ args args
+ wantN int
+ wantErr assert.ErrorAssertionFunc
+ }{
+ {
+ name: "when buf contains all bytes for a single transaction",
+ fields: fields{},
+ args: args{
+ p: func() []byte {
+ b, _ := sampleTransaction.MarshalBinary()
+ return b
+ }(),
+ },
+ wantN: 28,
+ wantErr: assert.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t1.Run(tt.name, func(t1 *testing.T) {
+ t := &Transaction{
+ clientID: tt.fields.clientID,
+ Flags: tt.fields.Flags,
+ IsReply: tt.fields.IsReply,
+ Type: tt.fields.Type,
+ ID: tt.fields.ID,
+ ErrorCode: tt.fields.ErrorCode,
+ TotalSize: tt.fields.TotalSize,
+ DataSize: tt.fields.DataSize,
+ ParamCount: tt.fields.ParamCount,
+ Fields: tt.fields.Fields,
+ }
+ gotN, err := t.Write(tt.args.p)
+ if !tt.wantErr(t1, err, fmt.Sprintf("Write(%v)", tt.args.p)) {
+ return
+ }
+ assert.Equalf(t1, tt.wantN, gotN, "Write(%v)", tt.args.p)
+ })
+ }
+}