]> git.r.bdr.sh - rbdr/mobius/blob - internal/mobius/ban_test.go
46860c6ae5aabfbc62d8aea941dcf6f3a73645e7
[rbdr/mobius] / internal / mobius / ban_test.go
1 package mobius
2
3 import (
4 "fmt"
5 "github.com/stretchr/testify/assert"
6 "os"
7 "path/filepath"
8 "sync"
9 "testing"
10 "time"
11 )
12
13 func TestNewBanFile(t *testing.T) {
14 cwd, _ := os.Getwd()
15 str := "2024-06-29T11:34:43.245899-07:00"
16 testTime, _ := time.Parse(time.RFC3339Nano, str)
17
18 type args struct {
19 path string
20 }
21 tests := []struct {
22 name string
23 args args
24 want *BanFile
25 wantErr assert.ErrorAssertionFunc
26 }{
27 {
28 name: "Valid path with valid content",
29 args: args{path: filepath.Join(cwd, "test", "config", "Banlist.yaml")},
30 want: &BanFile{
31 filePath: filepath.Join(cwd, "test", "config", "Banlist.yaml"),
32 banList: map[string]*time.Time{"192.168.86.29": &testTime},
33 },
34 wantErr: assert.NoError,
35 },
36 }
37 for _, tt := range tests {
38 t.Run(tt.name, func(t *testing.T) {
39 got, err := NewBanFile(tt.args.path)
40 if !tt.wantErr(t, err, fmt.Sprintf("NewBanFile(%v)", tt.args.path)) {
41 return
42 }
43 assert.Equalf(t, tt.want, got, "NewBanFile(%v)", tt.args.path)
44 })
45 }
46 }
47
48 // TestAdd tests the Add function.
49 func TestAdd(t *testing.T) {
50 // Create a temporary directory.
51 tmpDir, err := os.MkdirTemp("", "banfile_test")
52 if err != nil {
53 t.Fatalf("Failed to create temp directory: %v", err)
54 }
55 defer os.RemoveAll(tmpDir) // Clean up the temporary directory.
56
57 // Path to the temporary ban file.
58 tmpFilePath := filepath.Join(tmpDir, "banfile.yaml")
59
60 // Initialize BanFile.
61 bf := &BanFile{
62 filePath: tmpFilePath,
63 banList: make(map[string]*time.Time),
64 }
65
66 // Define the test cases.
67 tests := []struct {
68 name string
69 ip string
70 until *time.Time
71 expect map[string]*time.Time
72 }{
73 {
74 name: "Add IP with no expiration",
75 ip: "192.168.1.1",
76 until: nil,
77 expect: map[string]*time.Time{
78 "192.168.1.1": nil,
79 },
80 },
81 {
82 name: "Add IP with expiration",
83 ip: "192.168.1.2",
84 until: func() *time.Time { t := time.Date(2024, 6, 29, 11, 34, 43, 245899000, time.UTC); return &t }(),
85 expect: map[string]*time.Time{
86 "192.168.1.1": nil,
87 "192.168.1.2": func() *time.Time { t := time.Date(2024, 6, 29, 11, 34, 43, 245899000, time.UTC); return &t }(),
88 },
89 },
90 }
91
92 // Run the test cases.
93 for _, tt := range tests {
94 t.Run(tt.name, func(t *testing.T) {
95 err := bf.Add(tt.ip, tt.until)
96 assert.NoError(t, err, "Add() error")
97
98 // Load the file to check its contents.
99 loadedBanFile := &BanFile{filePath: tmpFilePath}
100 err = loadedBanFile.Load()
101 assert.NoError(t, err, "Load() error")
102 assert.Equal(t, tt.expect, loadedBanFile.banList, "Ban list does not match")
103 })
104 }
105 }
106
107 func TestBanFile_IsBanned(t *testing.T) {
108 type fields struct {
109 banList map[string]*time.Time
110 Mutex sync.Mutex
111 }
112 type args struct {
113 ip string
114 }
115 tests := []struct {
116 name string
117 fields fields
118 args args
119 want bool
120 want1 *time.Time
121 }{
122 {
123 name: "with permanent ban",
124 fields: fields{
125 banList: map[string]*time.Time{
126 "192.168.86.1": nil,
127 },
128 },
129 args: args{ip: "192.168.86.1"},
130 want: true,
131 want1: nil,
132 },
133 {
134 name: "with no ban",
135 fields: fields{
136 banList: map[string]*time.Time{},
137 },
138 args: args{ip: "192.168.86.1"},
139 want: false,
140 want1: nil,
141 },
142 }
143 for _, tt := range tests {
144 t.Run(tt.name, func(t *testing.T) {
145 bf := &BanFile{
146 banList: tt.fields.banList,
147 Mutex: sync.Mutex{},
148 }
149 got, got1 := bf.IsBanned(tt.args.ip)
150 assert.Equalf(t, tt.want, got, "IsBanned(%v)", tt.args.ip)
151 assert.Equalf(t, tt.want1, got1, "IsBanned(%v)", tt.args.ip)
152 })
153 }
154 }