"errors"
"fmt"
"golang.org/x/text/encoding/charmap"
+ "golang.org/x/time/rate"
"io"
"log"
"log/slog"
NetInterface string
Port int
+ rateLimiters map[string]*rate.Limiter
+
handlers map[TranType]HandlerFunc
Config Config
server := Server{
handlers: make(map[TranType]HandlerFunc),
outbox: make(chan Transaction),
+ rateLimiters: make(map[string]*rate.Limiter),
FS: &OSFileStore{},
ChatMgr: NewMemChatManager(),
ClientMgr: NewMemClientMgr(),
)
if err != nil {
- s.Logger.Error("file transfer error", "reason", err)
+ s.Logger.Error("file transfer error", "err", err)
}
}()
}
}
}
+// perIPRateLimit controls how frequently an IP address can connect before being throttled.
+// 0.5 = 1 connection every 2 seconds
+const perIPRateLimit = rate.Limit(0.5)
+
func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
for {
select {
}
go func() {
+ ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0]
+
connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
remoteAddr: conn.RemoteAddr().String(),
})
- s.Logger.Info("Connection established", "addr", conn.RemoteAddr())
+ s.Logger.Info("Connection established", "ip", ipAddr)
defer conn.Close()
+ // Check if we have an existing rate limit for the IP and create one if we do not.
+ rl, ok := s.rateLimiters[ipAddr]
+ if !ok {
+ rl = rate.NewLimiter(perIPRateLimit, 1)
+ s.rateLimiters[ipAddr] = rl
+ }
+
+ // Check if the rate limit is exceeded and close the connection if so.
+ if !rl.Allow() {
+ s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr())
+ conn.Close()
+ return
+ }
+
if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
if err == io.EOF {
s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())