import (
"encoding/binary"
+ "go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
+ "io"
"math/big"
- "net"
+ "sort"
)
type byClientID []*ClientConn
return s[i].uint16ID() < s[j].uint16ID()
}
+const template = `Nickname: %s
+Name: %s
+Account: %s
+Address: %s
+
+-------- File Downloads ---------
+
+%s
+
+------- Folder Downloads --------
+
+None.
+
+--------- File Uploads ----------
+
+None.
+
+-------- Folder Uploads ---------
+
+None.
+
+------- Waiting Downloads -------
+
+None.
+
+ `
+
// ClientConn represents a client connected to a Server
type ClientConn struct {
- Connection net.Conn
+ Connection io.ReadWriteCloser
+ RemoteAddr string
ID *[]byte
Icon *[]byte
Flags *[]byte
AutoReply []byte
Transfers map[int][]*FileTransfer
Agreed bool
+ logger *zap.SugaredLogger
}
func (cc *ClientConn) sendAll(t int, fields ...Field) {
// Validate that required field is present
if field.ID == nil {
- cc.Server.Logger.Errorw(
+ cc.logger.Errorw(
"Missing required field",
- "Account", cc.Account.Login, "UserName", string(cc.UserName), "RequestType", handler.Name, "FieldID", reqField.ID,
+ "RequestType", handler.Name, "FieldID", reqField.ID,
)
return nil
}
if len(field.Data) < reqField.minLen {
- cc.Server.Logger.Infow(
+ cc.logger.Infow(
"Field does not meet minLen",
- "Account", cc.Account.Login, "UserName", string(cc.UserName), "RequestType", handler.Name, "FieldID", reqField.ID,
+ "RequestType", handler.Name, "FieldID", reqField.ID,
)
return nil
}
}
- if !authorize(cc.Account.Access, handler.Access) {
- cc.Server.Logger.Infow(
- "Unauthorized Action",
- "Account", cc.Account.Login, "UserName", string(cc.UserName), "RequestType", handler.Name,
- )
- cc.Server.outbox <- cc.NewErrReply(transaction, handler.DenyMsg)
-
- return nil
- }
- cc.Server.Logger.Infow(
- "Received Transaction",
- "login", cc.Account.Login,
- "name", string(cc.UserName),
- "RequestType", handler.Name,
- )
+ cc.logger.Infow("Received Transaction", "RequestType", handler.Name)
transactions, err := handler.Handler(cc, transaction)
if err != nil {
cc.Server.outbox <- t
}
} else {
- cc.Server.Logger.Errorw(
- "Unimplemented transaction type received",
- "UserName", string(cc.UserName), "RequestID", requestNum,
- )
+ cc.logger.Errorw(
+ "Unimplemented transaction type received", "RequestID", requestNum)
}
cc.Server.mux.Lock()
return true
}
- accessBitmap := big.NewInt(int64(binary.BigEndian.Uint64(*cc.Account.Access)))
+ i := big.NewInt(int64(binary.BigEndian.Uint64(*cc.Account.Access)))
- return accessBitmap.Bit(63-access) == 1
+ return i.Bit(63-access) == 1
}
// Disconnect notifies other clients that a client has disconnected
-func (cc ClientConn) Disconnect() {
+func (cc *ClientConn) Disconnect() {
cc.Server.mux.Lock()
defer cc.Server.mux.Unlock()
delete(cc.Server.Clients, binary.BigEndian.Uint16(*cc.ID))
- cc.NotifyOthers(*NewTransaction(tranNotifyDeleteUser, nil, NewField(fieldUserID, *cc.ID)))
+ for _, t := range cc.notifyOthers(*NewTransaction(tranNotifyDeleteUser, nil, NewField(fieldUserID, *cc.ID))) {
+ cc.Server.outbox <- t
+ }
if err := cc.Connection.Close(); err != nil {
- cc.Server.Logger.Errorw("error closing client connection", "RemoteAddr", cc.Connection.RemoteAddr())
+ cc.Server.Logger.Errorw("error closing client connection", "RemoteAddr", cc.RemoteAddr)
}
}
-// NotifyOthers sends transaction t to other clients connected to the server
-func (cc ClientConn) NotifyOthers(t Transaction) {
+// notifyOthers sends transaction t to other clients connected to the server
+func (cc *ClientConn) notifyOthers(t Transaction) (trans []Transaction) {
for _, c := range sortedClients(cc.Server.Clients) {
if c.ID != cc.ID && c.Agreed {
t.clientID = c.ID
- cc.Server.outbox <- t
+ trans = append(trans, t)
}
}
+ return trans
}
// NewReply returns a reply Transaction with fields for the ClientConn
},
}
}
+
+// sortedClients is a utility function that takes a map of *ClientConn and returns a sorted slice of the values.
+// The purpose of this is to ensure that the ordering of client connections is deterministic so that test assertions work.
+func sortedClients(unsortedClients map[uint16]*ClientConn) (clients []*ClientConn) {
+ for _, c := range unsortedClients {
+ clients = append(clients, c)
+ }
+ sort.Sort(byClientID(clients))
+ return clients
+}