X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/blobdiff_plain/d2810ae9038b057e8f18e8b495a57d8f96ae5be6..40414f9295dd301f1c714bc3067392b7cc8f8427:/hotline/server.go?ds=sidebyside diff --git a/hotline/server.go b/hotline/server.go index aebff6e..d5da458 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/go-playground/validator/v10" "go.uber.org/zap" "io" "io/fs" @@ -43,7 +44,7 @@ type Server struct { Logger *zap.SugaredLogger PrivateChats map[uint32]*PrivateChat NextGuestID *uint16 - TrackerPassID []byte + TrackerPassID [4]byte Stats *Stats APIListener net.Listener @@ -131,7 +132,7 @@ func (s *Server) sendTransaction(t Transaction) error { "IsReply", t.IsReply, "type", handler.Name, "sentBytes", n, - "remoteAddr", client.Connection.RemoteAddr(), + "remoteAddr", client.RemoteAddr, ) return nil } @@ -156,7 +157,7 @@ func (s *Server) Serve(ctx context.Context, cancelRoot context.CancelFunc, ln ne } }() go func() { - if err := s.handleNewConnection(conn); err != nil { + if err := s.handleNewConnection(conn, conn.RemoteAddr().String()); err != nil { if err == io.EOF { s.Logger.Infow("Client disconnected", "RemoteAddr", conn.RemoteAddr()) } else { @@ -186,7 +187,6 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL outbox: make(chan Transaction), Stats: &Stats{StartTime: time.Now()}, ThreadedNews: &ThreadedNews{}, - TrackerPassID: make([]byte, 4), } ln, err := net.Listen("tcp", fmt.Sprintf("%s:%v", netInterface, netPort)) @@ -206,7 +206,7 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL } // generate a new random passID for tracker registration - if _, err := rand.Read(server.TrackerPassID); err != nil { + if _, err := rand.Read(server.TrackerPassID[:]); err != nil { return nil, err } @@ -236,21 +236,26 @@ func NewServer(configDir, netInterface string, netPort int, logger *zap.SugaredL *server.NextGuestID = 1 if server.Config.EnableTrackerRegistration { + server.Logger.Infow( + "Tracker registration enabled", + "frequency", fmt.Sprintf("%vs", trackerUpdateFrequency), + "trackers", server.Config.Trackers, + ) + go func() { for { - tr := TrackerRegistration{ - Port: []byte{0x15, 0x7c}, + tr := &TrackerRegistration{ UserCount: server.userCount(), - PassID: server.TrackerPassID, + PassID: server.TrackerPassID[:], Name: server.Config.Name, Description: server.Config.Description, } + binary.BigEndian.PutUint16(tr.Port[:], uint16(server.Port)) for _, t := range server.Config.Trackers { - server.Logger.Infof("Registering with tracker %v", t) - if err := register(t, tr); err != nil { server.Logger.Errorw("unable to register with tracker %v", "error", err) } + server.Logger.Infow("Sent Tracker registration", "data", tr) } time.Sleep(trackerUpdateFrequency * time.Second) @@ -314,7 +319,7 @@ func (s *Server) writeThreadedNews() error { return err } -func (s *Server) NewClientConn(conn net.Conn) *ClientConn { +func (s *Server) NewClientConn(conn net.Conn, remoteAddr string) *ClientConn { s.mux.Lock() defer s.mux.Unlock() @@ -329,6 +334,7 @@ func (s *Server) NewClientConn(conn net.Conn) *ClientConn { AutoReply: []byte{}, Transfers: make(map[int][]*FileTransfer), Agreed: false, + RemoteAddr: remoteAddr, } *s.NextGuestID++ ID := *s.NextGuestID @@ -471,6 +477,12 @@ func (s *Server) loadConfig(path string) error { if err != nil { return err } + + validate := validator.New() + err = validate.Struct(s.Config) + if err != nil { + return err + } return nil } @@ -478,17 +490,25 @@ const ( minTransactionLen = 22 // minimum length of any transaction ) -// handleNewConnection takes a new net.Conn and performs the initial login sequence -func (s *Server) handleNewConnection(conn net.Conn) error { - handshakeBuf := make([]byte, 12) // handshakes are always 12 bytes in length - if _, err := conn.Read(handshakeBuf); err != nil { - return err +// dontPanic recovers and logs panics instead of crashing +// TODO: remove this after known issues are fixed +func dontPanic(logger *zap.SugaredLogger) { + if r := recover(); r != nil { + fmt.Println("stacktrace from panic: \n" + string(debug.Stack())) + logger.Errorw("PANIC", "err", r, "trace", string(debug.Stack())) } - if err := Handshake(conn, handshakeBuf[:12]); err != nil { +} + +// handleNewConnection takes a new net.Conn and performs the initial login sequence +func (s *Server) handleNewConnection(conn net.Conn, remoteAddr string) error { + defer dontPanic(s.Logger) + + if err := Handshake(conn); err != nil { return err } buf := make([]byte, 1024) + // TODO: fix potential short read with io.ReadFull readLen, err := conn.Read(buf) if readLen < minTransactionLen { return err @@ -502,15 +522,8 @@ func (s *Server) handleNewConnection(conn net.Conn) error { return err } - c := s.NewClientConn(conn) + c := s.NewClientConn(conn, remoteAddr) defer c.Disconnect() - defer func() { - if r := recover(); r != nil { - fmt.Println("stacktrace from panic: \n" + string(debug.Stack())) - c.Server.Logger.Errorw("PANIC", "err", r, "trace", string(debug.Stack())) - c.Disconnect() - } - }() encodedLogin := clientLogin.GetField(fieldUserLogin).Data encodedPassword := clientLogin.GetField(fieldUserPassword).Data @@ -551,7 +564,7 @@ func (s *Server) handleNewConnection(conn net.Conn) error { *c.Flags = []byte{0, 2} } - s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", conn.RemoteAddr().String()) + s.Logger.Infow("Client connection received", "login", login, "version", *c.Version, "RemoteAddr", remoteAddr) s.outbox <- c.NewReply(clientLogin, NewField(fieldVersion, []byte{0x00, 0xbe}), @@ -646,28 +659,37 @@ const dlFldrActionNextFile = 3 // handleFileTransfer receives a client net.Conn from the file transfer server, performs the requested transfer type, then closes the connection func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { defer func() { + if err := conn.Close(); err != nil { s.Logger.Errorw("error closing connection", "error", err) } }() + defer dontPanic(s.Logger) + txBuf := make([]byte, 16) - _, err := conn.Read(txBuf) - if err != nil { + if _, err := io.ReadFull(conn, txBuf); err != nil { return err } var t transfer - _, err = t.Write(txBuf) - if err != nil { + if _, err := t.Write(txBuf); err != nil { return err } transferRefNum := binary.BigEndian.Uint32(t.ReferenceNumber[:]) - fileTransfer := s.FileTransfers[transferRefNum] + defer func() { + s.mux.Lock() + delete(s.FileTransfers, transferRefNum) + s.mux.Unlock() + }() - // delete single use transferRefNum - delete(s.FileTransfers, transferRefNum) + s.mux.Lock() + fileTransfer, ok := s.FileTransfers[transferRefNum] + s.mux.Unlock() + if !ok { + return errors.New("invalid transaction ID") + } switch fileTransfer.Type { case FileDownload: @@ -688,9 +710,11 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { s.Logger.Infow("File download started", "filePath", fullFilePath, "transactionRef", fileTransfer.ReferenceNumber) - // Start by sending flat file object to client - if _, err := conn.Write(ffo.BinaryMarshal()); err != nil { - return err + if fileTransfer.options == nil { + // Start by sending flat file object to client + if _, err := conn.Write(ffo.BinaryMarshal()); err != nil { + return err + } } file, err := FS.Open(fullFilePath) @@ -721,11 +745,27 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { } case FileUpload: destinationFile := s.Config.FileRoot + ReadFilePath(fileTransfer.FilePath) + "/" + string(fileTransfer.FileName) - tmpFile := destinationFile + ".incomplete" - file, err := effectiveFile(destinationFile) + var file *os.File + + // A file upload has three possible cases: + // 1) Upload a new file + // 2) Resume a partially transferred file + // 3) Replace a fully uploaded file + // Unfortunately we have to infer which case applies by inspecting what is already on the file system + + // 1) Check for existing file: + _, err := os.Stat(destinationFile) + if err == nil { + // If found, that means this upload is intended to replace the file + if err = os.Remove(destinationFile); err != nil { + return err + } + file, err = os.Create(destinationFile + incompleteFileSuffix) + } if errors.Is(err, fs.ErrNotExist) { - file, err = FS.Create(tmpFile) + // If not found, open or create a new incomplete file + file, err = os.OpenFile(destinationFile+incompleteFileSuffix, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return err } @@ -783,7 +823,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { s.Logger.Infow("Start folder download", "path", fullFilePath, "ReferenceNumber", fileTransfer.ReferenceNumber) nextAction := make([]byte, 2) - if _, err := conn.Read(nextAction); err != nil { + if _, err := io.ReadFull(conn, nextAction); err != nil { return err } @@ -809,7 +849,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { } // Read the client's Next Action request - if _, err := conn.Read(nextAction); err != nil { + if _, err := io.ReadFull(conn, nextAction); err != nil { return err } @@ -822,13 +862,13 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { // client asked to resume this file var frd FileResumeData // get size of resumeData - if _, err := conn.Read(nextAction); err != nil { + if _, err := io.ReadFull(conn, nextAction); err != nil { return err } resumeDataLen := binary.BigEndian.Uint16(nextAction) resumeDataBytes := make([]byte, resumeDataLen) - if _, err := conn.Read(resumeDataBytes); err != nil { + if _, err := io.ReadFull(conn, resumeDataBytes); err != nil { return err } @@ -905,7 +945,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { // TODO: optionally send resource fork header and resource fork data // Read the client's Next Action request. This is always 3, I think? - if _, err := conn.Read(nextAction); err != nil { + if _, err := io.ReadFull(conn, nextAction); err != nil { return err } @@ -941,7 +981,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { readBuffer := make([]byte, 1024) for i := 0; i < fileTransfer.ItemCount(); i++ { - + // TODO: fix potential short read with io.ReadFull _, err := conn.Read(readBuffer) if err != nil { return err @@ -1019,7 +1059,7 @@ func (s *Server) handleFileTransfer(conn io.ReadWriteCloser) error { return err } - if _, err := conn.Read(fileSize); err != nil { + if _, err := io.ReadFull(conn, fileSize); err != nil { return err }