X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/b8b0a6c9cb43d52aece4020163b612cfe87de898..5cc444c89968dda9060d4e2f458b1babb9b603cd:/hotline/server.go?ds=sidebyside diff --git a/hotline/server.go b/hotline/server.go index aca2221..b5bb0d5 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -2,6 +2,7 @@ package hotline import ( "bufio" + "bytes" "context" "encoding/binary" "errors" @@ -48,6 +49,7 @@ type Server struct { Config *Config ConfigDir string Logger *zap.SugaredLogger + banner []byte PrivateChatsMu sync.Mutex PrivateChats map[uint32]*PrivateChat @@ -270,6 +272,11 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL server.Config.FileRoot = filepath.Join(configDir, server.Config.FileRoot) } + server.banner, err = os.ReadFile(filepath.Join(server.ConfigDir, server.Config.BannerFile)) + if err != nil { + return nil, fmt.Errorf("error opening banner: %w", err) + } + *server.NextGuestID = 1 if server.Config.EnableTrackerRegistration { @@ -283,7 +290,7 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL for { tr := &TrackerRegistration{ UserCount: server.userCount(), - PassID: server.TrackerPassID[:], + PassID: server.TrackerPassID, Name: server.Config.Name, Description: server.Config.Description, } @@ -494,13 +501,16 @@ func (s *Server) connectedUsers() []Field { var connectedUsers []Field for _, c := range sortedClients(s.Clients) { - user := User{ + b, err := io.ReadAll(&User{ ID: *c.ID, Icon: c.Icon, Flags: c.Flags, Name: string(c.UserName), + }) + if err != nil { + return nil } - connectedUsers = append(connectedUsers, NewField(FieldUsernameWithInfo, user.Payload())) + connectedUsers = append(connectedUsers, NewField(FieldUsernameWithInfo, b)) } return connectedUsers } @@ -837,8 +847,8 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro switch fileTransfer.Type { case bannerDownload: - if err := s.bannerDownload(rwc); err != nil { - return err + if _, err := io.Copy(rwc, bytes.NewBuffer(s.banner)); err != nil { + return fmt.Errorf("error sending banner: %w", err) } case FileDownload: s.Stats.DownloadCounter += 1 @@ -861,8 +871,8 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro // if file transfer options are included, that means this is a "quick preview" request from a 1.5+ client if fileTransfer.options == nil { - // Start by sending flat file object to client - if _, err := rwc.Write(fw.ffo.BinaryMarshal()); err != nil { + _, err = io.Copy(rwc, fw.ffo) + if err != nil { return err } } @@ -1027,11 +1037,8 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro } fileHeader := NewFileHeader(subPath, info.IsDir()) - - // Send the fileWrapper header to client - if _, err := rwc.Write(fileHeader.Payload()); err != nil { - s.Logger.Errorf("error sending file header: %v", err) - return err + if _, err := io.Copy(rwc, &fileHeader); err != nil { + return fmt.Errorf("error sending file header: %w", err) } // Read the client's Next Action request @@ -1083,8 +1090,8 @@ func (s *Server) handleFileTransfer(ctx context.Context, rwc io.ReadWriter) erro } // Send ffo bytes to client - if _, err := rwc.Write(hlFile.ffo.BinaryMarshal()); err != nil { - s.Logger.Error(err) + _, err = io.Copy(rwc, hlFile.ffo) + if err != nil { return err }