]> git.r.bdr.sh - rbdr/mobius/blame - hotline/client.go
Make keepalive interval a const
[rbdr/mobius] / hotline / client.go
CommitLineData
6988a057
JH
1package hotline
2
3import (
d005ef04 4 "bufio"
6988a057
JH
5 "bytes"
6 "embed"
7 "encoding/binary"
8 "errors"
9 "fmt"
6988a057
JH
10 "github.com/gdamore/tcell/v2"
11 "github.com/rivo/tview"
12 "github.com/stretchr/testify/mock"
13 "go.uber.org/zap"
0197c3f5 14 "gopkg.in/yaml.v3"
6988a057
JH
15 "math/big"
16 "math/rand"
17 "net"
18 "os"
19 "strings"
20 "time"
21)
22
4f3c459c
JH
23const (
24 trackerListPage = "trackerList"
e005c191 25 serverUIPage = "serverUI"
4f3c459c 26)
6988a057 27
22c599ab 28//go:embed banners/*.txt
6988a057
JH
29var bannerDir embed.FS
30
31type Bookmark struct {
32 Name string `yaml:"Name"`
33 Addr string `yaml:"Addr"`
34 Login string `yaml:"Login"`
35 Password string `yaml:"Password"`
36}
37
38type ClientPrefs struct {
89bbc565
JH
39 Username string `yaml:"Username"`
40 IconID int `yaml:"IconID"`
41 Bookmarks []Bookmark `yaml:"Bookmarks"`
42 Tracker string `yaml:"Tracker"`
43 EnableBell bool `yaml:"EnableBell"`
6988a057
JH
44}
45
f7e36225
JH
46func (cp *ClientPrefs) IconBytes() []byte {
47 iconBytes := make([]byte, 2)
48 binary.BigEndian.PutUint16(iconBytes, uint16(cp.IconID))
49 return iconBytes
50}
51
da1e0d79
JH
52func (cp *ClientPrefs) AddBookmark(name, addr, login, pass string) error {
53 cp.Bookmarks = append(cp.Bookmarks, Bookmark{Addr: addr, Login: login, Password: pass})
54
55 return nil
56}
57
6988a057
JH
58func readConfig(cfgPath string) (*ClientPrefs, error) {
59 fh, err := os.Open(cfgPath)
60 if err != nil {
61 return nil, err
62 }
63
64 prefs := ClientPrefs{}
65 decoder := yaml.NewDecoder(fh)
6988a057
JH
66 if err := decoder.Decode(&prefs); err != nil {
67 return nil, err
68 }
69 return &prefs, nil
70}
71
72type Client struct {
95753255 73 cfgPath string
6988a057
JH
74 DebugBuf *DebugBuffer
75 Connection net.Conn
6988a057 76 UserAccess []byte
43ecc0f4 77 filePath []string
6988a057
JH
78 UserList []User
79 Logger *zap.SugaredLogger
80 activeTasks map[uint32]*Transaction
e005c191 81 serverName string
6988a057 82
d005ef04 83 Pref *ClientPrefs
6988a057 84
5978b74f 85 Handlers map[uint16]ClientHandler
6988a057
JH
86
87 UI *UI
88
72dd37f1 89 Inbox chan *Transaction
6988a057
JH
90}
91
5978b74f
JH
92type ClientHandler func(*Client, *Transaction) ([]Transaction, error)
93
94func (c *Client) HandleFunc(transactionID uint16, handler ClientHandler) {
95 c.Handlers[transactionID] = handler
96}
97
d005ef04
JH
98func NewClient(username string, logger *zap.SugaredLogger) *Client {
99 c := &Client{
100 Logger: logger,
101 activeTasks: make(map[uint32]*Transaction),
5978b74f 102 Handlers: make(map[uint16]ClientHandler),
d005ef04
JH
103 }
104 c.Pref = &ClientPrefs{Username: username}
105
106 return c
107}
108
109func NewUIClient(cfgPath string, logger *zap.SugaredLogger) *Client {
b198b22b
JH
110 c := &Client{
111 cfgPath: cfgPath,
112 Logger: logger,
113 activeTasks: make(map[uint32]*Transaction),
114 Handlers: clientHandlers,
6988a057 115 }
b198b22b 116 c.UI = NewUI(c)
6988a057 117
b198b22b 118 prefs, err := readConfig(cfgPath)
6988a057 119 if err != nil {
f4a69647 120 logger.Fatal(fmt.Sprintf("unable to read config file %s\n", cfgPath))
6988a057 121 }
d005ef04 122 c.Pref = prefs
6988a057 123
b198b22b 124 return c
6988a057
JH
125}
126
6988a057
JH
127// DebugBuffer wraps a *tview.TextView and adds a Sync() method to make it available as a Zap logger
128type DebugBuffer struct {
129 TextView *tview.TextView
130}
131
132func (db *DebugBuffer) Write(p []byte) (int, error) {
133 return db.TextView.Write(p)
134}
135
7cd900d6 136// Sync is a noop function that dataFile to satisfy the zapcore.WriteSyncer interface
6988a057
JH
137func (db *DebugBuffer) Sync() error {
138 return nil
139}
140
6988a057
JH
141func randomBanner() string {
142 rand.Seed(time.Now().UnixNano())
143
ce348eb8
JH
144 bannerFiles, _ := bannerDir.ReadDir("banners")
145 file, _ := bannerDir.ReadFile("banners/" + bannerFiles[rand.Intn(len(bannerFiles))].Name())
6988a057
JH
146
147 return fmt.Sprintf("\n\n\nWelcome to...\n\n[red::b]%s[-:-:-]\n\n", file)
148}
149
d005ef04 150type ClientTransaction struct {
6988a057
JH
151 Name string
152 Handler func(*Client, *Transaction) ([]Transaction, error)
153}
154
d005ef04 155func (ch ClientTransaction) Handle(cc *Client, t *Transaction) ([]Transaction, error) {
6988a057
JH
156 return ch.Handler(cc, t)
157}
158
d005ef04 159type ClientTHandler interface {
6988a057
JH
160 Handle(*Client, *Transaction) ([]Transaction, error)
161}
162
163type mockClientHandler struct {
164 mock.Mock
165}
166
167func (mh *mockClientHandler) Handle(cc *Client, t *Transaction) ([]Transaction, error) {
168 args := mh.Called(cc, t)
169 return args.Get(0).([]Transaction), args.Error(1)
170}
171
5978b74f
JH
172var clientHandlers = map[uint16]ClientHandler{
173 TranChatMsg: handleClientChatMsg,
174 TranLogin: handleClientTranLogin,
175 TranShowAgreement: handleClientTranShowAgreement,
176 TranUserAccess: handleClientTranUserAccess,
177 TranGetUserNameList: handleClientGetUserNameList,
178 TranNotifyChangeUser: handleNotifyChangeUser,
179 TranNotifyDeleteUser: handleNotifyDeleteUser,
180 TranGetMsgs: handleGetMsgs,
181 TranGetFileNameList: handleGetFileNameList,
182 TranServerMsg: handleTranServerMsg,
183 TranKeepAlive: func(client *Client, transaction *Transaction) (t []Transaction, err error) {
184 return t, err
9d41bcdf 185 },
3d2bd095
JH
186}
187
188func handleTranServerMsg(c *Client, t *Transaction) (res []Transaction, err error) {
189 time := time.Now().Format(time.RFC850)
190
d005ef04 191 msg := strings.ReplaceAll(string(t.GetField(FieldData).Data), "\r", "\n")
5c34f875 192 msg += "\n\nAt " + time
d005ef04 193 title := fmt.Sprintf("| Private Message From: %s |", t.GetField(FieldUserName).Data)
3d2bd095
JH
194
195 msgBox := tview.NewTextView().SetScrollable(true)
196 msgBox.SetText(msg).SetBackgroundColor(tcell.ColorDarkSlateBlue)
197 msgBox.SetTitle(title).SetBorder(true)
198 msgBox.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
199 switch event.Key() {
200 case tcell.KeyEscape:
201 c.UI.Pages.RemovePage("serverMsgModal" + time)
202 }
203 return event
204 })
205
206 centeredFlex := tview.NewFlex().
207 AddItem(nil, 0, 1, false).
208 AddItem(tview.NewFlex().SetDirection(tview.FlexRow).
209 AddItem(nil, 0, 1, false).
210 AddItem(msgBox, 0, 2, true).
211 AddItem(nil, 0, 1, false), 0, 2, true).
212 AddItem(nil, 0, 1, false)
213
5c34f875 214 c.UI.Pages.AddPage("serverMsgModal"+time, centeredFlex, true, true)
3d2bd095
JH
215 c.UI.App.Draw() // TODO: errModal doesn't render without this. wtf?
216
217 return res, err
43ecc0f4
JH
218}
219
9174dbe8
JH
220func (c *Client) showErrMsg(msg string) {
221 time := time.Now().Format(time.RFC850)
222
223 title := "| Error |"
224
225 msgBox := tview.NewTextView().SetScrollable(true)
226 msgBox.SetText(msg).SetBackgroundColor(tcell.ColorDarkRed)
227 msgBox.SetTitle(title).SetBorder(true)
228 msgBox.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
229 switch event.Key() {
230 case tcell.KeyEscape:
231 c.UI.Pages.RemovePage("serverMsgModal" + time)
232 }
233 return event
234 })
235
236 centeredFlex := tview.NewFlex().
237 AddItem(nil, 0, 1, false).
238 AddItem(tview.NewFlex().SetDirection(tview.FlexRow).
239 AddItem(nil, 0, 1, false).
240 AddItem(msgBox, 0, 2, true).
241 AddItem(nil, 0, 1, false), 0, 2, true).
242 AddItem(nil, 0, 1, false)
243
244 c.UI.Pages.AddPage("serverMsgModal"+time, centeredFlex, true, true)
245 c.UI.App.Draw() // TODO: errModal doesn't render without this. wtf?
246}
247
43ecc0f4 248func handleGetFileNameList(c *Client, t *Transaction) (res []Transaction, err error) {
9174dbe8 249 if t.IsError() {
d005ef04
JH
250 c.showErrMsg(string(t.GetField(FieldError).Data))
251 c.Logger.Infof("Error: %s", t.GetField(FieldError).Data)
9174dbe8
JH
252 return res, err
253 }
254
43ecc0f4
JH
255 fTree := tview.NewTreeView().SetTopLevel(1)
256 root := tview.NewTreeNode("Root")
257 fTree.SetRoot(root).SetCurrentNode(root)
258 fTree.SetBorder(true).SetTitle("| Files |")
259 fTree.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
260 switch event.Key() {
261 case tcell.KeyEscape:
262 c.UI.Pages.RemovePage("files")
263 c.filePath = []string{}
264 case tcell.KeyEnter:
265 selectedNode := fTree.GetCurrentNode()
266
267 if selectedNode.GetText() == "<- Back" {
268 c.filePath = c.filePath[:len(c.filePath)-1]
d005ef04 269 f := NewField(FieldFilePath, EncodeFilePath(strings.Join(c.filePath, "/")))
43ecc0f4 270
d005ef04 271 if err := c.UI.HLClient.Send(*NewTransaction(TranGetFileNameList, nil, f)); err != nil {
43ecc0f4
JH
272 c.UI.HLClient.Logger.Errorw("err", "err", err)
273 }
274 return event
275 }
276
277 entry := selectedNode.GetReference().(*FileNameWithInfo)
278
72dd37f1
JH
279 if bytes.Equal(entry.Type[:], []byte("fldr")) {
280 c.Logger.Infow("get new directory listing", "name", string(entry.name))
43ecc0f4 281
72dd37f1 282 c.filePath = append(c.filePath, string(entry.name))
d005ef04 283 f := NewField(FieldFilePath, EncodeFilePath(strings.Join(c.filePath, "/")))
43ecc0f4 284
d005ef04 285 if err := c.UI.HLClient.Send(*NewTransaction(TranGetFileNameList, nil, f)); err != nil {
43ecc0f4
JH
286 c.UI.HLClient.Logger.Errorw("err", "err", err)
287 }
288 } else {
289 // TODO: initiate file download
72dd37f1 290 c.Logger.Infow("download file", "name", string(entry.name))
43ecc0f4
JH
291 }
292 }
293
294 return event
295 })
296
297 if len(c.filePath) > 0 {
298 node := tview.NewTreeNode("<- Back")
299 root.AddChild(node)
300 }
301
43ecc0f4
JH
302 for _, f := range t.Fields {
303 var fn FileNameWithInfo
72dd37f1
JH
304 err = fn.UnmarshalBinary(f.Data)
305 if err != nil {
306 return nil, nil
307 }
43ecc0f4 308
72dd37f1
JH
309 if bytes.Equal(fn.Type[:], []byte("fldr")) {
310 node := tview.NewTreeNode(fmt.Sprintf("[blue::]📁 %s[-:-:-]", fn.name))
43ecc0f4
JH
311 node.SetReference(&fn)
312 root.AddChild(node)
313 } else {
72dd37f1 314 size := binary.BigEndian.Uint32(fn.FileSize[:]) / 1024
43ecc0f4 315
72dd37f1 316 node := tview.NewTreeNode(fmt.Sprintf(" %-40s %10v KB", fn.name, size))
43ecc0f4
JH
317 node.SetReference(&fn)
318 root.AddChild(node)
319 }
320
321 }
322
323 centerFlex := tview.NewFlex().
324 AddItem(nil, 0, 1, false).
325 AddItem(tview.NewFlex().
326 SetDirection(tview.FlexRow).
327 AddItem(nil, 0, 1, false).
328 AddItem(fTree, 20, 1, true).
246ed3a1 329 AddItem(nil, 0, 1, false), 60, 1, true).
43ecc0f4
JH
330 AddItem(nil, 0, 1, false)
331
332 c.UI.Pages.AddPage("files", centerFlex, true, true)
333 c.UI.App.Draw()
334
335 return res, err
6988a057
JH
336}
337
338func handleGetMsgs(c *Client, t *Transaction) (res []Transaction, err error) {
d005ef04 339 newsText := string(t.GetField(FieldData).Data)
6988a057
JH
340 newsText = strings.ReplaceAll(newsText, "\r", "\n")
341
342 newsTextView := tview.NewTextView().
343 SetText(newsText).
344 SetDoneFunc(func(key tcell.Key) {
40afb444 345 c.UI.Pages.SwitchToPage(serverUIPage)
6988a057
JH
346 c.UI.App.SetFocus(c.UI.chatInput)
347 })
348 newsTextView.SetBorder(true).SetTitle("News")
349
350 c.UI.Pages.AddPage("news", newsTextView, true, true)
aebc4d36
JH
351 // c.UI.Pages.SwitchToPage("news")
352 // c.UI.App.SetFocus(newsTextView)
6988a057
JH
353 c.UI.App.Draw()
354
355 return res, err
356}
357
358func handleNotifyChangeUser(c *Client, t *Transaction) (res []Transaction, err error) {
359 newUser := User{
d005ef04
JH
360 ID: t.GetField(FieldUserID).Data,
361 Name: string(t.GetField(FieldUserName).Data),
362 Icon: t.GetField(FieldUserIconID).Data,
363 Flags: t.GetField(FieldUserFlags).Data,
6988a057
JH
364 }
365
366 // Possible cases:
367 // user is new to the server
368 // user is already on the server but has a new name
369
370 var oldName string
371 var newUserList []User
372 updatedUser := false
373 for _, u := range c.UserList {
374 c.Logger.Debugw("Comparing Users", "userToUpdate", newUser.ID, "myID", u.ID, "userToUpdateName", newUser.Name, "myname", u.Name)
375 if bytes.Equal(newUser.ID, u.ID) {
376 oldName = u.Name
377 u.Name = newUser.Name
378 if u.Name != newUser.Name {
379 _, _ = fmt.Fprintf(c.UI.chatBox, " <<< "+oldName+" is now known as "+newUser.Name+" >>>\n")
380 }
381 updatedUser = true
382 }
383 newUserList = append(newUserList, u)
384 }
385
386 if !updatedUser {
387 newUserList = append(newUserList, newUser)
388 }
389
390 c.UserList = newUserList
391
392 c.renderUserList()
393
394 return res, err
395}
396
397func handleNotifyDeleteUser(c *Client, t *Transaction) (res []Transaction, err error) {
d005ef04 398 exitUser := t.GetField(FieldUserID).Data
6988a057
JH
399
400 var newUserList []User
401 for _, u := range c.UserList {
402 if !bytes.Equal(exitUser, u.ID) {
403 newUserList = append(newUserList, u)
404 }
405 }
406
407 c.UserList = newUserList
408
409 c.renderUserList()
410
411 return res, err
412}
413
6988a057
JH
414func handleClientGetUserNameList(c *Client, t *Transaction) (res []Transaction, err error) {
415 var users []User
416 for _, field := range t.Fields {
d005ef04
JH
417 // The Hotline protocol docs say that ClientGetUserNameList should only return FieldUsernameWithInfo (300)
418 // fields, but shxd sneaks in FieldChatSubject (115) so it's important to filter explicitly for the expected
71c56068
JH
419 // field type. Probably a good idea to do everywhere.
420 if bytes.Equal(field.ID, []byte{0x01, 0x2c}) {
421 u, err := ReadUser(field.Data)
422 if err != nil {
423 return res, err
424 }
425 users = append(users, *u)
426 }
6988a057
JH
427 }
428 c.UserList = users
429
430 c.renderUserList()
431
432 return res, err
433}
434
435func (c *Client) renderUserList() {
436 c.UI.userList.Clear()
437 for _, u := range c.UserList {
438 flagBitmap := big.NewInt(int64(binary.BigEndian.Uint16(u.Flags)))
439 if flagBitmap.Bit(userFlagAdmin) == 1 {
5dd57308 440 _, _ = fmt.Fprintf(c.UI.userList, "[red::b]%s[-:-:-]\n", u.Name)
6988a057 441 } else {
5dd57308 442 _, _ = fmt.Fprintf(c.UI.userList, "%s\n", u.Name)
6988a057 443 }
b198b22b 444 // TODO: fade if user is away
6988a057
JH
445 }
446}
447
448func handleClientChatMsg(c *Client, t *Transaction) (res []Transaction, err error) {
d005ef04 449 if c.Pref.EnableBell {
89bbc565
JH
450 fmt.Println("\a")
451 }
452
d005ef04 453 _, _ = fmt.Fprintf(c.UI.chatBox, "%s \n", t.GetField(FieldData).Data)
6988a057
JH
454
455 return res, err
456}
457
458func handleClientTranUserAccess(c *Client, t *Transaction) (res []Transaction, err error) {
d005ef04 459 c.UserAccess = t.GetField(FieldUserAccess).Data
6988a057
JH
460
461 return res, err
462}
463
464func handleClientTranShowAgreement(c *Client, t *Transaction) (res []Transaction, err error) {
d005ef04 465 agreement := string(t.GetField(FieldData).Data)
6988a057
JH
466 agreement = strings.ReplaceAll(agreement, "\r", "\n")
467
72dd37f1 468 agreeModal := tview.NewModal().
6988a057
JH
469 SetText(agreement).
470 AddButtons([]string{"Agree", "Disagree"}).
471 SetDoneFunc(func(buttonIndex int, buttonLabel string) {
472 if buttonIndex == 0 {
473 res = append(res,
474 *NewTransaction(
d005ef04
JH
475 TranAgreed, nil,
476 NewField(FieldUserName, []byte(c.Pref.Username)),
477 NewField(FieldUserIconID, c.Pref.IconBytes()),
478 NewField(FieldUserFlags, []byte{0x00, 0x00}),
479 NewField(FieldOptions, []byte{0x00, 0x00}),
6988a057
JH
480 ),
481 )
6988a057
JH
482 c.UI.Pages.HidePage("agreement")
483 c.UI.App.SetFocus(c.UI.chatInput)
484 } else {
f7e36225 485 _ = c.Disconnect()
6988a057
JH
486 c.UI.Pages.SwitchToPage("home")
487 }
488 },
489 )
490
72dd37f1 491 c.UI.Pages.AddPage("agreement", agreeModal, false, true)
b198b22b 492
6988a057
JH
493 return res, err
494}
495
496func handleClientTranLogin(c *Client, t *Transaction) (res []Transaction, err error) {
497 if !bytes.Equal(t.ErrorCode, []byte{0, 0, 0, 0}) {
d005ef04 498 errMsg := string(t.GetField(FieldError).Data)
6988a057
JH
499 errModal := tview.NewModal()
500 errModal.SetText(errMsg)
501 errModal.AddButtons([]string{"Oh no"})
502 errModal.SetDoneFunc(func(buttonIndex int, buttonLabel string) {
503 c.UI.Pages.RemovePage("errModal")
504 })
505 c.UI.Pages.RemovePage("joinServer")
506 c.UI.Pages.AddPage("errModal", errModal, false, true)
507
508 c.UI.App.Draw() // TODO: errModal doesn't render without this. wtf?
509
d005ef04
JH
510 c.Logger.Error(string(t.GetField(FieldError).Data))
511 return nil, errors.New("login error: " + string(t.GetField(FieldError).Data))
6988a057 512 }
40afb444 513 c.UI.Pages.AddAndSwitchToPage(serverUIPage, c.UI.renderServerUI(), true)
6988a057
JH
514 c.UI.App.SetFocus(c.UI.chatInput)
515
d005ef04 516 if err := c.Send(*NewTransaction(TranGetUserNameList, nil)); err != nil {
6988a057
JH
517 c.Logger.Errorw("err", "err", err)
518 }
519 return res, err
520}
521
522// JoinServer connects to a Hotline server and completes the login flow
d005ef04 523func (c *Client) Connect(address, login, passwd string) (err error) {
6988a057 524 // Establish TCP connection to server
d005ef04
JH
525 c.Connection, err = net.DialTimeout("tcp", address, 5*time.Second)
526 if err != nil {
6988a057
JH
527 return err
528 }
529
530 // Send handshake sequence
531 if err := c.Handshake(); err != nil {
532 return err
533 }
534
d005ef04 535 // Authenticate (send TranLogin 107)
6988a057
JH
536 if err := c.LogIn(login, passwd); err != nil {
537 return err
538 }
539
9d41bcdf
JH
540 // start keepalive go routine
541 go func() { _ = c.keepalive() }()
542
6988a057
JH
543 return nil
544}
545
f85c2d08
JH
546const keepaliveInterval = 300 * time.Second
547
9d41bcdf
JH
548func (c *Client) keepalive() error {
549 for {
f85c2d08 550 time.Sleep(keepaliveInterval)
d005ef04 551 _ = c.Send(*NewTransaction(TranKeepAlive, nil))
9d41bcdf
JH
552 c.Logger.Infow("Sent keepalive ping")
553 }
554}
555
6988a057
JH
556var ClientHandshake = []byte{
557 0x54, 0x52, 0x54, 0x50, // TRTP
558 0x48, 0x4f, 0x54, 0x4c, // HOTL
559 0x00, 0x01,
560 0x00, 0x02,
561}
562
563var ServerHandshake = []byte{
564 0x54, 0x52, 0x54, 0x50, // TRTP
565 0x00, 0x00, 0x00, 0x00, // ErrorCode
566}
567
568func (c *Client) Handshake() error {
aebc4d36
JH
569 // Protocol ID 4 ‘TRTP’ 0x54 52 54 50
570 // Sub-protocol ID 4 User defined
571 // Version 2 1 Currently 1
572 // Sub-version 2 User defined
6988a057
JH
573 if _, err := c.Connection.Write(ClientHandshake); err != nil {
574 return fmt.Errorf("handshake write err: %s", err)
575 }
576
577 replyBuf := make([]byte, 8)
578 _, err := c.Connection.Read(replyBuf)
579 if err != nil {
580 return err
581 }
582
72dd37f1 583 if bytes.Equal(replyBuf, ServerHandshake) {
6988a057
JH
584 return nil
585 }
6988a057 586
b198b22b 587 // In the case of an error, client and server close the connection.
6988a057
JH
588 return fmt.Errorf("handshake response err: %s", err)
589}
590
591func (c *Client) LogIn(login string, password string) error {
592 return c.Send(
593 *NewTransaction(
d005ef04
JH
594 TranLogin, nil,
595 NewField(FieldUserName, []byte(c.Pref.Username)),
596 NewField(FieldUserIconID, c.Pref.IconBytes()),
597 NewField(FieldUserLogin, negateString([]byte(login))),
598 NewField(FieldUserPassword, negateString([]byte(password))),
6988a057
JH
599 ),
600 )
601}
602
6988a057
JH
603func (c *Client) Send(t Transaction) error {
604 requestNum := binary.BigEndian.Uint16(t.Type)
605 tID := binary.BigEndian.Uint32(t.ID)
606
aebc4d36 607 // handler := TransactionHandlers[requestNum]
6988a057
JH
608
609 // if transaction is NOT reply, add it to the list to transactions we're expecting a response for
610 if t.IsReply == 0 {
611 c.activeTasks[tID] = &t
612 }
613
614 var n int
615 var err error
72dd37f1
JH
616 b, err := t.MarshalBinary()
617 if err != nil {
618 return err
619 }
620 if n, err = c.Connection.Write(b); err != nil {
6988a057
JH
621 return err
622 }
623 c.Logger.Debugw("Sent Transaction",
624 "IsReply", t.IsReply,
625 "type", requestNum,
626 "sentBytes", n,
627 )
628 return nil
629}
630
631func (c *Client) HandleTransaction(t *Transaction) error {
632 var origT Transaction
633 if t.IsReply == 1 {
634 requestID := binary.BigEndian.Uint32(t.ID)
635 origT = *c.activeTasks[requestID]
636 t.Type = origT.Type
637 }
638
639 requestNum := binary.BigEndian.Uint16(t.Type)
0da28a1f 640 c.Logger.Debugw("Received Transaction", "RequestType", requestNum)
6988a057
JH
641
642 if handler, ok := c.Handlers[requestNum]; ok {
5978b74f 643 outT, _ := handler(c, t)
6988a057
JH
644 for _, t := range outT {
645 c.Send(t)
646 }
647 } else {
d005ef04 648 c.Logger.Debugw(
6988a057
JH
649 "Unimplemented transaction type received",
650 "RequestID", requestNum,
651 "TransactionID", t.ID,
652 )
653 }
654
655 return nil
656}
657
6988a057 658func (c *Client) Disconnect() error {
00d1ef67 659 return c.Connection.Close()
6988a057 660}
d005ef04
JH
661
662func (c *Client) HandleTransactions() error {
663 // Create a new scanner for parsing incoming bytes into transaction tokens
664 scanner := bufio.NewScanner(c.Connection)
665 scanner.Split(transactionScanner)
666
667 // Scan for new transactions and handle them as they come in.
668 for scanner.Scan() {
669 // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the
670 // scanner re-uses the buffer for subsequent scans.
671 buf := make([]byte, len(scanner.Bytes()))
672 copy(buf, scanner.Bytes())
673
674 var t Transaction
675 _, err := t.Write(buf)
676 if err != nil {
677 break
678 }
679 if err := c.HandleTransaction(&t); err != nil {
680 c.Logger.Errorw("Error handling transaction", "err", err)
681 }
682 }
683
684 if scanner.Err() == nil {
685 return scanner.Err()
686 }
687 return nil
688}