From: Jeff Halter Date: Fri, 24 Jun 2022 17:41:37 +0000 (-0700) Subject: Implement bufio.Scanner for transaction parsing X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/3178ae580a3fe97d6a1167b4346d209f04e9b7e3 Implement bufio.Scanner for transaction parsing --- diff --git a/go.mod b/go.mod index 737889a..59a4967 100644 --- a/go.mod +++ b/go.mod @@ -17,12 +17,16 @@ require ( github.com/gdamore/encoding v1.0.0 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.8.0 // indirect + github.com/ryboe/q v1.0.16 // indirect github.com/stretchr/objx v0.4.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.8.0 // indirect diff --git a/hotline/client.go b/hotline/client.go index e5b396b..c322562 100644 --- a/hotline/client.go +++ b/hotline/client.go @@ -397,54 +397,6 @@ func handleNotifyDeleteUser(c *Client, t *Transaction) (res []Transaction, err e return res, err } -const readBuffSize = 1024000 // 1KB - TODO: what should this be? - -func (c *Client) ReadLoop() error { - tranBuff := make([]byte, 0) - tReadlen := 0 - // Infinite loop where take action on incoming client requests until the connection is closed - for { - buf := make([]byte, readBuffSize) - tranBuff = tranBuff[tReadlen:] - - readLen, err := c.Connection.Read(buf) - if err != nil { - return err - } - tranBuff = append(tranBuff, buf[:readLen]...) - - // We may have read multiple requests worth of bytes from Connection.Read. readTransactions splits them - // into a slice of transactions - var transactions []Transaction - if transactions, tReadlen, err = readTransactions(tranBuff); err != nil { - c.Logger.Errorw("Error handling transaction", "err", err) - } - - // iterate over all of the transactions that were parsed from the byte slice and handle them - for _, t := range transactions { - if err := c.HandleTransaction(&t); err != nil { - c.Logger.Errorw("Error handling transaction", "err", err) - } - } - } -} - -func (c *Client) GetTransactions() error { - tranBuff := make([]byte, 0) - tReadlen := 0 - - buf := make([]byte, readBuffSize) - tranBuff = tranBuff[tReadlen:] - - readLen, err := c.Connection.Read(buf) - if err != nil { - return err - } - tranBuff = append(tranBuff, buf[:readLen]...) - - return nil -} - func handleClientGetUserNameList(c *Client, t *Transaction) (res []Transaction, err error) { var users []User for _, field := range t.Fields { diff --git a/hotline/client_conn.go b/hotline/client_conn.go index 1186947..41810a8 100644 --- a/hotline/client_conn.go +++ b/hotline/client_conn.go @@ -54,7 +54,7 @@ func (cc *ClientConn) sendAll(t int, fields ...Field) { } } -func (cc *ClientConn) handleTransaction(transaction *Transaction) error { +func (cc *ClientConn) handleTransaction(transaction Transaction) error { requestNum := binary.BigEndian.Uint16(transaction.Type) if handler, ok := TransactionHandlers[requestNum]; ok { for _, reqField := range handler.RequiredFields { @@ -80,7 +80,7 @@ func (cc *ClientConn) handleTransaction(transaction *Transaction) error { cc.logger.Debugw("Received Transaction", "RequestType", handler.Name) - transactions, err := handler.Handler(cc, transaction) + transactions, err := handler.Handler(cc, &transaction) if err != nil { return err } diff --git a/hotline/client_conn_test.go b/hotline/client_conn_test.go index da0a2cc..9f04d66 100644 --- a/hotline/client_conn_test.go +++ b/hotline/client_conn_test.go @@ -1,52 +1,2 @@ package hotline -import ( - "net" - "testing" -) - -func TestClientConn_handleTransaction(t *testing.T) { - type fields struct { - Connection net.Conn - ID *[]byte - Icon *[]byte - Flags *[]byte - UserName []byte - Account *Account - IdleTime int - Server *Server - Version *[]byte - Idle bool - AutoReply []byte - } - type args struct { - transaction *Transaction - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cc := &ClientConn{ - Connection: tt.fields.Connection, - ID: tt.fields.ID, - Icon: tt.fields.Icon, - Flags: tt.fields.Flags, - UserName: tt.fields.UserName, - Account: tt.fields.Account, - Server: tt.fields.Server, - Version: tt.fields.Version, - Idle: tt.fields.Idle, - AutoReply: tt.fields.AutoReply, - } - if err := cc.handleTransaction(tt.args.transaction); (err != nil) != tt.wantErr { - t.Errorf("handleTransaction() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/hotline/server.go b/hotline/server.go index 0677a81..1edfd4a 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -133,7 +133,6 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error } func (s *Server) sendTransaction(t Transaction) error { - requestNum := binary.BigEndian.Uint16(t.Type) clientID, err := byteToInt(*t.clientID) if err != nil { return err @@ -141,31 +140,21 @@ func (s *Server) sendTransaction(t Transaction) error { s.mux.Lock() client := s.Clients[uint16(clientID)] - s.mux.Unlock() if client == nil { return fmt.Errorf("invalid client id %v", *t.clientID) } - userName := string(client.UserName) - login := client.Account.Login - handler := TransactionHandlers[requestNum] + s.mux.Unlock() b, err := t.MarshalBinary() if err != nil { return err } - var n int - if n, err = client.Connection.Write(b); err != nil { + + if _, err := client.Connection.Write(b); err != nil { return err } - s.Logger.Debugw("Sent Transaction", - "name", userName, - "login", login, - "IsReply", t.IsReply, - "type", handler.Name, - "sentBytes", n, - "remoteAddr", client.RemoteAddr, - ) + return nil } @@ -518,12 +507,7 @@ func (s *Server) loadConfig(path string) error { return nil } -const ( - minTransactionLen = 22 // minimum length of any transaction -) - -// dontPanic recovers and logs panics instead of crashing -// TODO: remove this after known issues are fixed +// dontPanic logs panics instead of crashing func dontPanic(logger *zap.SugaredLogger) { if r := recover(); r != nil { fmt.Println("stacktrace from panic: \n" + string(debug.Stack())) @@ -532,29 +516,25 @@ func dontPanic(logger *zap.SugaredLogger) { } // handleNewConnection takes a new net.Conn and performs the initial login sequence -func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error { +func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error { defer dontPanic(s.Logger) - if err := Handshake(conn); err != nil { + if err := Handshake(rwc); err != nil { return err } - buf := make([]byte, 1024) - // TODO: fix potential short read with io.ReadFull - readLen, err := conn.Read(buf) - if readLen < minTransactionLen { - return err - } - if err != nil { - return err - } + // Create a new scanner for parsing incoming bytes into transaction tokens + scanner := bufio.NewScanner(rwc) + scanner.Split(transactionScanner) + + scanner.Scan() - clientLogin, _, err := ReadTransaction(buf[:readLen]) + clientLogin, _, err := ReadTransaction(scanner.Bytes()) if err != nil { - return err + panic(err) } - c := s.NewClientConn(conn, remoteAddr) + c := s.NewClientConn(rwc, remoteAddr) defer c.Disconnect() encodedLogin := clientLogin.GetField(fieldUserLogin).Data @@ -578,7 +558,7 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose if err != nil { return err } - if _, err := conn.Write(b); err != nil { + if _, err := rwc.Write(b); err != nil { return err } @@ -634,34 +614,22 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose c.Server.Stats.LoginCount += 1 - const readBuffSize = 1024000 // 1KB - TODO: what should this be? - tranBuff := make([]byte, 0) - tReadlen := 0 - // Infinite loop where take action on incoming client requests until the connection is closed - for { - buf = make([]byte, readBuffSize) - tranBuff = tranBuff[tReadlen:] + // Scan for new transactions and handle them as they come in. + for scanner.Scan() { + // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the + // scanner re-uses the buffer for subsequent scans. + buf := make([]byte, len(scanner.Bytes())) + copy(buf, scanner.Bytes()) - readLen, err := c.Connection.Read(buf) + t, _, err := ReadTransaction(buf) if err != nil { - return err + panic(err) } - tranBuff = append(tranBuff, buf[:readLen]...) - - // We may have read multiple requests worth of bytes from Connection.Read. readTransactions splits them - // into a slice of transactions - var transactions []Transaction - if transactions, tReadlen, err = readTransactions(tranBuff); err != nil { + if err := c.handleTransaction(*t); err != nil { c.logger.Errorw("Error handling transaction", "err", err) } - - // iterate over all the transactions that were parsed from the byte slice and handle them - for _, t := range transactions { - if err := c.handleTransaction(&t); err != nil { - c.logger.Errorw("Error handling transaction", "err", err) - } - } } + return nil } func (s *Server) NewPrivateChat(cc *ClientConn) []byte { diff --git a/hotline/transaction.go b/hotline/transaction.go index ece2924..88e17df 100644 --- a/hotline/transaction.go +++ b/hotline/transaction.go @@ -131,23 +131,24 @@ func ReadTransaction(buf []byte) (*Transaction, int, error) { }, tranLen, nil } -func readTransactions(buf []byte) ([]Transaction, int, error) { - var transactions []Transaction +const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields - bufLen := len(buf) +// transactionScanner implements bufio.SplitFunc for parsing incoming byte slices into complete tokens +func transactionScanner(data []byte, _ bool) (advance int, token []byte, err error) { + // The bytes that contain the size of a transaction are from 12:16, so we need at least 16 bytes + if len(data) < 16 { + return 0, nil, nil + } - var bytesRead = 0 - for bytesRead < bufLen { - t, tReadLen, err := ReadTransaction(buf[bytesRead:]) - if err != nil { - return transactions, bytesRead, err - } - bytesRead += tReadLen + totalSize := binary.BigEndian.Uint32(data[12:16]) - transactions = append(transactions, *t) + // tranLen represents the length of bytes that are part of the transaction + tranLen := int(tranHeaderLen + totalSize) + if tranLen > len(data) { + return 0, nil, nil } - return transactions, bytesRead, nil + return tranLen, data[0:tranLen], nil } const minFieldLen = 4 diff --git a/hotline/transaction_test.go b/hotline/transaction_test.go index 1fade25..04bcde0 100644 --- a/hotline/transaction_test.go +++ b/hotline/transaction_test.go @@ -1,6 +1,7 @@ package hotline import ( + "fmt" "github.com/stretchr/testify/assert" "testing" ) @@ -147,8 +148,8 @@ func TestReadTransaction(t *testing.T) { return b }(), }, - want: sampleTransaction, - want1: func() int { + want: sampleTransaction, + want1: func() int { b, _ := sampleTransaction.MarshalBinary() return len(b) }(), @@ -183,3 +184,191 @@ func TestReadTransaction(t *testing.T) { }) } } + +func Test_transactionScanner(t *testing.T) { + type args struct { + data []byte + in1 bool + } + tests := []struct { + name string + args args + wantAdvance int + wantToken []byte + wantErr assert.ErrorAssertionFunc + }{ + { + name: "when too few bytes are provided to read the transaction size", + args: args{ + data: []byte{}, + in1: false, + }, + wantAdvance: 0, + wantToken: []byte(nil), + wantErr: assert.NoError, + }, + { + name: "when too few bytes are provided to read the full payload", + args: args{ + data: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 04, + 00, 00, 00, 00, + 00, 00, 00, 10, + 00, 00, 00, 10, + }, + in1: false, + }, + wantAdvance: 0, + wantToken: []byte(nil), + wantErr: assert.NoError, + }, + { + name: "when a full transaction is provided", + args: args{ + data: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + }, + in1: false, + }, + wantAdvance: 36, + wantToken: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + }, + wantErr: assert.NoError, + }, + { + name: "when a full transaction plus extra bytes are provided", + args: args{ + data: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + in1: false, + }, + wantAdvance: 36, + wantToken: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + }, + wantErr: assert.NoError, + }, + { + name: "when two full transactions are provided", + args: args{ + data: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + }, + in1: false, + }, + wantAdvance: 36, + wantToken: []byte{ + 0, + 1, + 0, 0, + 0, 00, 00, 0x04, + 00, 00, 00, 0x00, + 00, 00, 00, 0x10, + 00, 00, 00, 0x10, + 00, 02, + 00, 0x6c, // 108 - fieldTransferSize + 00, 02, + 0x63, 0x3b, + 00, 0x6b, // 107 = fieldRefNum + 00, 0x04, + 00, 0x02, 0x93, 0x47, + }, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotAdvance, gotToken, err := transactionScanner(tt.args.data, tt.args.in1) + if !tt.wantErr(t, err, fmt.Sprintf("transactionScanner(%v, %v)", tt.args.data, tt.args.in1)) { + return + } + assert.Equalf(t, tt.wantAdvance, gotAdvance, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1) + assert.Equalf(t, tt.wantToken, gotToken, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1) + }) + } +} diff --git a/hotline/ui.go b/hotline/ui.go index b9e45dd..6eba011 100644 --- a/hotline/ui.go +++ b/hotline/ui.go @@ -1,11 +1,11 @@ package hotline import ( + "bufio" "fmt" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "gopkg.in/yaml.v3" - "io" "io/ioutil" "os" "strconv" @@ -197,29 +197,41 @@ func (ui *UI) joinServer(addr, login, password string) error { } go func() { - for { - err := ui.HLClient.ReadLoop() + // Create a new scanner for parsing incoming bytes into transaction tokens + scanner := bufio.NewScanner(ui.HLClient.Connection) + scanner.Split(transactionScanner) + + // Scan for new transactions and handle them as they come in. + for scanner.Scan() { + // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the + // scanner re-uses the buffer for subsequent scans. + buf := make([]byte, len(scanner.Bytes())) + copy(buf, scanner.Bytes()) + + t, _, err := ReadTransaction(buf) if err != nil { - ui.HLClient.Logger.Errorw("read error", "err", err) - - if err == io.EOF { - loginErrModal := tview.NewModal(). - AddButtons([]string{"Ok"}). - SetText("The server connection has closed."). - SetDoneFunc(func(buttonIndex int, buttonLabel string) { - ui.Pages.SwitchToPage("home") - }) - loginErrModal.Box.SetTitle("Server Connection Error") - - ui.Pages.AddPage("loginErr", loginErrModal, false, true) - ui.App.Draw() - return - } - ui.Pages.SwitchToPage("home") - - return + break + } + if err := ui.HLClient.HandleTransaction(t); err != nil { + ui.HLClient.Logger.Errorw("Error handling transaction", "err", err) } } + + if scanner.Err() == nil { + loginErrModal := tview.NewModal(). + AddButtons([]string{"Ok"}). + SetText("The server connection has closed."). + SetDoneFunc(func(buttonIndex int, buttonLabel string) { + ui.Pages.SwitchToPage("home") + }) + loginErrModal.Box.SetTitle("Server Connection Error") + + ui.Pages.AddPage("loginErr", loginErrModal, false, true) + ui.App.Draw() + return + } + ui.Pages.SwitchToPage("home") + }() return nil @@ -324,6 +336,7 @@ func (ui *UI) renderServerUI() *tview.Flex { if buttonIndex == 1 { _ = ui.HLClient.Disconnect() ui.Pages.RemovePage(pageServerUI) + ui.Pages.SwitchToPage("home") } else { ui.Pages.HidePage("modal") }