]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Implement bufio.Scanner for transaction parsing
[rbdr/mobius] / hotline / server.go
index 0677a813a6d48da4a2aaf99fca228d0c1a28c366..1edfd4a9d03a65224bd66a52c2b5dba185621780 100644 (file)
@@ -133,7 +133,6 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error
 }
 
 func (s *Server) sendTransaction(t Transaction) error {
-       requestNum := binary.BigEndian.Uint16(t.Type)
        clientID, err := byteToInt(*t.clientID)
        if err != nil {
                return err
@@ -141,31 +140,21 @@ func (s *Server) sendTransaction(t Transaction) error {
 
        s.mux.Lock()
        client := s.Clients[uint16(clientID)]
-       s.mux.Unlock()
        if client == nil {
                return fmt.Errorf("invalid client id %v", *t.clientID)
        }
-       userName := string(client.UserName)
-       login := client.Account.Login
 
-       handler := TransactionHandlers[requestNum]
+       s.mux.Unlock()
 
        b, err := t.MarshalBinary()
        if err != nil {
                return err
        }
-       var n int
-       if n, err = client.Connection.Write(b); err != nil {
+
+       if _, err := client.Connection.Write(b); err != nil {
                return err
        }
-       s.Logger.Debugw("Sent Transaction",
-               "name", userName,
-               "login", login,
-               "IsReply", t.IsReply,
-               "type", handler.Name,
-               "sentBytes", n,
-               "remoteAddr", client.RemoteAddr,
-       )
+
        return nil
 }
 
@@ -518,12 +507,7 @@ func (s *Server) loadConfig(path string) error {
        return nil
 }
 
-const (
-       minTransactionLen = 22 // minimum length of any transaction
-)
-
-// dontPanic recovers and logs panics instead of crashing
-// TODO: remove this after known issues are fixed
+// dontPanic logs panics instead of crashing
 func dontPanic(logger *zap.SugaredLogger) {
        if r := recover(); r != nil {
                fmt.Println("stacktrace from panic: \n" + string(debug.Stack()))
@@ -532,29 +516,25 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error {
+func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
-       if err := Handshake(conn); err != nil {
+       if err := Handshake(rwc); err != nil {
                return err
        }
 
-       buf := make([]byte, 1024)
-       // TODO: fix potential short read with io.ReadFull
-       readLen, err := conn.Read(buf)
-       if readLen < minTransactionLen {
-               return err
-       }
-       if err != nil {
-               return err
-       }
+       // Create a new scanner for parsing incoming bytes into transaction tokens
+       scanner := bufio.NewScanner(rwc)
+       scanner.Split(transactionScanner)
+
+       scanner.Scan()
 
-       clientLogin, _, err := ReadTransaction(buf[:readLen])
+       clientLogin, _, err := ReadTransaction(scanner.Bytes())
        if err != nil {
-               return err
+               panic(err)
        }
 
-       c := s.NewClientConn(conn, remoteAddr)
+       c := s.NewClientConn(rwc, remoteAddr)
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
@@ -578,7 +558,7 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose
                if err != nil {
                        return err
                }
-               if _, err := conn.Write(b); err != nil {
+               if _, err := rwc.Write(b); err != nil {
                        return err
                }
 
@@ -634,34 +614,22 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose
 
        c.Server.Stats.LoginCount += 1
 
-       const readBuffSize = 1024000 // 1KB - TODO: what should this be?
-       tranBuff := make([]byte, 0)
-       tReadlen := 0
-       // Infinite loop where take action on incoming client requests until the connection is closed
-       for {
-               buf = make([]byte, readBuffSize)
-               tranBuff = tranBuff[tReadlen:]
+       // Scan for new transactions and handle them as they come in.
+       for scanner.Scan() {
+               // Make a new []byte slice and copy the scanner bytes to it.  This is critical to avoid a data race as the
+               // scanner re-uses the buffer for subsequent scans.
+               buf := make([]byte, len(scanner.Bytes()))
+               copy(buf, scanner.Bytes())
 
-               readLen, err := c.Connection.Read(buf)
+               t, _, err := ReadTransaction(buf)
                if err != nil {
-                       return err
+                       panic(err)
                }
-               tranBuff = append(tranBuff, buf[:readLen]...)
-
-               // We may have read multiple requests worth of bytes from Connection.Read.  readTransactions splits them
-               // into a slice of transactions
-               var transactions []Transaction
-               if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
+               if err := c.handleTransaction(*t); err != nil {
                        c.logger.Errorw("Error handling transaction", "err", err)
                }
-
-               // iterate over all the transactions that were parsed from the byte slice and handle them
-               for _, t := range transactions {
-                       if err := c.handleTransaction(&t); err != nil {
-                               c.logger.Errorw("Error handling transaction", "err", err)
-                       }
-               }
        }
+       return nil
 }
 
 func (s *Server) NewPrivateChat(cc *ClientConn) []byte {