golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
+golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
+golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
"errors"
"fmt"
"golang.org/x/text/encoding/charmap"
+ "golang.org/x/time/rate"
"io"
"log"
"log/slog"
NetInterface string
Port int
+ rateLimiters map[string]*rate.Limiter
+
handlers map[TranType]HandlerFunc
Config Config
server := Server{
handlers: make(map[TranType]HandlerFunc),
outbox: make(chan Transaction),
+ rateLimiters: make(map[string]*rate.Limiter),
FS: &OSFileStore{},
ChatMgr: NewMemChatManager(),
ClientMgr: NewMemClientMgr(),
}
}
+// perIPRateLimit controls how frequently an IP address can connect before being throttled.
+// 0.5 = 1 connection every 2 seconds
+const perIPRateLimit = rate.Limit(0.5)
+
func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
for {
select {
}
go func() {
+ ipAddr := strings.Split(conn.RemoteAddr().(*net.TCPAddr).String(), ":")[0]
+
connCtx := context.WithValue(ctx, contextKeyReq, requestCtx{
remoteAddr: conn.RemoteAddr().String(),
})
- s.Logger.Info("Connection established", "addr", conn.RemoteAddr())
+ s.Logger.Info("Connection established", "ip", ipAddr)
defer conn.Close()
+ // Check if we have an existing rate limit for the IP and create one if we do not.
+ rl, ok := s.rateLimiters[ipAddr]
+ if !ok {
+ rl = rate.NewLimiter(perIPRateLimit, 1)
+ s.rateLimiters[ipAddr] = rl
+ }
+
+ // Check if the rate limit is exceeded and close the connection if so.
+ if !rl.Allow() {
+ s.Logger.Info("Rate limit exceeded", "RemoteAddr", conn.RemoteAddr())
+ conn.Close()
+ return
+ }
+
if err := s.handleNewConnection(connCtx, conn, conn.RemoteAddr().String()); err != nil {
if err == io.EOF {
s.Logger.Info("Client disconnected", "RemoteAddr", conn.RemoteAddr())