// Copyright 2017 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package enr import ( "bytes" "encoding/binary" "fmt" "math/rand" "testing" "time" "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var rnd = rand.New(rand.NewSource(time.Now().UnixNano())) func randomString(strlen int) string { b := make([]byte, strlen) rnd.Read(b) return string(b) } // TestGetSetID tests encoding/decoding and setting/getting of the ID key. func TestGetSetID(t *testing.T) { id := ID("someid") var r Record r.Set(id) var id2 ID require.NoError(t, r.Load(&id2)) assert.Equal(t, id, id2) } // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key. func TestGetSetIP4(t *testing.T) { ip := IP{192, 168, 0, 3} var r Record r.Set(ip) var ip2 IP require.NoError(t, r.Load(&ip2)) assert.Equal(t, ip, ip2) } // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP key. func TestGetSetIP6(t *testing.T) { ip := IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} var r Record r.Set(ip) var ip2 IP require.NoError(t, r.Load(&ip2)) assert.Equal(t, ip, ip2) } // TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key. func TestGetSetUDP(t *testing.T) { port := UDP(30309) var r Record r.Set(port) var port2 UDP require.NoError(t, r.Load(&port2)) assert.Equal(t, port, port2) } func TestLoadErrors(t *testing.T) { var r Record ip4 := IP{127, 0, 0, 1} r.Set(ip4) // Check error for missing keys. var udp UDP err := r.Load(&udp) if !IsNotFound(err) { t.Error("IsNotFound should return true for missing key") } assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err) // Check error for invalid keys. var list []uint err = r.Load(WithEntry(ip4.ENRKey(), &list)) kerr, ok := err.(*KeyError) if !ok { t.Fatalf("expected KeyError, got %T", err) } assert.Equal(t, kerr.Key, ip4.ENRKey()) assert.Error(t, kerr.Err) if IsNotFound(err) { t.Error("IsNotFound should return false for decoding errors") } } // TestSortedGetAndSet tests that Set produced a sorted pairs slice. func TestSortedGetAndSet(t *testing.T) { type pair struct { k string v uint32 } for _, tt := range []struct { input []pair want []pair }{ { input: []pair{{"a", 1}, {"c", 2}, {"b", 3}}, want: []pair{{"a", 1}, {"b", 3}, {"c", 2}}, }, { input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, }, { input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, }, } { var r Record for _, i := range tt.input { r.Set(WithEntry(i.k, &i.v)) } for i, w := range tt.want { // set got's key from r.pair[i], so that we preserve order of pairs got := pair{k: r.pairs[i].k} assert.NoError(t, r.Load(WithEntry(w.k, &got.v))) assert.Equal(t, w, got) } } } // TestDirty tests record signature removal on setting of new key/value pair in record. func TestDirty(t *testing.T) { var r Record if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { t.Errorf("expected errEncodeUnsigned, got %#v", err) } require.NoError(t, signTest([]byte{5}, &r)) if len(r.signature) == 0 { t.Error("record is not signed") } _, err := rlp.EncodeToBytes(r) assert.NoError(t, err) r.SetSeq(3) if len(r.signature) != 0 { t.Error("signature still set after modification") } if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { t.Errorf("expected errEncodeUnsigned, got %#v", err) } } func TestSeq(t *testing.T) { var r Record assert.Equal(t, uint64(0), r.Seq()) r.Set(UDP(1)) assert.Equal(t, uint64(0), r.Seq()) signTest([]byte{5}, &r) assert.Equal(t, uint64(0), r.Seq()) r.Set(UDP(2)) assert.Equal(t, uint64(1), r.Seq()) } // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. func TestGetSetOverwrite(t *testing.T) { var r Record ip := IP{192, 168, 0, 3} r.Set(ip) ip2 := IP{192, 168, 0, 4} r.Set(ip2) var ip3 IP require.NoError(t, r.Load(&ip3)) assert.Equal(t, ip2, ip3) } // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record. func TestSignEncodeAndDecode(t *testing.T) { var r Record r.Set(UDP(30303)) r.Set(IP{127, 0, 0, 1}) require.NoError(t, signTest([]byte{5}, &r)) blob, err := rlp.EncodeToBytes(r) require.NoError(t, err) var r2 Record require.NoError(t, rlp.DecodeBytes(blob, &r2)) assert.Equal(t, r, r2) blob2, err := rlp.EncodeToBytes(r2) require.NoError(t, err) assert.Equal(t, blob, blob2) } // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed. func TestRecordTooBig(t *testing.T) { var r Record key := randomString(10) // set a big value for random key, expect error r.Set(WithEntry(key, randomString(SizeLimit))) if err := signTest([]byte{5}, &r); err != errTooBig { t.Fatalf("expected to get errTooBig, got %#v", err) } // set an acceptable value for random key, expect no error r.Set(WithEntry(key, randomString(100))) require.NoError(t, signTest([]byte{5}, &r)) } // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs. func TestSignEncodeAndDecodeRandom(t *testing.T) { var r Record // random key/value pairs for testing pairs := map[string]uint32{} for i := 0; i < 10; i++ { key := randomString(7) value := rnd.Uint32() pairs[key] = value r.Set(WithEntry(key, &value)) } require.NoError(t, signTest([]byte{5}, &r)) _, err := rlp.EncodeToBytes(r) require.NoError(t, err) for k, v := range pairs { desc := fmt.Sprintf("key %q", k) var got uint32 buf := WithEntry(k, &got) require.NoError(t, r.Load(buf), desc) require.Equal(t, v, got, desc) } } type testSig struct{} type testID []byte func (id testID) ENRKey() string { return "testid" } func signTest(id []byte, r *Record) error { r.Set(ID("test")) r.Set(testID(id)) return r.SetSig(testSig{}, makeTestSig(id, r.Seq())) } func makeTestSig(id []byte, seq uint64) []byte { sig := make([]byte, 8, len(id)+8) binary.BigEndian.PutUint64(sig[:8], seq) sig = append(sig, id...) return sig } func (testSig) Verify(r *Record, sig []byte) error { var id []byte if err := r.Load((*testID)(&id)); err != nil { return err } if !bytes.Equal(sig, makeTestSig(id, r.Seq())) { return ErrInvalidSig } return nil } func (testSig) NodeAddr(r *Record) []byte { var id []byte if err := r.Load((*testID)(&id)); err != nil { return nil } return id }