]> git.r.bdr.sh - rbdr/mobius/commitdiff
Implement bufio.Scanner for transaction parsing
authorJeff Halter <redacted>
Fri, 24 Jun 2022 17:41:37 +0000 (10:41 -0700)
committerJeff Halter <redacted>
Fri, 24 Jun 2022 17:41:37 +0000 (10:41 -0700)
go.mod
hotline/client.go
hotline/client_conn.go
hotline/client_conn_test.go
hotline/server.go
hotline/transaction.go
hotline/transaction_test.go
hotline/ui.go

diff --git a/go.mod b/go.mod
index 737889a2cc87b3b89df1bbc6e2de64b3e5e5886e..59a496765291db3211cafb8d8adbd9fd06c6b5d0 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -17,12 +17,16 @@ require (
        github.com/gdamore/encoding v1.0.0 // indirect
        github.com/go-playground/locales v0.14.0 // indirect
        github.com/go-playground/universal-translator v0.18.0 // indirect
+       github.com/kr/pretty v0.3.0 // indirect
+       github.com/kr/text v0.2.0 // indirect
        github.com/leodido/go-urn v1.2.1 // indirect
        github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
        github.com/mattn/go-runewidth v0.0.13 // indirect
        github.com/pkg/errors v0.9.1 // indirect
        github.com/pmezard/go-difflib v1.0.0 // indirect
        github.com/rivo/uniseg v0.2.0 // indirect
+       github.com/rogpeppe/go-internal v1.8.0 // indirect
+       github.com/ryboe/q v1.0.16 // indirect
        github.com/stretchr/objx v0.4.0 // indirect
        go.uber.org/atomic v1.9.0 // indirect
        go.uber.org/multierr v1.8.0 // indirect
index e5b396b75192f367239b99541edb225618b7a7b9..c32256298def76bf16373d678d985071ecc8f308 100644 (file)
@@ -397,54 +397,6 @@ func handleNotifyDeleteUser(c *Client, t *Transaction) (res []Transaction, err e
        return res, err
 }
 
-const readBuffSize = 1024000 // 1KB - TODO: what should this be?
-
-func (c *Client) ReadLoop() error {
-       tranBuff := make([]byte, 0)
-       tReadlen := 0
-       // Infinite loop where take action on incoming client requests until the connection is closed
-       for {
-               buf := make([]byte, readBuffSize)
-               tranBuff = tranBuff[tReadlen:]
-
-               readLen, err := c.Connection.Read(buf)
-               if err != nil {
-                       return err
-               }
-               tranBuff = append(tranBuff, buf[:readLen]...)
-
-               // We may have read multiple requests worth of bytes from Connection.Read.  readTransactions splits them
-               // into a slice of transactions
-               var transactions []Transaction
-               if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
-                       c.Logger.Errorw("Error handling transaction", "err", err)
-               }
-
-               // iterate over all of the transactions that were parsed from the byte slice and handle them
-               for _, t := range transactions {
-                       if err := c.HandleTransaction(&t); err != nil {
-                               c.Logger.Errorw("Error handling transaction", "err", err)
-                       }
-               }
-       }
-}
-
-func (c *Client) GetTransactions() error {
-       tranBuff := make([]byte, 0)
-       tReadlen := 0
-
-       buf := make([]byte, readBuffSize)
-       tranBuff = tranBuff[tReadlen:]
-
-       readLen, err := c.Connection.Read(buf)
-       if err != nil {
-               return err
-       }
-       tranBuff = append(tranBuff, buf[:readLen]...)
-
-       return nil
-}
-
 func handleClientGetUserNameList(c *Client, t *Transaction) (res []Transaction, err error) {
        var users []User
        for _, field := range t.Fields {
index 1186947ebeac57d3f469584ae60cafbe66f645ea..41810a8d439b62914001bdc50bd81357317afada 100644 (file)
@@ -54,7 +54,7 @@ func (cc *ClientConn) sendAll(t int, fields ...Field) {
        }
 }
 
-func (cc *ClientConn) handleTransaction(transaction *Transaction) error {
+func (cc *ClientConn) handleTransaction(transaction Transaction) error {
        requestNum := binary.BigEndian.Uint16(transaction.Type)
        if handler, ok := TransactionHandlers[requestNum]; ok {
                for _, reqField := range handler.RequiredFields {
@@ -80,7 +80,7 @@ func (cc *ClientConn) handleTransaction(transaction *Transaction) error {
 
                cc.logger.Debugw("Received Transaction", "RequestType", handler.Name)
 
-               transactions, err := handler.Handler(cc, transaction)
+               transactions, err := handler.Handler(cc, &transaction)
                if err != nil {
                        return err
                }
index da0a2cc1187181d49dd5df8c399d02547845817a..9f04d66c8d10d7232d331092baa368f15790e175 100644 (file)
@@ -1,52 +1,2 @@
 package hotline
 
-import (
-       "net"
-       "testing"
-)
-
-func TestClientConn_handleTransaction(t *testing.T) {
-       type fields struct {
-               Connection net.Conn
-               ID         *[]byte
-               Icon       *[]byte
-               Flags      *[]byte
-               UserName   []byte
-               Account    *Account
-               IdleTime   int
-               Server     *Server
-               Version    *[]byte
-               Idle       bool
-               AutoReply  []byte
-       }
-       type args struct {
-               transaction *Transaction
-       }
-       tests := []struct {
-               name    string
-               fields  fields
-               args    args
-               wantErr bool
-       }{
-               // TODO: Add test cases.
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       cc := &ClientConn{
-                               Connection: tt.fields.Connection,
-                               ID:         tt.fields.ID,
-                               Icon:       tt.fields.Icon,
-                               Flags:      tt.fields.Flags,
-                               UserName:   tt.fields.UserName,
-                               Account:    tt.fields.Account,
-                               Server:     tt.fields.Server,
-                               Version:    tt.fields.Version,
-                               Idle:       tt.fields.Idle,
-                               AutoReply:  tt.fields.AutoReply,
-                       }
-                       if err := cc.handleTransaction(tt.args.transaction); (err != nil) != tt.wantErr {
-                               t.Errorf("handleTransaction() error = %v, wantErr %v", err, tt.wantErr)
-                       }
-               })
-       }
-}
index 0677a813a6d48da4a2aaf99fca228d0c1a28c366..1edfd4a9d03a65224bd66a52c2b5dba185621780 100644 (file)
@@ -133,7 +133,6 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error
 }
 
 func (s *Server) sendTransaction(t Transaction) error {
-       requestNum := binary.BigEndian.Uint16(t.Type)
        clientID, err := byteToInt(*t.clientID)
        if err != nil {
                return err
@@ -141,31 +140,21 @@ func (s *Server) sendTransaction(t Transaction) error {
 
        s.mux.Lock()
        client := s.Clients[uint16(clientID)]
-       s.mux.Unlock()
        if client == nil {
                return fmt.Errorf("invalid client id %v", *t.clientID)
        }
-       userName := string(client.UserName)
-       login := client.Account.Login
 
-       handler := TransactionHandlers[requestNum]
+       s.mux.Unlock()
 
        b, err := t.MarshalBinary()
        if err != nil {
                return err
        }
-       var n int
-       if n, err = client.Connection.Write(b); err != nil {
+
+       if _, err := client.Connection.Write(b); err != nil {
                return err
        }
-       s.Logger.Debugw("Sent Transaction",
-               "name", userName,
-               "login", login,
-               "IsReply", t.IsReply,
-               "type", handler.Name,
-               "sentBytes", n,
-               "remoteAddr", client.RemoteAddr,
-       )
+
        return nil
 }
 
@@ -518,12 +507,7 @@ func (s *Server) loadConfig(path string) error {
        return nil
 }
 
-const (
-       minTransactionLen = 22 // minimum length of any transaction
-)
-
-// dontPanic recovers and logs panics instead of crashing
-// TODO: remove this after known issues are fixed
+// dontPanic logs panics instead of crashing
 func dontPanic(logger *zap.SugaredLogger) {
        if r := recover(); r != nil {
                fmt.Println("stacktrace from panic: \n" + string(debug.Stack()))
@@ -532,29 +516,25 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error {
+func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
-       if err := Handshake(conn); err != nil {
+       if err := Handshake(rwc); err != nil {
                return err
        }
 
-       buf := make([]byte, 1024)
-       // TODO: fix potential short read with io.ReadFull
-       readLen, err := conn.Read(buf)
-       if readLen < minTransactionLen {
-               return err
-       }
-       if err != nil {
-               return err
-       }
+       // Create a new scanner for parsing incoming bytes into transaction tokens
+       scanner := bufio.NewScanner(rwc)
+       scanner.Split(transactionScanner)
+
+       scanner.Scan()
 
-       clientLogin, _, err := ReadTransaction(buf[:readLen])
+       clientLogin, _, err := ReadTransaction(scanner.Bytes())
        if err != nil {
-               return err
+               panic(err)
        }
 
-       c := s.NewClientConn(conn, remoteAddr)
+       c := s.NewClientConn(rwc, remoteAddr)
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
@@ -578,7 +558,7 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose
                if err != nil {
                        return err
                }
-               if _, err := conn.Write(b); err != nil {
+               if _, err := rwc.Write(b); err != nil {
                        return err
                }
 
@@ -634,34 +614,22 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose
 
        c.Server.Stats.LoginCount += 1
 
-       const readBuffSize = 1024000 // 1KB - TODO: what should this be?
-       tranBuff := make([]byte, 0)
-       tReadlen := 0
-       // Infinite loop where take action on incoming client requests until the connection is closed
-       for {
-               buf = make([]byte, readBuffSize)
-               tranBuff = tranBuff[tReadlen:]
+       // Scan for new transactions and handle them as they come in.
+       for scanner.Scan() {
+               // Make a new []byte slice and copy the scanner bytes to it.  This is critical to avoid a data race as the
+               // scanner re-uses the buffer for subsequent scans.
+               buf := make([]byte, len(scanner.Bytes()))
+               copy(buf, scanner.Bytes())
 
-               readLen, err := c.Connection.Read(buf)
+               t, _, err := ReadTransaction(buf)
                if err != nil {
-                       return err
+                       panic(err)
                }
-               tranBuff = append(tranBuff, buf[:readLen]...)
-
-               // We may have read multiple requests worth of bytes from Connection.Read.  readTransactions splits them
-               // into a slice of transactions
-               var transactions []Transaction
-               if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
+               if err := c.handleTransaction(*t); err != nil {
                        c.logger.Errorw("Error handling transaction", "err", err)
                }
-
-               // iterate over all the transactions that were parsed from the byte slice and handle them
-               for _, t := range transactions {
-                       if err := c.handleTransaction(&t); err != nil {
-                               c.logger.Errorw("Error handling transaction", "err", err)
-                       }
-               }
        }
+       return nil
 }
 
 func (s *Server) NewPrivateChat(cc *ClientConn) []byte {
index ece2924da5969f6441ba1bb75188f7c6504d6f9f..88e17dfcad2471d707575b6ba5422546e2578632 100644 (file)
@@ -131,23 +131,24 @@ func ReadTransaction(buf []byte) (*Transaction, int, error) {
        }, tranLen, nil
 }
 
-func readTransactions(buf []byte) ([]Transaction, int, error) {
-       var transactions []Transaction
+const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields
 
-       bufLen := len(buf)
+// transactionScanner implements bufio.SplitFunc for parsing incoming byte slices into complete tokens
+func transactionScanner(data []byte, _ bool) (advance int, token []byte, err error) {
+       // The bytes that contain the size of a transaction are from 12:16, so we need at least 16 bytes
+       if len(data) < 16 {
+               return 0, nil, nil
+       }
 
-       var bytesRead = 0
-       for bytesRead < bufLen {
-               t, tReadLen, err := ReadTransaction(buf[bytesRead:])
-               if err != nil {
-                       return transactions, bytesRead, err
-               }
-               bytesRead += tReadLen
+       totalSize := binary.BigEndian.Uint32(data[12:16])
 
-               transactions = append(transactions, *t)
+       // tranLen represents the length of bytes that are part of the transaction
+       tranLen := int(tranHeaderLen + totalSize)
+       if tranLen > len(data) {
+               return 0, nil, nil
        }
 
-       return transactions, bytesRead, nil
+       return tranLen, data[0:tranLen], nil
 }
 
 const minFieldLen = 4
index 1fade255c63bc5e2047624b2d459b7257e93b066..04bcde0e7851e1b1538bced6411411113a438bff 100644 (file)
@@ -1,6 +1,7 @@
 package hotline
 
 import (
+       "fmt"
        "github.com/stretchr/testify/assert"
        "testing"
 )
@@ -147,8 +148,8 @@ func TestReadTransaction(t *testing.T) {
                                        return b
                                }(),
                        },
-                       want:    sampleTransaction,
-                       want1:   func() int {
+                       want: sampleTransaction,
+                       want1: func() int {
                                b, _ := sampleTransaction.MarshalBinary()
                                return len(b)
                        }(),
@@ -183,3 +184,191 @@ func TestReadTransaction(t *testing.T) {
                })
        }
 }
+
+func Test_transactionScanner(t *testing.T) {
+       type args struct {
+               data []byte
+               in1  bool
+       }
+       tests := []struct {
+               name        string
+               args        args
+               wantAdvance int
+               wantToken   []byte
+               wantErr     assert.ErrorAssertionFunc
+       }{
+               {
+                       name: "when too few bytes are provided to read the transaction size",
+                       args: args{
+                               data: []byte{},
+                               in1:  false,
+                       },
+                       wantAdvance: 0,
+                       wantToken:   []byte(nil),
+                       wantErr:     assert.NoError,
+               },
+               {
+                       name: "when too few bytes are provided to read the full payload",
+                       args: args{
+                               data: []byte{
+                                       0,
+                                       1,
+                                       0, 0,
+                                       0, 00, 00, 04,
+                                       00, 00, 00, 00,
+                                       00, 00, 00, 10,
+                                       00, 00, 00, 10,
+                               },
+                               in1: false,
+                       },
+                       wantAdvance: 0,
+                       wantToken:   []byte(nil),
+                       wantErr:     assert.NoError,
+               },
+               {
+                       name: "when a full transaction is provided",
+                       args: args{
+                               data: []byte{
+                                       0,
+                                       1,
+                                       0, 0,
+                                       0, 00, 00, 0x04,
+                                       00, 00, 00, 0x00,
+                                       00, 00, 00, 0x10,
+                                       00, 00, 00, 0x10,
+                                       00, 02,
+                                       00, 0x6c, // 108 - fieldTransferSize
+                                       00, 02,
+                                       0x63, 0x3b,
+                                       00, 0x6b, // 107 = fieldRefNum
+                                       00, 0x04,
+                                       00, 0x02, 0x93, 0x47,
+                               },
+                               in1: false,
+                       },
+                       wantAdvance: 36,
+                       wantToken: []byte{
+                               0,
+                               1,
+                               0, 0,
+                               0, 00, 00, 0x04,
+                               00, 00, 00, 0x00,
+                               00, 00, 00, 0x10,
+                               00, 00, 00, 0x10,
+                               00, 02,
+                               00, 0x6c, // 108 - fieldTransferSize
+                               00, 02,
+                               0x63, 0x3b,
+                               00, 0x6b, // 107 = fieldRefNum
+                               00, 0x04,
+                               00, 0x02, 0x93, 0x47,
+                       },
+                       wantErr: assert.NoError,
+               },
+               {
+                       name: "when a full transaction plus extra bytes are provided",
+                       args: args{
+                               data: []byte{
+                                       0,
+                                       1,
+                                       0, 0,
+                                       0, 00, 00, 0x04,
+                                       00, 00, 00, 0x00,
+                                       00, 00, 00, 0x10,
+                                       00, 00, 00, 0x10,
+                                       00, 02,
+                                       00, 0x6c, // 108 - fieldTransferSize
+                                       00, 02,
+                                       0x63, 0x3b,
+                                       00, 0x6b, // 107 = fieldRefNum
+                                       00, 0x04,
+                                       00, 0x02, 0x93, 0x47,
+                                       1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                               },
+                               in1: false,
+                       },
+                       wantAdvance: 36,
+                       wantToken: []byte{
+                               0,
+                               1,
+                               0, 0,
+                               0, 00, 00, 0x04,
+                               00, 00, 00, 0x00,
+                               00, 00, 00, 0x10,
+                               00, 00, 00, 0x10,
+                               00, 02,
+                               00, 0x6c, // 108 - fieldTransferSize
+                               00, 02,
+                               0x63, 0x3b,
+                               00, 0x6b, // 107 = fieldRefNum
+                               00, 0x04,
+                               00, 0x02, 0x93, 0x47,
+                       },
+                       wantErr: assert.NoError,
+               },
+               {
+                       name: "when two full transactions are provided",
+                       args: args{
+                               data: []byte{
+                                       0,
+                                       1,
+                                       0, 0,
+                                       0, 00, 00, 0x04,
+                                       00, 00, 00, 0x00,
+                                       00, 00, 00, 0x10,
+                                       00, 00, 00, 0x10,
+                                       00, 02,
+                                       00, 0x6c, // 108 - fieldTransferSize
+                                       00, 02,
+                                       0x63, 0x3b,
+                                       00, 0x6b, // 107 = fieldRefNum
+                                       00, 0x04,
+                                       00, 0x02, 0x93, 0x47,
+                                       0,
+                                       1,
+                                       0, 0,
+                                       0, 00, 00, 0x04,
+                                       00, 00, 00, 0x00,
+                                       00, 00, 00, 0x10,
+                                       00, 00, 00, 0x10,
+                                       00, 02,
+                                       00, 0x6c, // 108 - fieldTransferSize
+                                       00, 02,
+                                       0x63, 0x3b,
+                                       00, 0x6b, // 107 = fieldRefNum
+                                       00, 0x04,
+                                       00, 0x02, 0x93, 0x47,
+                               },
+                               in1: false,
+                       },
+                       wantAdvance: 36,
+                       wantToken: []byte{
+                               0,
+                               1,
+                               0, 0,
+                               0, 00, 00, 0x04,
+                               00, 00, 00, 0x00,
+                               00, 00, 00, 0x10,
+                               00, 00, 00, 0x10,
+                               00, 02,
+                               00, 0x6c, // 108 - fieldTransferSize
+                               00, 02,
+                               0x63, 0x3b,
+                               00, 0x6b, // 107 = fieldRefNum
+                               00, 0x04,
+                               00, 0x02, 0x93, 0x47,
+                       },
+                       wantErr: assert.NoError,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       gotAdvance, gotToken, err := transactionScanner(tt.args.data, tt.args.in1)
+                       if !tt.wantErr(t, err, fmt.Sprintf("transactionScanner(%v, %v)", tt.args.data, tt.args.in1)) {
+                               return
+                       }
+                       assert.Equalf(t, tt.wantAdvance, gotAdvance, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1)
+                       assert.Equalf(t, tt.wantToken, gotToken, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1)
+               })
+       }
+}
index b9e45dd9bafaa662f9284ce2fe06c792e5b7dfd0..6eba01175121534c57808122518445d6d16ec7c1 100644 (file)
@@ -1,11 +1,11 @@
 package hotline
 
 import (
+       "bufio"
        "fmt"
        "github.com/gdamore/tcell/v2"
        "github.com/rivo/tview"
        "gopkg.in/yaml.v3"
-       "io"
        "io/ioutil"
        "os"
        "strconv"
@@ -197,29 +197,41 @@ func (ui *UI) joinServer(addr, login, password string) error {
        }
 
        go func() {
-               for {
-                       err := ui.HLClient.ReadLoop()
+               // Create a new scanner for parsing incoming bytes into transaction tokens
+               scanner := bufio.NewScanner(ui.HLClient.Connection)
+               scanner.Split(transactionScanner)
+
+               // Scan for new transactions and handle them as they come in.
+               for scanner.Scan() {
+                       // Make a new []byte slice and copy the scanner bytes to it.  This is critical to avoid a data race as the
+                       // scanner re-uses the buffer for subsequent scans.
+                       buf := make([]byte, len(scanner.Bytes()))
+                       copy(buf, scanner.Bytes())
+
+                       t, _, err := ReadTransaction(buf)
                        if err != nil {
-                               ui.HLClient.Logger.Errorw("read error", "err", err)
-
-                               if err == io.EOF {
-                                       loginErrModal := tview.NewModal().
-                                               AddButtons([]string{"Ok"}).
-                                               SetText("The server connection has closed.").
-                                               SetDoneFunc(func(buttonIndex int, buttonLabel string) {
-                                                       ui.Pages.SwitchToPage("home")
-                                               })
-                                       loginErrModal.Box.SetTitle("Server Connection Error")
-
-                                       ui.Pages.AddPage("loginErr", loginErrModal, false, true)
-                                       ui.App.Draw()
-                                       return
-                               }
-                               ui.Pages.SwitchToPage("home")
-
-                               return
+                               break
+                       }
+                       if err := ui.HLClient.HandleTransaction(t); err != nil {
+                               ui.HLClient.Logger.Errorw("Error handling transaction", "err", err)
                        }
                }
+
+               if scanner.Err() == nil {
+                       loginErrModal := tview.NewModal().
+                               AddButtons([]string{"Ok"}).
+                               SetText("The server connection has closed.").
+                               SetDoneFunc(func(buttonIndex int, buttonLabel string) {
+                                       ui.Pages.SwitchToPage("home")
+                               })
+                       loginErrModal.Box.SetTitle("Server Connection Error")
+
+                       ui.Pages.AddPage("loginErr", loginErrModal, false, true)
+                       ui.App.Draw()
+                       return
+               }
+               ui.Pages.SwitchToPage("home")
+
        }()
 
        return nil
@@ -324,6 +336,7 @@ func (ui *UI) renderServerUI() *tview.Flex {
                if buttonIndex == 1 {
                        _ = ui.HLClient.Disconnect()
                        ui.Pages.RemovePage(pageServerUI)
+                       ui.Pages.SwitchToPage("home")
                } else {
                        ui.Pages.HidePage("modal")
                }