import (
"bufio"
+ "bytes"
"context"
"encoding/binary"
"errors"
Config *Config
ConfigDir string
Logger *zap.SugaredLogger
+ banner []byte
PrivateChatsMu sync.Mutex
PrivateChats map[uint32]*PrivateChat
server.Config.FileRoot = filepath.Join(configDir, server.Config.FileRoot)
}
+ server.banner, err = os.ReadFile(filepath.Join(server.ConfigDir, server.Config.BannerFile))
+ if err != nil {
+ return nil, fmt.Errorf("error opening banner: %w", err)
+ }
+
*server.NextGuestID = 1
if server.Config.EnableTrackerRegistration {
switch fileTransfer.Type {
case bannerDownload:
- if err := s.bannerDownload(rwc); err != nil {
- return err
+ if _, err := io.Copy(rwc, bytes.NewBuffer(s.banner)); err != nil {
+ return fmt.Errorf("error sending banner: %w", err)
}
case FileDownload:
s.Stats.DownloadCounter += 1
}
}
-func TestHandleDownloadBanner(t *testing.T) {
- type args struct {
- cc *ClientConn
- t *Transaction
- }
- tests := []struct {
- name string
- args args
- wantRes []Transaction
- wantErr assert.ErrorAssertionFunc
- }{
- {
- name: "returns expected response",
- args: args{
- cc: &ClientConn{
- ID: &[]byte{0, 1},
- transfers: map[int]map[[4]byte]*FileTransfer{
- bannerDownload: {},
- },
- Server: &Server{
- ConfigDir: "/config",
- Config: &Config{
- BannerFile: "banner.jpg",
- },
- fileTransfers: map[[4]byte]*FileTransfer{},
- FS: func() *MockFileStore {
- mfi := &MockFileInfo{}
- mfi.On("Size").Return(int64(100))
-
- mfs := &MockFileStore{}
- mfs.On("Stat", "/config/banner.jpg").Return(mfi, nil)
- return mfs
- }(),
- },
- },
- t: NewTransaction(TranDownloadBanner, nil),
- },
- wantRes: []Transaction{
- {
- clientID: &[]byte{0, 1},
- Flags: 0x00,
- IsReply: 0x01,
- Type: []byte{0, 0},
- ID: []byte{0, 0, 0, 0},
- ErrorCode: []byte{0, 0, 0, 0},
- Fields: []Field{
- NewField(FieldRefNum, []byte{1, 2, 3, 4}),
- NewField(FieldTransferSize, []byte{0, 0, 0, 0x64}),
- },
- },
- },
- wantErr: assert.NoError,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- gotRes, err := HandleDownloadBanner(tt.args.cc, tt.args.t)
- if !tt.wantErr(t, err, fmt.Sprintf("HandleDownloadBanner(%v, %v)", tt.args.cc, tt.args.t)) {
- return
- }
-
- tranAssertEqual(t, tt.wantRes, gotRes)
- })
- }
-}
-
func TestHandleTranOldPostNews(t *testing.T) {
type args struct {
cc *ClientConn