8 func TestHandshakeWrite(t *testing.T) {
16 name: "Valid Handshake",
17 input: []byte{0x54, 0x52, 0x54, 0x50, 0x48, 0x4F, 0x54, 0x4C, 0x00, 0x01, 0x00, 0x02},
19 Protocol: [4]byte{0x54, 0x52, 0x54, 0x50},
20 SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C},
21 Version: [2]byte{0x00, 0x01},
22 SubVersion: [2]byte{0x00, 0x02},
27 name: "Invalid Handshake Size",
28 input: []byte{0x54, 0x52, 0x54, 0x50},
29 expected: handshake{},
30 expectedError: "invalid handshake size",
33 name: "Empty Handshake Data",
35 expected: handshake{},
36 expectedError: "invalid handshake size",
40 for _, tt := range tests {
41 t.Run(tt.name, func(t *testing.T) {
43 n, err := h.Write(tt.input)
45 if tt.expectedError != "" {
46 if err == nil || err.Error() != tt.expectedError {
47 t.Fatalf("expected error %q, got %q", tt.expectedError, err)
51 t.Fatalf("unexpected error: %v", err)
53 if n != handshakeSize {
54 t.Fatalf("expected %d bytes written, got %d", handshakeSize, n)
57 t.Fatalf("expected handshake %+v, got %+v", tt.expected, h)
64 func TestHandshakeValid(t *testing.T) {
71 name: "Valid Handshake",
73 Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP
74 SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL
75 Version: [2]byte{0x00, 0x01},
76 SubVersion: [2]byte{0x00, 0x02},
81 name: "Invalid Protocol",
83 Protocol: [4]byte{0x00, 0x00, 0x00, 0x00},
84 SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL
85 Version: [2]byte{0x00, 0x01},
86 SubVersion: [2]byte{0x00, 0x02},
91 name: "Invalid SubProtocol",
93 Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP
94 SubProtocol: [4]byte{0x00, 0x00, 0x00, 0x00},
95 Version: [2]byte{0x00, 0x01},
96 SubVersion: [2]byte{0x00, 0x02},
101 name: "Invalid Protocol and SubProtocol",
103 Protocol: [4]byte{0x00, 0x00, 0x00, 0x00},
104 SubProtocol: [4]byte{0x00, 0x00, 0x00, 0x00},
105 Version: [2]byte{0x00, 0x01},
106 SubVersion: [2]byte{0x00, 0x02},
111 name: "Valid Handshake with Different Version",
113 Protocol: [4]byte{0x54, 0x52, 0x54, 0x50}, // TRTP
114 SubProtocol: [4]byte{0x48, 0x4F, 0x54, 0x4C}, // HOTL
115 Version: [2]byte{0x00, 0x02},
116 SubVersion: [2]byte{0x00, 0x03},
122 for _, tt := range tests {
123 t.Run(tt.name, func(t *testing.T) {
124 result := tt.input.Valid()
125 if result != tt.expected {
126 t.Fatalf("expected %v, got %v", tt.expected, result)
132 // readWriteBuffer combines input and output buffers to implement io.ReadWriter
133 type readWriteBuffer struct {
138 func (rw *readWriteBuffer) Read(p []byte) (int, error) {
139 return rw.input.Read(p)
142 func (rw *readWriteBuffer) Write(p []byte) (int, error) {
143 return rw.output.Write(p)
146 func TestPerformHandshake(t *testing.T) {
150 expectedOutput []byte
154 name: "Valid Handshake",
156 0x54, 0x52, 0x54, 0x50, // TRTP
157 0x48, 0x4F, 0x54, 0x4C, // HOTL
158 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2
160 expectedOutput: []byte{0x54, 0x52, 0x54, 0x50, 0x00, 0x00, 0x00, 0x00},
164 name: "Invalid Handshake Size",
166 0x54, 0x52, 0x54, 0x50, // TRTP
169 expectedError: "failed to read handshake data: invalid handshake size",
172 name: "Invalid Protocol",
174 0x00, 0x00, 0x00, 0x00, // Invalid protocol
175 0x48, 0x4F, 0x54, 0x4C, // HOTL
176 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2
179 expectedError: "invalid protocol or sub-protocol in handshake",
182 name: "Invalid SubProtocol",
184 0x54, 0x52, 0x54, 0x50, // TRTP
185 0x00, 0x00, 0x00, 0x00, // Invalid sub-protocol
186 0x00, 0x01, 0x00, 0x02, // Version 1, SubVersion 2
189 expectedError: "invalid protocol or sub-protocol in handshake",
192 name: "Binary Read Error",
194 0xFF, 0xFF, 0xFF, 0xFF, // Invalid data
195 0xFF, 0xFF, 0xFF, 0xFF,
196 0xFF, 0xFF, 0xFF, 0xFF,
199 expectedError: "invalid protocol or sub-protocol in handshake",
203 for _, tt := range tests {
204 t.Run(tt.name, func(t *testing.T) {
205 inputBuffer := bytes.NewBuffer(tt.input)
206 outputBuffer := &bytes.Buffer{}
207 rw := &readWriteBuffer{
209 output: outputBuffer,
212 err := performHandshake(rw)
214 if tt.expectedError != "" {
215 if err == nil || err.Error() != tt.expectedError {
216 t.Fatalf("expected error %q, got %q", tt.expectedError, err)
220 t.Fatalf("unexpected error: %v", err)
222 output := outputBuffer.Bytes()
223 if !bytes.Equal(output, tt.expectedOutput) {
224 t.Fatalf("expected output %v, got %v", tt.expectedOutput, output)