github.com/gdamore/encoding v1.0.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
+ github.com/kr/pretty v0.3.0 // indirect
+ github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/rogpeppe/go-internal v1.8.0 // indirect
+ github.com/ryboe/q v1.0.16 // indirect
github.com/stretchr/objx v0.4.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
return res, err
}
-const readBuffSize = 1024000 // 1KB - TODO: what should this be?
-
-func (c *Client) ReadLoop() error {
- tranBuff := make([]byte, 0)
- tReadlen := 0
- // Infinite loop where take action on incoming client requests until the connection is closed
- for {
- buf := make([]byte, readBuffSize)
- tranBuff = tranBuff[tReadlen:]
-
- readLen, err := c.Connection.Read(buf)
- if err != nil {
- return err
- }
- tranBuff = append(tranBuff, buf[:readLen]...)
-
- // We may have read multiple requests worth of bytes from Connection.Read. readTransactions splits them
- // into a slice of transactions
- var transactions []Transaction
- if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
- c.Logger.Errorw("Error handling transaction", "err", err)
- }
-
- // iterate over all of the transactions that were parsed from the byte slice and handle them
- for _, t := range transactions {
- if err := c.HandleTransaction(&t); err != nil {
- c.Logger.Errorw("Error handling transaction", "err", err)
- }
- }
- }
-}
-
-func (c *Client) GetTransactions() error {
- tranBuff := make([]byte, 0)
- tReadlen := 0
-
- buf := make([]byte, readBuffSize)
- tranBuff = tranBuff[tReadlen:]
-
- readLen, err := c.Connection.Read(buf)
- if err != nil {
- return err
- }
- tranBuff = append(tranBuff, buf[:readLen]...)
-
- return nil
-}
-
func handleClientGetUserNameList(c *Client, t *Transaction) (res []Transaction, err error) {
var users []User
for _, field := range t.Fields {
}
}
-func (cc *ClientConn) handleTransaction(transaction *Transaction) error {
+func (cc *ClientConn) handleTransaction(transaction Transaction) error {
requestNum := binary.BigEndian.Uint16(transaction.Type)
if handler, ok := TransactionHandlers[requestNum]; ok {
for _, reqField := range handler.RequiredFields {
cc.logger.Debugw("Received Transaction", "RequestType", handler.Name)
- transactions, err := handler.Handler(cc, transaction)
+ transactions, err := handler.Handler(cc, &transaction)
if err != nil {
return err
}
package hotline
-import (
- "net"
- "testing"
-)
-
-func TestClientConn_handleTransaction(t *testing.T) {
- type fields struct {
- Connection net.Conn
- ID *[]byte
- Icon *[]byte
- Flags *[]byte
- UserName []byte
- Account *Account
- IdleTime int
- Server *Server
- Version *[]byte
- Idle bool
- AutoReply []byte
- }
- type args struct {
- transaction *Transaction
- }
- tests := []struct {
- name string
- fields fields
- args args
- wantErr bool
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cc := &ClientConn{
- Connection: tt.fields.Connection,
- ID: tt.fields.ID,
- Icon: tt.fields.Icon,
- Flags: tt.fields.Flags,
- UserName: tt.fields.UserName,
- Account: tt.fields.Account,
- Server: tt.fields.Server,
- Version: tt.fields.Version,
- Idle: tt.fields.Idle,
- AutoReply: tt.fields.AutoReply,
- }
- if err := cc.handleTransaction(tt.args.transaction); (err != nil) != tt.wantErr {
- t.Errorf("handleTransaction() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
}
func (s *Server) sendTransaction(t Transaction) error {
- requestNum := binary.BigEndian.Uint16(t.Type)
clientID, err := byteToInt(*t.clientID)
if err != nil {
return err
s.mux.Lock()
client := s.Clients[uint16(clientID)]
- s.mux.Unlock()
if client == nil {
return fmt.Errorf("invalid client id %v", *t.clientID)
}
- userName := string(client.UserName)
- login := client.Account.Login
- handler := TransactionHandlers[requestNum]
+ s.mux.Unlock()
b, err := t.MarshalBinary()
if err != nil {
return err
}
- var n int
- if n, err = client.Connection.Write(b); err != nil {
+
+ if _, err := client.Connection.Write(b); err != nil {
return err
}
- s.Logger.Debugw("Sent Transaction",
- "name", userName,
- "login", login,
- "IsReply", t.IsReply,
- "type", handler.Name,
- "sentBytes", n,
- "remoteAddr", client.RemoteAddr,
- )
+
return nil
}
return nil
}
-const (
- minTransactionLen = 22 // minimum length of any transaction
-)
-
-// dontPanic recovers and logs panics instead of crashing
-// TODO: remove this after known issues are fixed
+// dontPanic logs panics instead of crashing
func dontPanic(logger *zap.SugaredLogger) {
if r := recover(); r != nil {
fmt.Println("stacktrace from panic: \n" + string(debug.Stack()))
}
// handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error {
+func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error {
defer dontPanic(s.Logger)
- if err := Handshake(conn); err != nil {
+ if err := Handshake(rwc); err != nil {
return err
}
- buf := make([]byte, 1024)
- // TODO: fix potential short read with io.ReadFull
- readLen, err := conn.Read(buf)
- if readLen < minTransactionLen {
- return err
- }
- if err != nil {
- return err
- }
+ // Create a new scanner for parsing incoming bytes into transaction tokens
+ scanner := bufio.NewScanner(rwc)
+ scanner.Split(transactionScanner)
+
+ scanner.Scan()
- clientLogin, _, err := ReadTransaction(buf[:readLen])
+ clientLogin, _, err := ReadTransaction(scanner.Bytes())
if err != nil {
- return err
+ panic(err)
}
- c := s.NewClientConn(conn, remoteAddr)
+ c := s.NewClientConn(rwc, remoteAddr)
defer c.Disconnect()
encodedLogin := clientLogin.GetField(fieldUserLogin).Data
if err != nil {
return err
}
- if _, err := conn.Write(b); err != nil {
+ if _, err := rwc.Write(b); err != nil {
return err
}
c.Server.Stats.LoginCount += 1
- const readBuffSize = 1024000 // 1KB - TODO: what should this be?
- tranBuff := make([]byte, 0)
- tReadlen := 0
- // Infinite loop where take action on incoming client requests until the connection is closed
- for {
- buf = make([]byte, readBuffSize)
- tranBuff = tranBuff[tReadlen:]
+ // Scan for new transactions and handle them as they come in.
+ for scanner.Scan() {
+ // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the
+ // scanner re-uses the buffer for subsequent scans.
+ buf := make([]byte, len(scanner.Bytes()))
+ copy(buf, scanner.Bytes())
- readLen, err := c.Connection.Read(buf)
+ t, _, err := ReadTransaction(buf)
if err != nil {
- return err
+ panic(err)
}
- tranBuff = append(tranBuff, buf[:readLen]...)
-
- // We may have read multiple requests worth of bytes from Connection.Read. readTransactions splits them
- // into a slice of transactions
- var transactions []Transaction
- if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
+ if err := c.handleTransaction(*t); err != nil {
c.logger.Errorw("Error handling transaction", "err", err)
}
-
- // iterate over all the transactions that were parsed from the byte slice and handle them
- for _, t := range transactions {
- if err := c.handleTransaction(&t); err != nil {
- c.logger.Errorw("Error handling transaction", "err", err)
- }
- }
}
+ return nil
}
func (s *Server) NewPrivateChat(cc *ClientConn) []byte {
}, tranLen, nil
}
-func readTransactions(buf []byte) ([]Transaction, int, error) {
- var transactions []Transaction
+const tranHeaderLen = 20 // fixed length of transaction fields before the variable length fields
- bufLen := len(buf)
+// transactionScanner implements bufio.SplitFunc for parsing incoming byte slices into complete tokens
+func transactionScanner(data []byte, _ bool) (advance int, token []byte, err error) {
+ // The bytes that contain the size of a transaction are from 12:16, so we need at least 16 bytes
+ if len(data) < 16 {
+ return 0, nil, nil
+ }
- var bytesRead = 0
- for bytesRead < bufLen {
- t, tReadLen, err := ReadTransaction(buf[bytesRead:])
- if err != nil {
- return transactions, bytesRead, err
- }
- bytesRead += tReadLen
+ totalSize := binary.BigEndian.Uint32(data[12:16])
- transactions = append(transactions, *t)
+ // tranLen represents the length of bytes that are part of the transaction
+ tranLen := int(tranHeaderLen + totalSize)
+ if tranLen > len(data) {
+ return 0, nil, nil
}
- return transactions, bytesRead, nil
+ return tranLen, data[0:tranLen], nil
}
const minFieldLen = 4
package hotline
import (
+ "fmt"
"github.com/stretchr/testify/assert"
"testing"
)
return b
}(),
},
- want: sampleTransaction,
- want1: func() int {
+ want: sampleTransaction,
+ want1: func() int {
b, _ := sampleTransaction.MarshalBinary()
return len(b)
}(),
})
}
}
+
+func Test_transactionScanner(t *testing.T) {
+ type args struct {
+ data []byte
+ in1 bool
+ }
+ tests := []struct {
+ name string
+ args args
+ wantAdvance int
+ wantToken []byte
+ wantErr assert.ErrorAssertionFunc
+ }{
+ {
+ name: "when too few bytes are provided to read the transaction size",
+ args: args{
+ data: []byte{},
+ in1: false,
+ },
+ wantAdvance: 0,
+ wantToken: []byte(nil),
+ wantErr: assert.NoError,
+ },
+ {
+ name: "when too few bytes are provided to read the full payload",
+ args: args{
+ data: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 04,
+ 00, 00, 00, 00,
+ 00, 00, 00, 10,
+ 00, 00, 00, 10,
+ },
+ in1: false,
+ },
+ wantAdvance: 0,
+ wantToken: []byte(nil),
+ wantErr: assert.NoError,
+ },
+ {
+ name: "when a full transaction is provided",
+ args: args{
+ data: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ },
+ in1: false,
+ },
+ wantAdvance: 36,
+ wantToken: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ },
+ wantErr: assert.NoError,
+ },
+ {
+ name: "when a full transaction plus extra bytes are provided",
+ args: args{
+ data: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ },
+ in1: false,
+ },
+ wantAdvance: 36,
+ wantToken: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ },
+ wantErr: assert.NoError,
+ },
+ {
+ name: "when two full transactions are provided",
+ args: args{
+ data: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ },
+ in1: false,
+ },
+ wantAdvance: 36,
+ wantToken: []byte{
+ 0,
+ 1,
+ 0, 0,
+ 0, 00, 00, 0x04,
+ 00, 00, 00, 0x00,
+ 00, 00, 00, 0x10,
+ 00, 00, 00, 0x10,
+ 00, 02,
+ 00, 0x6c, // 108 - fieldTransferSize
+ 00, 02,
+ 0x63, 0x3b,
+ 00, 0x6b, // 107 = fieldRefNum
+ 00, 0x04,
+ 00, 0x02, 0x93, 0x47,
+ },
+ wantErr: assert.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotAdvance, gotToken, err := transactionScanner(tt.args.data, tt.args.in1)
+ if !tt.wantErr(t, err, fmt.Sprintf("transactionScanner(%v, %v)", tt.args.data, tt.args.in1)) {
+ return
+ }
+ assert.Equalf(t, tt.wantAdvance, gotAdvance, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1)
+ assert.Equalf(t, tt.wantToken, gotToken, "transactionScanner(%v, %v)", tt.args.data, tt.args.in1)
+ })
+ }
+}
package hotline
import (
+ "bufio"
"fmt"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
"gopkg.in/yaml.v3"
- "io"
"io/ioutil"
"os"
"strconv"
}
go func() {
- for {
- err := ui.HLClient.ReadLoop()
+ // Create a new scanner for parsing incoming bytes into transaction tokens
+ scanner := bufio.NewScanner(ui.HLClient.Connection)
+ scanner.Split(transactionScanner)
+
+ // Scan for new transactions and handle them as they come in.
+ for scanner.Scan() {
+ // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the
+ // scanner re-uses the buffer for subsequent scans.
+ buf := make([]byte, len(scanner.Bytes()))
+ copy(buf, scanner.Bytes())
+
+ t, _, err := ReadTransaction(buf)
if err != nil {
- ui.HLClient.Logger.Errorw("read error", "err", err)
-
- if err == io.EOF {
- loginErrModal := tview.NewModal().
- AddButtons([]string{"Ok"}).
- SetText("The server connection has closed.").
- SetDoneFunc(func(buttonIndex int, buttonLabel string) {
- ui.Pages.SwitchToPage("home")
- })
- loginErrModal.Box.SetTitle("Server Connection Error")
-
- ui.Pages.AddPage("loginErr", loginErrModal, false, true)
- ui.App.Draw()
- return
- }
- ui.Pages.SwitchToPage("home")
-
- return
+ break
+ }
+ if err := ui.HLClient.HandleTransaction(t); err != nil {
+ ui.HLClient.Logger.Errorw("Error handling transaction", "err", err)
}
}
+
+ if scanner.Err() == nil {
+ loginErrModal := tview.NewModal().
+ AddButtons([]string{"Ok"}).
+ SetText("The server connection has closed.").
+ SetDoneFunc(func(buttonIndex int, buttonLabel string) {
+ ui.Pages.SwitchToPage("home")
+ })
+ loginErrModal.Box.SetTitle("Server Connection Error")
+
+ ui.Pages.AddPage("loginErr", loginErrModal, false, true)
+ ui.App.Draw()
+ return
+ }
+ ui.Pages.SwitchToPage("home")
+
}()
return nil
if buttonIndex == 1 {
_ = ui.HLClient.Disconnect()
ui.Pages.RemovePage(pageServerUI)
+ ui.Pages.SwitchToPage("home")
} else {
ui.Pages.HidePage("modal")
}