diff options
Diffstat (limited to 'p2p/netutil')
-rw-r--r-- | p2p/netutil/net.go | 131 | ||||
-rw-r--r-- | p2p/netutil/net_test.go | 89 |
2 files changed, 220 insertions, 0 deletions
diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go index f6005afd2..656abb682 100644 --- a/p2p/netutil/net.go +++ b/p2p/netutil/net.go @@ -18,8 +18,11 @@ package netutil import ( + "bytes" "errors" + "fmt" "net" + "sort" "strings" ) @@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error { } return nil } + +// SameNet reports whether two IP addresses have an equal prefix of the given bit length. +func SameNet(bits uint, ip, other net.IP) bool { + ip4, other4 := ip.To4(), other.To4() + switch { + case (ip4 == nil) != (other4 == nil): + return false + case ip4 != nil: + return sameNet(bits, ip4, other4) + default: + return sameNet(bits, ip.To16(), other.To16()) + } +} + +func sameNet(bits uint, ip, other net.IP) bool { + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask { + return false + } + return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb]) +} + +// DistinctNetSet tracks IPs, ensuring that at most N of them +// fall into the same network range. +type DistinctNetSet struct { + Subnet uint // number of common prefix bits + Limit uint // maximum number of IPs in each subnet + + members map[string]uint + buf net.IP +} + +// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the +// number of existing IPs in the defined range exceeds the limit. +func (s *DistinctNetSet) Add(ip net.IP) bool { + key := s.key(ip) + n := s.members[string(key)] + if n < s.Limit { + s.members[string(key)] = n + 1 + return true + } + return false +} + +// Remove removes an IP from the set. +func (s *DistinctNetSet) Remove(ip net.IP) { + key := s.key(ip) + if n, ok := s.members[string(key)]; ok { + if n == 1 { + delete(s.members, string(key)) + } else { + s.members[string(key)] = n - 1 + } + } +} + +// Contains whether the given IP is contained in the set. +func (s DistinctNetSet) Contains(ip net.IP) bool { + key := s.key(ip) + _, ok := s.members[string(key)] + return ok +} + +// Len returns the number of tracked IPs. +func (s DistinctNetSet) Len() int { + n := uint(0) + for _, i := range s.members { + n += i + } + return int(n) +} + +// key encodes the map key for an address into a temporary buffer. +// +// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types. +// The remainder of the key is the IP, truncated to the number of bits. +func (s *DistinctNetSet) key(ip net.IP) net.IP { + // Lazily initialize storage. + if s.members == nil { + s.members = make(map[string]uint) + s.buf = make(net.IP, 17) + } + // Canonicalize ip and bits. + typ := byte('6') + if ip4 := ip.To4(); ip4 != nil { + typ, ip = '4', ip4 + } + bits := s.Subnet + if bits > uint(len(ip)*8) { + bits = uint(len(ip) * 8) + } + // Encode the prefix into s.buf. + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + s.buf[0] = typ + buf := append(s.buf[:1], ip[:nb]...) + if nb < len(ip) && mask != 0 { + buf = append(buf, ip[nb]&mask) + } + return buf +} + +// String implements fmt.Stringer +func (s DistinctNetSet) String() string { + var buf bytes.Buffer + buf.WriteString("{") + keys := make([]string, 0, len(s.members)) + for k := range s.members { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + var ip net.IP + if k[0] == '4' { + ip = make(net.IP, 4) + } else { + ip = make(net.IP, 16) + } + copy(ip, k[1:]) + fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k]) + if i != len(keys)-1 { + buf.WriteString(" ") + } + } + buf.WriteString("}") + return buf.String() +} diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go index 1ee1fcb4d..3a6aa081f 100644 --- a/p2p/netutil/net_test.go +++ b/p2p/netutil/net_test.go @@ -17,9 +17,11 @@ package netutil import ( + "fmt" "net" "reflect" "testing" + "testing/quick" "github.com/davecgh/go-spew/spew" ) @@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) { CheckRelayIP(sender, addr) } } + +func TestSameNet(t *testing.T) { + tests := []struct { + ip, other string + bits uint + want bool + }{ + {"0.0.0.0", "0.0.0.0", 32, true}, + {"0.0.0.0", "0.0.0.1", 0, true}, + {"0.0.0.0", "0.0.0.1", 31, true}, + {"0.0.0.0", "0.0.0.1", 32, false}, + {"0.33.0.1", "0.34.0.2", 8, true}, + {"0.33.0.1", "0.34.0.2", 13, true}, + {"0.33.0.1", "0.34.0.2", 15, false}, + } + + for _, test := range tests { + if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want { + t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want) + } + } +} + +func ExampleSameNet() { + // This returns true because the IPs are in the same /24 network: + fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3})) + // This call returns false: + fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3})) + // Output: + // true + // false +} + +func TestDistinctNetSet(t *testing.T) { + ops := []struct { + add, remove string + fails bool + }{ + {add: "127.0.0.1"}, + {add: "127.0.0.2"}, + {add: "127.0.0.3", fails: true}, + {add: "127.32.0.1"}, + {add: "127.32.0.2"}, + {add: "127.32.0.3", fails: true}, + {add: "127.33.0.1", fails: true}, + {add: "127.34.0.1"}, + {add: "127.34.0.2"}, + {add: "127.34.0.3", fails: true}, + // Make room for an address, then add again. + {remove: "127.0.0.1"}, + {add: "127.0.0.3"}, + {add: "127.0.0.3", fails: true}, + } + + set := DistinctNetSet{Subnet: 15, Limit: 2} + for _, op := range ops { + var desc string + if op.add != "" { + desc = fmt.Sprintf("Add(%s)", op.add) + if ok := set.Add(parseIP(op.add)); ok != !op.fails { + t.Errorf("%s == %t, want %t", desc, ok, !op.fails) + } + } else { + desc = fmt.Sprintf("Remove(%s)", op.remove) + set.Remove(parseIP(op.remove)) + } + t.Logf("%s: %v", desc, set) + } +} + +func TestDistinctNetSetAddRemove(t *testing.T) { + cfg := &quick.Config{} + fn := func(ips []net.IP) bool { + s := DistinctNetSet{Limit: 3, Subnet: 2} + for _, ip := range ips { + s.Add(ip) + } + for _, ip := range ips { + s.Remove(ip) + } + return s.Len() == 0 + } + + if err := quick.Check(fn, cfg); err != nil { + t.Fatal(err) + } +} |