func HandleTranAgreed(cc *ClientConn, t *Transaction) (res []Transaction, err error) {
cc.Agreed = true
- cc.UserName = t.GetField(fieldUserName).Data
+
+ if t.GetField(fieldUserName).Data != nil {
+ if cc.Authorize(accessAnyName) {
+ cc.UserName = t.GetField(fieldUserName).Data
+ } else {
+ cc.UserName = []byte(cc.Account.Name)
+ }
+ }
+
*cc.Icon = t.GetField(fieldUserIconID).Data
cc.logger = cc.logger.With("name", string(cc.UserName))
cc.AutoReply = []byte{}
}
- for _, t := range cc.notifyOthers(
+ trans := cc.notifyOthers(
*NewTransaction(
tranNotifyChangeUser, nil,
NewField(fieldUserName, cc.UserName),
NewField(fieldUserIconID, *cc.Icon),
NewField(fieldUserFlags, *cc.Flags),
),
- ) {
- cc.Server.outbox <- t
- }
+ )
+ res = append(res, trans...)
if cc.Server.Config.BannerFile != "" {
- cc.Server.outbox <- *NewTransaction(tranServerBanner, cc.ID, NewField(fieldBannerType, []byte("JPEG")))
+ res = append(res, *NewTransaction(tranServerBanner, cc.ID, NewField(fieldBannerType, []byte("JPEG"))))
}
res = append(res, cc.NewReply(t))
})
}
}
+
+func TestHandleTranAgreed(t *testing.T) {
+ type args struct {
+ cc *ClientConn
+ t *Transaction
+ }
+ tests := []struct {
+ name string
+ args args
+ wantRes []Transaction
+ wantErr assert.ErrorAssertionFunc
+ }{
+ {
+ name: "normal request flow",
+ args: args{
+ cc: &ClientConn{
+ Account: &Account{
+ Access: func() *[]byte {
+ var bits accessBitmap
+ bits.Set(accessDisconUser)
+ bits.Set(accessAnyName)
+ access := bits[:]
+ return &access
+ }()},
+ Icon: &[]byte{0, 1},
+ Flags: &[]byte{0, 1},
+ Version: &[]byte{0, 1},
+ ID: &[]byte{0, 1},
+ logger: NewTestLogger(),
+ Server: &Server{
+ Config: &Config{
+ BannerFile: "banner.jpg",
+ },
+ },
+ },
+ t: NewTransaction(
+ tranAgreed, nil,
+ NewField(fieldUserName, []byte("username")),
+ NewField(fieldUserIconID, []byte{0, 1}),
+ NewField(fieldOptions, []byte{0, 0}),
+ ),
+ },
+ wantRes: []Transaction{
+ {
+ clientID: &[]byte{0, 1},
+ Flags: 0x00,
+ IsReply: 0x00,
+ Type: []byte{0, 0x7a},
+ ID: []byte{0, 0, 0, 0},
+ ErrorCode: []byte{0, 0, 0, 0},
+ Fields: []Field{
+ NewField(fieldBannerType, []byte("JPEG")),
+ },
+ },
+ {
+ clientID: &[]byte{0, 1},
+ Flags: 0x00,
+ IsReply: 0x01,
+ Type: []byte{0, 0x79},
+ ID: []byte{0, 0, 0, 0},
+ ErrorCode: []byte{0, 0, 0, 0},
+ Fields: []Field{},
+ },
+ },
+ wantErr: assert.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotRes, err := HandleTranAgreed(tt.args.cc, tt.args.t)
+ if !tt.wantErr(t, err, fmt.Sprintf("HandleTranAgreed(%v, %v)", tt.args.cc, tt.args.t)) {
+ return
+ }
+ tranAssertEqual(t, tt.wantRes, gotRes)
+ })
+ }
+}