]> git.r.bdr.sh - rbdr/mobius/blobdiff - hotline/transaction_handlers_test.go
Fix race in file upload handling that may cause panic
[rbdr/mobius] / hotline / transaction_handlers_test.go
index 7c5ee43268cad1433177115e637cafb5832bf656..5ffe45c1f37b3c7a52ebc8b835e91a9313777271 100644 (file)
@@ -1,10 +1,12 @@
 package hotline
 
 import (
 package hotline
 
 import (
+       "errors"
        "github.com/stretchr/testify/assert"
        "io/fs"
        "math/rand"
        "os"
        "github.com/stretchr/testify/assert"
        "io/fs"
        "math/rand"
        "os"
+       "strings"
        "testing"
 )
 
        "testing"
 )
 
@@ -577,7 +579,7 @@ func TestHandleNewFolder(t *testing.T) {
                                ),
                        },
                        setup: func() {
                                ),
                        },
                        setup: func() {
-                               mfs := MockFileStore{}
+                               mfs := &MockFileStore{}
                                mfs.On("Mkdir", "/Files/aaa/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/aaa/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
                                mfs.On("Mkdir", "/Files/aaa/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/aaa/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
@@ -611,7 +613,7 @@ func TestHandleNewFolder(t *testing.T) {
                                ),
                        },
                        setup: func() {
                                ),
                        },
                        setup: func() {
-                               mfs := MockFileStore{}
+                               mfs := &MockFileStore{}
                                mfs.On("Mkdir", "/Files/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
                                mfs.On("Mkdir", "/Files/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
@@ -648,7 +650,7 @@ func TestHandleNewFolder(t *testing.T) {
                                ),
                        },
                        setup: func() {
                                ),
                        },
                        setup: func() {
-                               mfs := MockFileStore{}
+                               mfs := &MockFileStore{}
                                mfs.On("Mkdir", "/Files/aaa/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/aaa/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
                                mfs.On("Mkdir", "/Files/aaa/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/aaa/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
@@ -673,7 +675,7 @@ func TestHandleNewFolder(t *testing.T) {
                                ),
                        },
                        setup: func() {
                                ),
                        },
                        setup: func() {
-                               mfs := MockFileStore{}
+                               mfs := &MockFileStore{}
                                mfs.On("Mkdir", "/Files/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
                                mfs.On("Mkdir", "/Files/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
@@ -715,7 +717,7 @@ func TestHandleNewFolder(t *testing.T) {
                                ),
                        },
                        setup: func() {
                                ),
                        },
                        setup: func() {
-                               mfs := MockFileStore{}
+                               mfs := &MockFileStore{}
                                mfs.On("Mkdir", "/Files/foo/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/foo/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
                                mfs.On("Mkdir", "/Files/foo/testFolder", fs.FileMode(0777)).Return(nil)
                                mfs.On("Stat", "/Files/foo/testFolder").Return(nil, os.ErrNotExist)
                                FS = mfs
@@ -849,9 +851,193 @@ func TestHandleUploadFile(t *testing.T) {
                                t.Errorf("HandleUploadFile() error = %v, wantErr %v", err, tt.wantErr)
                                return
                        }
                                t.Errorf("HandleUploadFile() error = %v, wantErr %v", err, tt.wantErr)
                                return
                        }
-                       if !tranAssertEqual(t, tt.wantRes, gotRes) {
-                               t.Errorf("HandleUploadFile() gotRes = %v, want %v", gotRes, tt.wantRes)
+
+                       tranAssertEqual(t, tt.wantRes, gotRes)
+
+               })
+       }
+}
+
+func TestHandleMakeAlias(t *testing.T) {
+       type args struct {
+               cc *ClientConn
+               t  *Transaction
+       }
+       tests := []struct {
+               name    string
+               setup   func()
+               args    args
+               wantRes []Transaction
+               wantErr bool
+       }{
+               {
+                       name: "with valid input and required permissions",
+                       setup: func() {
+                               mfs := &MockFileStore{}
+                               path, _ := os.Getwd()
+                               mfs.On(
+                                       "Symlink",
+                                       path+"/test/config/Files/foo/testFile",
+                                       path+"/test/config/Files/bar/testFile",
+                               ).Return(nil)
+                               FS = mfs
+                       },
+                       args: args{
+                               cc: &ClientConn{
+                                       Account: &Account{
+                                               Access: func() *[]byte {
+                                                       var bits accessBitmap
+                                                       bits.Set(accessMakeAlias)
+                                                       access := bits[:]
+                                                       return &access
+                                               }(),
+                                       },
+                                       Server: &Server{
+                                               Config: &Config{
+                                                       FileRoot: func() string {
+                                                               path, _ := os.Getwd()
+                                                               return path + "/test/config/Files"
+                                                       }(),
+                                               },
+                                               Logger: NewTestLogger(),
+                                       },
+                               },
+                               t: NewTransaction(
+                                       tranMakeFileAlias, &[]byte{0, 1},
+                                       NewField(fieldFileName, []byte("testFile")),
+                                       NewField(fieldFilePath, EncodeFilePath(strings.Join([]string{"foo"}, "/"))),
+                                       NewField(fieldFileNewPath, EncodeFilePath(strings.Join([]string{"bar"}, "/"))),
+                               ),
+                       },
+                       wantRes: []Transaction{
+                               {
+                                       Flags:     0x00,
+                                       IsReply:   0x01,
+                                       Type:      []byte{0, 0xd1},
+                                       ID:        []byte{0x9a, 0xcb, 0x04, 0x42},
+                                       ErrorCode: []byte{0, 0, 0, 0},
+                                       Fields:    []Field(nil),
+                               },
+                       },
+                       wantErr: false,
+               },
+               {
+                       name: "when symlink returns an error",
+                       setup: func() {
+                               mfs := &MockFileStore{}
+                               path, _ := os.Getwd()
+                               mfs.On(
+                                       "Symlink",
+                                       path+"/test/config/Files/foo/testFile",
+                                       path+"/test/config/Files/bar/testFile",
+                               ).Return(errors.New("ohno"))
+                               FS = mfs
+                       },
+                       args: args{
+                               cc: &ClientConn{
+                                       Account: &Account{
+                                               Access: func() *[]byte {
+                                                       var bits accessBitmap
+                                                       bits.Set(accessMakeAlias)
+                                                       access := bits[:]
+                                                       return &access
+                                               }(),
+                                       },
+                                       Server: &Server{
+                                               Config: &Config{
+                                                       FileRoot: func() string {
+                                                               path, _ := os.Getwd()
+                                                               return path + "/test/config/Files"
+                                                       }(),
+                                               },
+                                               Logger: NewTestLogger(),
+                                       },
+                               },
+                               t: NewTransaction(
+                                       tranMakeFileAlias, &[]byte{0, 1},
+                                       NewField(fieldFileName, []byte("testFile")),
+                                       NewField(fieldFilePath, EncodeFilePath(strings.Join([]string{"foo"}, "/"))),
+                                       NewField(fieldFileNewPath, EncodeFilePath(strings.Join([]string{"bar"}, "/"))),
+                               ),
+                       },
+                       wantRes: []Transaction{
+                               {
+                                       Flags:     0x00,
+                                       IsReply:   0x01,
+                                       Type:      []byte{0, 0x00},
+                                       ID:        []byte{0x9a, 0xcb, 0x04, 0x42},
+                                       ErrorCode: []byte{0, 0, 0, 1},
+                                       Fields: []Field{
+                                               NewField(fieldError, []byte("Error creating alias")),
+                                       },
+                               },
+                       },
+                       wantErr: false,
+               },
+               {
+                       name:  "when user does not have required permission",
+                       setup: func() {},
+                       args: args{
+                               cc: &ClientConn{
+                                       Account: &Account{
+                                               Access: func() *[]byte {
+                                                       var bits accessBitmap
+                                                       access := bits[:]
+                                                       return &access
+                                               }(),
+                                       },
+                                       Server: &Server{
+                                               Config: &Config{
+                                                       FileRoot: func() string {
+                                                               path, _ := os.Getwd()
+                                                               return path + "/test/config/Files"
+                                                       }(),
+                                               },
+                                       },
+                               },
+                               t: NewTransaction(
+                                       tranMakeFileAlias, &[]byte{0, 1},
+                                       NewField(fieldFileName, []byte("testFile")),
+                                       NewField(fieldFilePath, []byte{
+                                               0x00, 0x01,
+                                               0x00, 0x00,
+                                               0x03,
+                                               0x2e, 0x2e, 0x2e,
+                                       }),
+                                       NewField(fieldFileNewPath, []byte{
+                                               0x00, 0x01,
+                                               0x00, 0x00,
+                                               0x03,
+                                               0x2e, 0x2e, 0x2e,
+                                       }),
+                               ),
+                       },
+                       wantRes: []Transaction{
+                               {
+                                       Flags:     0x00,
+                                       IsReply:   0x01,
+                                       Type:      []byte{0, 0x00},
+                                       ID:        []byte{0x9a, 0xcb, 0x04, 0x42},
+                                       ErrorCode: []byte{0, 0, 0, 1},
+                                       Fields: []Field{
+                                               NewField(fieldError, []byte("You are not allowed to make aliases.")),
+                                       },
+                               },
+                       },
+                       wantErr: false,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       tt.setup()
+
+                       gotRes, err := HandleMakeAlias(tt.args.cc, tt.args.t)
+                       if (err != nil) != tt.wantErr {
+                               t.Errorf("HandleMakeAlias(%v, %v)", tt.args.cc, tt.args.t)
+                               return
                        }
                        }
+
+                       tranAssertEqual(t, tt.wantRes, gotRes)
                })
        }
 }
                })
        }
 }