X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/c6a3fa2cdbf7eb78964a5df7622482cdd8782d74..3178ae580a3fe97d6a1167b4346d209f04e9b7e3:/hotline/server.go diff --git a/hotline/server.go b/hotline/server.go index e685afd..1edfd4a 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -133,7 +133,6 @@ func (s *Server) ServeFileTransfers(ctx context.Context, ln net.Listener) error } func (s *Server) sendTransaction(t Transaction) error { - requestNum := binary.BigEndian.Uint16(t.Type) clientID, err := byteToInt(*t.clientID) if err != nil { return err @@ -141,31 +140,21 @@ func (s *Server) sendTransaction(t Transaction) error { s.mux.Lock() client := s.Clients[uint16(clientID)] - s.mux.Unlock() if client == nil { return fmt.Errorf("invalid client id %v", *t.clientID) } - userName := string(client.UserName) - login := client.Account.Login - handler := TransactionHandlers[requestNum] + s.mux.Unlock() b, err := t.MarshalBinary() if err != nil { return err } - var n int - if n, err = client.Connection.Write(b); err != nil { + + if _, err := client.Connection.Write(b); err != nil { return err } - s.Logger.Debugw("Sent Transaction", - "name", userName, - "login", login, - "IsReply", t.IsReply, - "type", handler.Name, - "sentBytes", n, - "remoteAddr", client.RemoteAddr, - ) + return nil } @@ -193,8 +182,9 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { }) go func() { + s.Logger.Infow("Connection established", "RemoteAddr", conn.RemoteAddr()) + if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { - s.Logger.Infow("New client connection established", "RemoteAddr", conn.RemoteAddr()) if err == io.EOF { s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr()) } else { @@ -517,12 +507,7 @@ func (s *Server) loadConfig(path string) error { return nil } -const ( - minTransactionLen = 22 // minimum length of any transaction -) - -// dontPanic recovers and logs panics instead of crashing -// TODO: remove this after known issues are fixed +// dontPanic logs panics instead of crashing func dontPanic(logger *zap.SugaredLogger) { if r := recover(); r != nil { fmt.Println("stacktrace from panic: \n" + string(debug.Stack())) @@ -531,29 +516,25 @@ func dontPanic(logger *zap.SugaredLogger) { } // handleNewConnection takes a new net.Conn and performs the initial login sequence -func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error { +func (s *Server) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser, remoteAddr string) error { defer dontPanic(s.Logger) - if err := Handshake(conn); err != nil { + if err := Handshake(rwc); err != nil { return err } - buf := make([]byte, 1024) - // TODO: fix potential short read with io.ReadFull - readLen, err := conn.Read(buf) - if readLen < minTransactionLen { - return err - } - if err != nil { - return err - } + // Create a new scanner for parsing incoming bytes into transaction tokens + scanner := bufio.NewScanner(rwc) + scanner.Split(transactionScanner) - clientLogin, _, err := ReadTransaction(buf[:readLen]) + scanner.Scan() + + clientLogin, _, err := ReadTransaction(scanner.Bytes()) if err != nil { - return err + panic(err) } - c := s.NewClientConn(conn, remoteAddr) + c := s.NewClientConn(rwc, remoteAddr) defer c.Disconnect() encodedLogin := clientLogin.GetField(fieldUserLogin).Data @@ -568,6 +549,8 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose login = GuestAccount } + c.logger = s.Logger.With("remoteAddr", remoteAddr, "login", login) + // If authentication fails, send error reply and close connection if !c.Authenticate(login, encodedPassword) { t := c.NewErrReply(clientLogin, "Incorrect login.") @@ -575,10 +558,13 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose if err != nil { return err } - if _, err := conn.Write(b); err != nil { + if _, err := rwc.Write(b); err != nil { return err } - return fmt.Errorf("incorrect login") + + c.logger.Infow("Login failed", "clientVersion", fmt.Sprintf("%x", *c.Version)) + + return nil } if clientLogin.GetField(fieldUserName).Data != nil { @@ -595,10 +581,6 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose *c.Flags = []byte{0, 2} } - c.logger = s.Logger.With("remoteAddr", remoteAddr, "login", login) - - c.logger.Infow("Client connection received", "version", fmt.Sprintf("%x", *c.Version)) - s.outbox <- c.NewReply(clientLogin, NewField(fieldVersion, []byte{0x00, 0xbe}), NewField(fieldCommunityBannerID, []byte{0, 0}), @@ -615,6 +597,7 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose if *c.Version == nil || bytes.Equal(*c.Version, nostalgiaVersion) { c.Agreed = true c.logger = c.logger.With("name", string(c.UserName)) + c.logger.Infow("Login successful", "clientVersion", fmt.Sprintf("%x", *c.Version)) for _, t := range c.notifyOthers( *NewTransaction( @@ -631,34 +614,22 @@ func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteClose c.Server.Stats.LoginCount += 1 - const readBuffSize = 1024000 // 1KB - TODO: what should this be? - tranBuff := make([]byte, 0) - tReadlen := 0 - // Infinite loop where take action on incoming client requests until the connection is closed - for { - buf = make([]byte, readBuffSize) - tranBuff = tranBuff[tReadlen:] + // Scan for new transactions and handle them as they come in. + for scanner.Scan() { + // Make a new []byte slice and copy the scanner bytes to it. This is critical to avoid a data race as the + // scanner re-uses the buffer for subsequent scans. + buf := make([]byte, len(scanner.Bytes())) + copy(buf, scanner.Bytes()) - readLen, err := c.Connection.Read(buf) + t, _, err := ReadTransaction(buf) if err != nil { - return err + panic(err) } - tranBuff = append(tranBuff, buf[:readLen]...) - - // We may have read multiple requests worth of bytes from Connection.Read. readTransactions splits them - // into a slice of transactions - var transactions []Transaction - if transactions, tReadlen, err = readTransactions(tranBuff); err != nil { + if err := c.handleTransaction(*t); err != nil { c.logger.Errorw("Error handling transaction", "err", err) } - - // iterate over all the transactions that were parsed from the byte slice and handle them - for _, t := range transactions { - if err := c.handleTransaction(&t); err != nil { - c.logger.Errorw("Error handling transaction", "err", err) - } - } } + return nil } func (s *Server) NewPrivateChat(cc *ClientConn) []byte {