// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build go1.21 package quic import ( "bytes" "encoding/binary" "testing" ) func TestDecodePacketNumber(t *testing.T) { for _, test := range []struct { largest packetNumber truncated packetNumber want packetNumber size int }{{ largest: 0, truncated: 1, size: 4, want: 1, }, { largest: 0, truncated: 0, size: 1, want: 0, }, { largest: 0x00, truncated: 0x01, size: 1, want: 0x01, }, { largest: 0x00, truncated: 0xff, size: 1, want: 0xff, }, { largest: 0xff, truncated: 0x01, size: 1, want: 0x101, }, { largest: 0x1000, truncated: 0xff, size: 1, want: 0xfff, }, { largest: 0xa82f30ea, truncated: 0x9b32, size: 2, want: 0xa82f9b32, }} { got := decodePacketNumber(test.largest, test.truncated, test.size) if got != test.want { t.Errorf("decodePacketNumber(largest=0x%x, truncated=0x%x, size=%v) = 0x%x, want 0x%x", test.largest, test.truncated, test.size, got, test.want) } } } func TestEncodePacketNumber(t *testing.T) { for _, test := range []struct { largestAcked packetNumber pnum packetNumber wantSize int }{{ largestAcked: -1, pnum: 0, wantSize: 1, }, { largestAcked: 1000, pnum: 1000 + 0x7f, wantSize: 1, }, { largestAcked: 1000, pnum: 1000 + 0x80, // 0x468 wantSize: 2, }, { largestAcked: 0x12345678, pnum: 0x12345678 + 0x7fff, // 0x305452663 wantSize: 2, }, { largestAcked: 0x12345678, pnum: 0x12345678 + 0x8000, wantSize: 3, }, { largestAcked: 0, pnum: 0x7fffff, wantSize: 3, }, { largestAcked: 0, pnum: 0x800000, wantSize: 4, }, { largestAcked: 0xabe8bc, pnum: 0xac5c02, wantSize: 2, }, { largestAcked: 0xabe8bc, pnum: 0xace8fe, wantSize: 3, }} { size := packetNumberLength(test.pnum, test.largestAcked) if got, want := size, test.wantSize; got != want { t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, want %v", test.pnum, test.largestAcked, got, want) } var enc packetNumber switch size { case 1: enc = test.pnum & 0xff case 2: enc = test.pnum & 0xffff case 3: enc = test.pnum & 0xffffff case 4: enc = test.pnum & 0xffffffff } wantBytes := binary.BigEndian.AppendUint32(nil, uint32(enc))[4-size:] gotBytes := appendPacketNumber(nil, test.pnum, test.largestAcked) if !bytes.Equal(gotBytes, wantBytes) { t.Errorf("appendPacketNumber(num=%v, maxAck=%x) = {%x}, want {%x}", test.pnum, test.largestAcked, gotBytes, wantBytes) } gotNum := decodePacketNumber(test.largestAcked, enc, size) if got, want := gotNum, test.pnum; got != want { t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, but decoded number=%x", test.pnum, test.largestAcked, size, got) } } } func FuzzPacketNumber(f *testing.F) { truncatedNumber := func(in []byte) packetNumber { var truncated packetNumber for _, b := range in { truncated = (truncated << 8) | packetNumber(b) } return truncated } f.Fuzz(func(t *testing.T, in []byte, largestAckedInt64 int64) { largestAcked := packetNumber(largestAckedInt64) if len(in) < 1 || len(in) > 4 || largestAcked < 0 || largestAcked > maxPacketNumber { return } truncatedIn := truncatedNumber(in) decoded := decodePacketNumber(largestAcked, truncatedIn, len(in)) // Check that the decoded packet number's least significant bits match the input. var mask packetNumber for i := 0; i < len(in); i++ { mask = (mask << 8) | 0xff } if truncatedIn != decoded&mask { t.Fatalf("decoding mismatch: input=%x largestAcked=%v decoded=0x%x", in, largestAcked, decoded) } // We don't support encoding packet numbers less than largestAcked (since packet numbers // never decrease), so skip the encoder tests if this would make us go backwards. if decoded < largestAcked { return } // We might encode this number using a different length than we received, // but the common portions should match. encoded := appendPacketNumber(nil, decoded, largestAcked) a, b := in, encoded if len(b) < len(a) { a, b = b, a } for len(a) < len(b) { b = b[1:] } if len(a) == 0 || !bytes.Equal(a, b) { t.Fatalf("encoding mismatch: input=%x largestAcked=%v decoded=%v reencoded=%x", in, largestAcked, decoded, encoded) } if g := decodePacketNumber(largestAcked, truncatedNumber(encoded), len(encoded)); g != decoded { t.Fatalf("packet encode/decode mismatch: pnum=%v largestAcked=%v encoded=%x got=%v", decoded, largestAcked, encoded, g) } if l := packetNumberLength(decoded, largestAcked); l != len(encoded) { t.Fatalf("packet number length mismatch: pnum=%v largestAcked=%v encoded=%x len=%v", decoded, largestAcked, encoded, l) } }) }