diff --git a/go.mod b/go.mod index 2cd5884..3b41b9e 100644 --- a/go.mod +++ b/go.mod @@ -1,2 +1,3 @@ -module git.skdevstudios.com/specCon18/reforgerds-updater -go 1.16 +module git.skdevstudios.com/specCon18/reforgerds-updater + +go 1.18 diff --git a/internal/a2s/client.go b/internal/a2s/client.go new file mode 100644 index 0000000..30cbf3d --- /dev/null +++ b/internal/a2s/client.go @@ -0,0 +1,202 @@ +package a2s + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "time" +) + +// Errors +var ( + ErrPlayerRead = errors.New("failed to read player data") + ErrMultiPacketInvalid = errors.New("invalid multi-packet ID mismatch") + ErrMultiPacketMismatch = errors.New("multi-packet assembly failed") + errBzip2 = errors.New("bzip2 compressed response not supported") +) + +// Flag type for request/response types +type Flag byte + +const ( + PlayerRequest Flag = 0x55 + ChallengeResponse Flag = 0x41 + + DefaultBufferSize = 1400 + DefaultDeadlineTimeout = 5 + singlePacket uint32 = 0xFFFFFFFF +) + +// Client handles connection and options +type Client struct { + Conn *net.UDPConn + Address *net.UDPAddr + Timeout time.Duration + BufferSize uint16 +} + +// New creates a new Client and dials the connection +func New(ip string, port int) (*Client, error) { + return NewWithAddr(&net.UDPAddr{IP: net.ParseIP(ip), Port: port}) +} + +func NewWithString(addr string) (*Client, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return NewWithAddr(udpAddr) +} + +func NewWithAddr(addr *net.UDPAddr) (*Client, error) { + client := &Client{ + Address: addr, + Timeout: DefaultDeadlineTimeout * time.Second, + BufferSize: DefaultBufferSize, + } + return client, client.Dial() +} + +func (c *Client) Dial() error { + conn, err := net.DialUDP("udp", nil, c.Address) + if err != nil { + return err + } + c.Conn = conn + return nil +} + +func (c *Client) Close() error { + return c.Conn.Close() +} + +func (c *Client) SetBufferSize(size uint16) { + c.BufferSize = size +} + +func (c *Client) SetDeadlineTimeout(seconds int) { + c.Timeout = time.Duration(seconds) * time.Second +} + +func (c *Client) Get(requestType Flag) ([]byte, Flag, time.Duration, error) { + resp, duration, err := c.request(requestType, singlePacket) + if err != nil { + return nil, 0, 0, err + } + flag := Flag(resp[4]) + + if flag == ChallengeResponse { + challenge := binary.BigEndian.Uint32(resp[5:9]) + resp, _, err = c.request(requestType, challenge) + if err != nil { + return nil, 0, 0, err + } + flag = Flag(resp[4]) + } + + if err := validateResponseType(requestType, flag); err != nil { + return resp[5:], flag, duration, err + } + + return resp[5:], flag, duration, nil +} + +func (c *Client) request(requestType Flag, challenge uint32) ([]byte, time.Duration, error) { + req, err := createHeader(requestType, challenge) + if err != nil { + return nil, 0, err + } + + start := time.Now() + + if _, err := c.Conn.Write(req); err != nil { + return nil, 0, err + } + if err := c.Conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil { + return nil, 0, err + } + + resp := make([]byte, c.BufferSize) + n, err := c.Conn.Read(resp) + if err != nil { + return nil, 0, err + } + + duration := time.Since(start) + + multi, err := isMultiPacket(resp) + if err != nil { + return resp, 0, err + } + + if !multi { + return resp[:n], duration, nil + } + + packetID := binary.LittleEndian.Uint32(resp[4:8]) + packetCount := int(resp[8] & 0x0F) + currentPacket := int(resp[9] & 0x0F) + + if (packetID & 0x80000000) != 0 { + return nil, 0, errBzip2 + } + + packets := make(map[int][]byte) + packets[currentPacket] = resp[12:n] + + for len(packets) < packetCount { + buf := make([]byte, c.BufferSize) + n, err := c.Conn.Read(buf) + if err != nil { + return nil, 0, err + } + if binary.LittleEndian.Uint32(buf[4:8]) != packetID { + return nil, 0, ErrMultiPacketInvalid + } + + currentPacket = int(buf[9] & 0x0F) + if _, exists := packets[currentPacket]; !exists { + packets[currentPacket] = buf[12:n] + } + } + + var assembledResp []byte + for i := 0; i < packetCount; i++ { + data, exists := packets[i] + if !exists { + return nil, 0, ErrMultiPacketMismatch + } + assembledResp = append(assembledResp, data...) + } + + return assembledResp, duration, nil +} + +// Helpers + +func createHeader(requestType Flag, challenge uint32) ([]byte, error) { + header := []byte{0xFF, 0xFF, 0xFF, 0xFF, byte(requestType)} + if challenge != singlePacket { + challengeBytes := make([]byte, 4) + binary.BigEndian.PutUint32(challengeBytes, challenge) + header = append(header, challengeBytes...) + } + return header, nil +} + +func validateResponseType(request, response Flag) error { + if request == PlayerRequest && response != 0x44 { + return fmt.Errorf("unexpected player response flag: 0x%X", response) + } + return nil +} + +func isMultiPacket(buf []byte) (bool, error) { + if len(buf) < 4 { + return false, errors.New("packet too short") + } + header := binary.LittleEndian.Uint32(buf[:4]) + return header == 0xFFFFFFFE, nil +} + diff --git a/internal/a2s/players.go b/internal/a2s/players.go index 6d6e54b..419d6d5 100644 --- a/internal/a2s/players.go +++ b/internal/a2s/players.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "reforgerds-updater/internal/bread" + "git.skdevstudios.com/specCon18/reforgerds-updater/internal/a2s/bread" ) // https://developer.valvesoftware.com/wiki/Server_queries#Response_Format_2 diff --git a/main.go b/main.go index 77065c8..1e0e8da 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,8 @@ import ( "os" "os/exec" "strings" + "time" + "git.skdevstudios.com/specCon18/reforgerds-updater/internal/a2s" ) @@ -18,9 +20,13 @@ type Update struct { } const ( - updateURL = "http://127.0.0.1:3000/updates" - stateFilePath = "latest_version.txt" + updateURL = "http://127.0.0.1:3000/updates" + stateFilePath = "latest_version.txt" + + serverIP = "127.0.0.1" + serverPort = 17777 ) + func main() { resp, err := http.Get(updateURL) if err != nil { @@ -59,19 +65,38 @@ func main() { if versionCompare(latest, prevVersion) > 0 { fmt.Printf("New version found! %s > %s\n", latest, prevVersion) - // Run steamcmd with reforger_update script - fmt.Println("Running update command...") + // Always update the state file + err := os.WriteFile(stateFilePath, []byte(latest), 0644) + if err != nil { + fmt.Printf("Failed to write version file: %v\n", err) + } + + // Check for online players + players, err := fetchPlayers(serverIP, serverPort) + if err != nil { + fmt.Printf("Error checking players: %v\n", err) + return + } + + if len(players) > 0 { + fmt.Printf("Players are currently online (%d):\n", len(players)) + for _, p := range players { + fmt.Printf("- %-16s | Score: %d | Time: %s\n", p.Name, p.Score, formatDuration(p.Duration)) + } + fmt.Println("Skipping update while players are online.") + return + } + + // No players — run steamcmd + fmt.Println("No players online. Running update command...") cmd := exec.Command("./steamcmd.sh", "+runscript", "reforger_update") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - err := cmd.Run() + err = cmd.Run() if err != nil { fmt.Printf("Update command failed: %v\n", err) return } - - // Store the new latest version - _ = os.WriteFile(stateFilePath, []byte(latest), 0644) } else { fmt.Printf("No new version. Latest seen: %s\n", prevVersion) } @@ -105,3 +130,26 @@ func versionCompare(a, b string) int { return 0 } +func fetchPlayers(ip string, port int) ([]a2s.Player, error) { + client, err := a2s.New(ip, port) + if err != nil { + return nil, fmt.Errorf("create client: %w", err) + } + defer client.Close() + + client.SetBufferSize(2048) + client.SetDeadlineTimeout(3) + + players, err := client.GetPlayers() + if err != nil { + return nil, fmt.Errorf("get players: %w", err) + } + + return *players, nil +} + +func formatDuration(d time.Duration) string { + minutes := int(d.Minutes()) + seconds := int(d.Seconds()) % 60 + return fmt.Sprintf("%02d:%02d", minutes, seconds) +} diff --git a/scratch b/scratch new file mode 100644 index 0000000..f5546c2 --- /dev/null +++ b/scratch @@ -0,0 +1,16 @@ +client, err := a2s.New("127.0.0.1", 27016) +if err != nil { + panic(err) +} + +client.SetBufferSize(2048) +client.SetDeadlineTimeout(3) + +defer client.Close() + +players,err := client.GetPlayers() +if err != nil { + panic(err) +} else { + fmt.Printf("%+v\n", players) +}