]> git.r.bdr.sh - rbdr/mobius/commitdiff
Add client connection rate limit
authorJeff Halter <redacted>
Wed, 24 Jul 2024 00:52:10 +0000 (17:52 -0700)
committerJeff Halter <redacted>
Wed, 24 Jul 2024 00:54:36 +0000 (17:54 -0700)
go.mod
go.sum
hotline/server.go

diff --git a/go.mod b/go.mod
index 378e22979eae5e79cfb2428d6961263de13c12e9..65a454d9b292b2581a2711ef87d8b8726894a56e 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
        github.com/stretchr/testify v1.9.0
        golang.org/x/crypto v0.25.0
        golang.org/x/text v0.16.0
        github.com/stretchr/testify v1.9.0
        golang.org/x/crypto v0.25.0
        golang.org/x/text v0.16.0
+       golang.org/x/time v0.5.0
        gopkg.in/natefinch/lumberjack.v2 v2.2.1
        gopkg.in/yaml.v3 v3.0.1
 )
        gopkg.in/natefinch/lumberjack.v2 v2.2.1
        gopkg.in/yaml.v3 v3.0.1
 )
diff --git a/go.sum b/go.sum
index 73ed644f96d0a4d7099413570cb1e77a6598c1b8..aed79d00359ac3556eeb76bca2abfab1b675cf51 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -26,6 +26,8 @@ golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
 golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
 golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
+golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
+golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
index 5531bd3c59512916f927c705cd0ea4dc74d0a631..758f7ff8a3fa9aeacfb2c9b9a4bb65b46f715141 100644 (file)
@@ -9,6 +9,7 @@ import (
        "errors"
        "fmt"
        "golang.org/x/text/encoding/charmap"
        "errors"
        "fmt"
        "golang.org/x/text/encoding/charmap"
+       "golang.org/x/time/rate"
        "io"
        "log"
        "log/slog"
        "io"
        "log"
        "log/slog"
@@ -37,6 +38,8 @@ type Server struct {
        NetInterface string
        Port         int
 
        NetInterface string
        Port         int
 
+       rateLimiters map[string]*rate.Limiter
+
        handlers map[TranType]HandlerFunc
 
        Config Config
        handlers map[TranType]HandlerFunc
 
        Config Config
@@ -98,6 +101,7 @@ func NewServer(options ...Option) (*Server, error) {
        server := Server{
                handlers:        make(map[TranType]HandlerFunc),
                outbox:          make(chan Transaction),
        server := Server{
                handlers:        make(map[TranType]HandlerFunc),
                outbox:          make(chan Transaction),
+               rateLimiters:    make(map[string]*rate.Limiter),
                FS:              &OSFileStore{},
                ChatMgr:         NewMemChatManager(),
                ClientMgr:       NewMemClientMgr(),
                FS:              &OSFileStore{},
                ChatMgr:         NewMemChatManager(),
                ClientMgr:       NewMemClientMgr(),
@@ -202,6 +206,10 @@ func (s *Server) processOutbox() {
        }
 }
 
        }
 }
 
+// perIPRateLimit controls how frequently an IP address can connect before being throttled.
+// 0.5 = 1 connection every 2 seconds
+const perIPRateLimit = rate.Limit(0.5)
+
 func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
        for {
                select {
 func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
        for {
                select {
@@ -216,13 +224,29 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
                        }
 
                        go func() {
                        }
 
                        go func() {
+                               ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0]
+
                                connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
                                        remoteAddr: conn.RemoteAddr().String(),
                                })
 
                                connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
                                        remoteAddr: conn.RemoteAddr().String(),
                                })
 
-                               s.Logger.Info("Connection established", "addr", conn.RemoteAddr())
+                               s.Logger.Info("Connection established", "ip", ipAddr)
                                defer conn.Close()
 
                                defer conn.Close()
 
+                               // Check if we have an existing rate limit for the IP and create one if we do not.
+                               rl, ok := s.rateLimiters[ipAddr]
+                               if !ok {
+                                       rl = rate.NewLimiter(perIPRateLimit, 1)
+                                       s.rateLimiters[ipAddr] = rl
+                               }
+
+                               // Check if the rate limit is exceeded and close the connection if so.
+                               if !rl.Allow() {
+                                       s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr())
+                                       conn.Close()
+                                       return
+                               }
+
                                if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
                                        if err == io.EOF {
                                                s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())
                                if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
                                        if err == io.EOF {
                                                s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())