]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/server.go
Fix intermittently failing banner download
[rbdr/mobius] / hotline / server.go
index ba0a6689c619251620e28a4dab4c195ea024fb04..21b1316d4f233b0e36ae11e6d12b928e9d0f0bc6 100644 (file)
@@ -167,25 +167,31 @@ func (s *Server) sendTransaction(t Transaction) error {
        return nil
 }
 
+func (s *Server) processOutbox() {
+       for {
+               t := <-s.outbox
+               go func() {
+                       if err := s.sendTransaction(t); err != nil {
+                               s.Logger.Errorw("error sending transaction", "err", err)
+                       }
+               }()
+       }
+}
+
 func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
+       go s.processOutbox()
+
        for {
                conn, err := ln.Accept()
                if err != nil {
                        s.Logger.Errorw("error accepting connection", "err", err)
                }
+               connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
+                       remoteAddr: conn.RemoteAddr().String(),
+               })
 
                go func() {
-                       for {
-                               t := <-s.outbox
-                               go func() {
-                                       if err := s.sendTransaction(t); err != nil {
-                                               s.Logger.Errorw("error sending transaction", "err", err)
-                                       }
-                               }()
-                       }
-               }()
-               go func() {
-                       if err := s.handleNewConnection(ctx, conn, conn.RemoteAddr().String()); err != nil {
+                       if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
                                s.Logger.Infow("New client connection established", "RemoteAddr", conn.RemoteAddr())
                                if err == io.EOF {
                                        s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr())
@@ -335,7 +341,7 @@ func (s *Server) writeThreadedNews() error {
        return err
 }
 
-func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn {
+func (s *Server) NewClientConn(conn io.ReadWriteCloser, remoteAddr string) *ClientConn {
        s.mux.Lock()
        defer s.mux.Unlock()
 
@@ -515,7 +521,7 @@ func dontPanic(logger *zap.SugaredLogger) {
 }
 
 // handleNewConnection takes a new net.Conn and performs the initial login sequence
-func (s *Server) handleNewConnection(ctx context.Context, conn net.Conn, remoteAddr string) error {
+func (s *Server) handleNewConnection(ctx context.Context, conn io.ReadWriteCloser, remoteAddr string) error {
        defer dontPanic(s.Logger)
 
        if err := Handshake(conn); err != nil {
@@ -579,11 +585,13 @@ func (s *Server) handleNewConnection(ctx context.Context, conn net.Conn, remoteA
                *c.Flags = []byte{0, 2}
        }
 
-       s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr)
+       c.logger = s.Logger.With("remoteAddr", remoteAddr, "login", login)
+
+       c.logger.Infow("Client connection received", "version", fmt.Sprintf("%x", *c.Version))
 
        s.outbox <- c.NewReply(clientLogin,
                NewField(fieldVersion, []byte{0x00, 0xbe}),
-               NewField(fieldCommunityBannerID, []byte{0x00, 0x01}),
+               NewField(fieldCommunityBannerID, []byte{0, 0}),
                NewField(fieldServerName, []byte(s.Config.Name)),
        )
 
@@ -596,8 +604,9 @@ func (s *Server) handleNewConnection(ctx context.Context, conn net.Conn, remoteA
        // Used simplified hotline v1.2.3 login flow for clients that do not send login info in tranAgreed
        if *c.Version == nil || bytes.Equal(*c.Version, nostalgiaVersion) {
                c.Agreed = true
+               c.logger = c.logger.With("name", string(c.UserName))
 
-               c.notifyOthers(
+               for _, t := range c.notifyOthers(
                        *NewTransaction(
                                tranNotifyChangeUser, nil,
                                NewField(fieldUserName, c.UserName),
@@ -605,7 +614,9 @@ func (s *Server) handleNewConnection(ctx context.Context, conn net.Conn, remoteA
                                NewField(fieldUserIconID, *c.Icon),
                                NewField(fieldUserFlags, *c.Flags),
                        ),
-               )
+               ) {
+                       c.Server.outbox <- t
+               }
        }
 
        c.Server.Stats.LoginCount += 1
@@ -628,13 +639,13 @@ func (s *Server) handleNewConnection(ctx context.Context, conn net.Conn, remoteA
                // into a slice of transactions
                var transactions []Transaction
                if transactions, tReadlen, err = readTransactions(tranBuff); err != nil {
-                       c.Server.Logger.Errorw("Error handling transaction", "err", err)
+                       c.logger.Errorw("Error handling transaction", "err", err)
                }
 
                // iterate over all the transactions that were parsed from the byte slice and handle them
                for _, t := range transactions {
                        if err := c.handleTransaction(&t); err != nil {
-                               c.Server.Logger.Errorw("Error handling transaction", "err", err)
+                               c.logger.Errorw("Error handling transaction", "err", err)
                        }
                }
        }
@@ -705,6 +716,10 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
        )
 
        switch fileTransfer.Type {
+       case bannerDownload:
+               if err := s.bannerDownload(rwc); err != nil {
+                       return err
+               }
        case FileDownload:
                s.Stats.DownloadCounter += 1
 
@@ -765,6 +780,9 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
                }
 
                err = sendFile(wr, rFile, int(dataOffset))
+               if err != nil {
+                       return err
+               }
 
                if err := wr.Flush(); err != nil {
                        return err
@@ -831,7 +849,7 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
                if err := file.Close(); err != nil {
                        return err
                }
-               
+
                if err := s.FS.Rename(destinationFile+".incomplete", destinationFile); err != nil {
                        return err
                }
@@ -1004,6 +1022,10 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro
                        return nil
                })
 
+               if err != nil {
+                       return err
+               }
+
        case FolderUpload:
                dstPath, err := readPath(s.Config.FileRoot, fileTransfer.FilePath, fileTransfer.FileName)
                if err != nil {