From: Jeff Halter Date: Sun, 9 Jun 2024 02:25:12 +0000 (-0700) Subject: Refactor TrackerRegistration Read to io.Reader interface X-Git-Url: https://git.r.bdr.sh/rbdr/mobius/commitdiff_plain/9a75b7cbfa6fa37ed10732bbff8b722d652a1cdc?hp=2b7fabb5a1ff6092dd3ea62bd882dd0c02951b81 Refactor TrackerRegistration Read to io.Reader interface --- diff --git a/hotline/tracker.go b/hotline/tracker.go index bd7aa00..a56059f 100644 --- a/hotline/tracker.go +++ b/hotline/tracker.go @@ -4,49 +4,49 @@ import ( "bufio" "encoding/binary" "fmt" - "github.com/jhalter/mobius/concat" + "io" "net" + "slices" "strconv" "time" ) // TrackerRegistration represents the payload a Hotline server sends to a Tracker to register type TrackerRegistration struct { - Port [2]byte // Server listening port number + Port [2]byte // Server's listening TCP port number UserCount int // Number of users connected to this particular server - PassID []byte // Random number generated by the server + PassID [4]byte // Random number generated by the server Name string // Server name Description string // Description of the server } -// TODO: reimplement as io.Reader -func (tr *TrackerRegistration) Read() []byte { +// Read implements io.Reader to write tracker registration payload bytes to slice +func (tr *TrackerRegistration) Read(p []byte) (int, error) { userCount := make([]byte, 2) binary.BigEndian.PutUint16(userCount, uint16(tr.UserCount)) - return concat.Slices( - []byte{0x00, 0x01}, + return copy(p, slices.Concat( + []byte{0x00, 0x01}, // Magic number, always 1 tr.Port[:], userCount, - []byte{0x00, 0x00}, - tr.PassID, + []byte{0x00, 0x00}, // Magic number, always 0 + tr.PassID[:], []byte{uint8(len(tr.Name))}, []byte(tr.Name), []byte{uint8(len(tr.Description))}, []byte(tr.Description), - ) + )[:]), io.EOF } func register(tracker string, tr *TrackerRegistration) error { conn, err := net.Dial("udp", tracker) if err != nil { - return err + return fmt.Errorf("failed to dial tracker: %w", err) } + defer conn.Close() - b := tr.Read() - - if _, err := conn.Write(b); err != nil { - return err + if _, err := io.Copy(conn, tr); err != nil { + return fmt.Errorf("failed to write to connection: %w", err) } return nil diff --git a/hotline/tracker_test.go b/hotline/tracker_test.go index 481f916..b5c5c55 100644 --- a/hotline/tracker_test.go +++ b/hotline/tracker_test.go @@ -3,6 +3,7 @@ package hotline import ( "fmt" "github.com/stretchr/testify/assert" + "io" "reflect" "testing" ) @@ -11,7 +12,7 @@ func TestTrackerRegistration_Payload(t *testing.T) { type fields struct { Port [2]byte UserCount int - PassID []byte + PassID [4]byte Name string Description string } @@ -25,7 +26,7 @@ func TestTrackerRegistration_Payload(t *testing.T) { fields: fields{ Port: [2]byte{0x00, 0x10}, UserCount: 2, - PassID: []byte{0x00, 0x00, 0x00, 0x01}, + PassID: [4]byte{0x00, 0x00, 0x00, 0x01}, Name: "Test Serv", Description: "Fooz", }, @@ -51,7 +52,8 @@ func TestTrackerRegistration_Payload(t *testing.T) { Name: tt.fields.Name, Description: tt.fields.Description, } - if got := tr.Read(); !reflect.DeepEqual(got, tt.want) { + + if got, _ := io.ReadAll(tr); !reflect.DeepEqual(got, tt.want) { t.Errorf("Read() = %v, want %v", got, tt.want) } })