]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Refactor and cleanup to improve testability
[rbdr/mobius] / hotline / server.go
index f164df1c8ff9e2b760f390f3242e8d45baa19246..fb06001f3aaa16fb453f8db262bccbbf47a42844 100644 (file)
@@ -131,7 +131,7 @@ func (s *Server) sendTransaction(t Transaction) error {
                "IsReply", t.IsReply,
                "type", handler.Name,
                "sentBytes", n,
-               "remoteAddr", client.Connection.RemoteAddr(),
+               "remoteAddr", client.RemoteAddr,
        )
        return nil
 }
@@ -156,7 +156,7 @@ func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln ne
                        }
                }()
                go func() {
-                       if err := s.handleNewConnection(conn); err != nil {
+                       if err := s.handleNewConnection(conn, conn.RemoteAddr().String()); err != nil {
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                } else {
@@ -319,7 +319,7 @@ func (s *Server) writeThreadedNews() error {
        return err
 }
 
-func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
+func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn {
        s.mux.Lock()
        defer s.mux.Unlock()
 
@@ -334,6 +334,7 @@ func (s *Server) NewClientConn(conn net.Conn) *ClientConn {
                AutoReply:  []byte{},
                Transfers:  make(map[int][]*FileTransfer),
                Agreed:     false,
+               RemoteAddr: remoteAddr,
        }
        *s.NextGuestID++
        ID := *s.NextGuestID
@@ -493,14 +494,10 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(conn net.Conn) error {
+func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
-       handshakeBuf := make([]byte, 12)
-       if _, err := io.ReadFull(conn, handshakeBuf); err != nil {
-               return err
-       }
-       if err := Handshake(conn, handshakeBuf); err != nil {
+       if err := Handshake(conn); err != nil {
                return err
        }
 
@@ -519,7 +516,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                return err
        }
 
-       c := s.NewClientConn(conn)
+       c := s.NewClientConn(conn, remoteAddr)
        defer c.Disconnect()
 
        encodedLogin := clientLogin.GetField(fieldUserLogin).Data
@@ -561,7 +558,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error {
                *c.Flags = []byte{0, 2}
        }
 
-       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", conn.RemoteAddr().String())
+       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr)
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),