From: Jeff Halter Date: Wed, 24 Jul 2024 00:52:10 +0000 (-0700) Subject: Add client connection rate limit X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/23c1295a4d9b13fdb955d2041a49b9b5d6daa060?hp=318eb6bc17ee406e26c1cfe0835409fd446bad75 Add client connection rate limit --- diff --git a/go.mod b/go.mod index 378e229..65a454d 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.25.0 golang.org/x/text v0.16.0 + golang.org/x/time v0.5.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 73ed644..aed79d0 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/hotline/server.go b/hotline/server.go index 5531bd3..758f7ff 100644 --- a/hotline/server.go +++ b/hotline/server.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "golang.org/x/text/encoding/charmap" + "golang.org/x/time/rate" "io" "log" "log/slog" @@ -37,6 +38,8 @@ type Server struct { NetInterface string Port int + rateLimiters map[string]*rate.Limiter + handlers map[TranType]HandlerFunc Config Config @@ -98,6 +101,7 @@ func NewServer(options ...Option) (*Server, error) { server := Server{ handlers: make(map[TranType]HandlerFunc), outbox: make(chan Transaction), + rateLimiters: make(map[string]*rate.Limiter), FS: &OSFileStore{}, ChatMgr: NewMemChatManager(), ClientMgr: NewMemClientMgr(), @@ -202,6 +206,10 @@ func (s *Server) processOutbox() { } } +// perIPRateLimit controls how frequently an IP address can connect before being throttled. +// 0.5 = 1 connection every 2 seconds +const perIPRateLimit = rate.Limit(0.5) + func (s *Server) Serve(ctx context.Context, ln net.Listener) error { for { select { @@ -216,13 +224,29 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { } go func() { + ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0] + connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{ remoteAddr: conn.RemoteAddr().String(), }) - s.Logger.Info("Connection established", "addr", conn.RemoteAddr()) + s.Logger.Info("Connection established", "ip", ipAddr) defer conn.Close() + // Check if we have an existing rate limit for the IP and create one if we do not. + rl, ok := s.rateLimiters[ipAddr] + if !ok { + rl = rate.NewLimiter(perIPRateLimit, 1) + s.rateLimiters[ipAddr] = rl + } + + // Check if the rate limit is exceeded and close the connection if so. + if !rl.Allow() { + s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr()) + conn.Close() + return + } + if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil { if err == io.EOF { s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())