aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/dial.go45
-rw-r--r--p2p/dial_test.go41
-rw-r--r--p2p/discover/table_test.go1
-rw-r--r--p2p/discover/udp.go50
-rw-r--r--p2p/discover/udp_test.go59
-rw-r--r--p2p/discv5/net.go31
-rw-r--r--p2p/discv5/net_test.go29
-rw-r--r--p2p/discv5/sim_test.go2
-rw-r--r--p2p/discv5/udp.go23
-rw-r--r--p2p/discv5/udp_test.go50
-rw-r--r--p2p/discv5/udp_windows.go40
-rw-r--r--p2p/netutil/error.go (renamed from p2p/discv5/udp_notwindows.go)15
-rw-r--r--p2p/netutil/error_test.go73
-rw-r--r--p2p/netutil/net.go166
-rw-r--r--p2p/netutil/net_test.go173
-rw-r--r--p2p/netutil/toobig_notwindows.go (renamed from p2p/discover/udp_notwindows.go)4
-rw-r--r--p2p/netutil/toobig_windows.go (renamed from p2p/discover/udp_windows.go)4
-rw-r--r--p2p/server.go25
18 files changed, 603 insertions, 228 deletions
diff --git a/p2p/dial.go b/p2p/dial.go
index 691b8539e..57fba136a 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -19,6 +19,7 @@ package p2p
import (
"container/heap"
"crypto/rand"
+ "errors"
"fmt"
"net"
"time"
@@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
@@ -48,6 +50,7 @@ const (
type dialstate struct {
maxDynDials int
ntab discoverTable
+ netrestrict *netutil.Netlist
lookupRunning bool
dialing map[discover.NodeID]connFlag
@@ -100,10 +103,11 @@ type waitExpireTask struct {
time.Duration
}
-func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
+func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
+ netrestrict: netrestrict,
static: make(map[discover.NodeID]*dialTask),
dialing: make(map[discover.NodeID]connFlag),
randomNodes: make([]*discover.Node, maxdyn/2),
@@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) {
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
var newtasks []task
- isDialing := func(id discover.NodeID) bool {
- _, found := s.dialing[id]
- return found || peers[id] != nil || s.hist.contains(id)
- }
addDial := func(flag connFlag, n *discover.Node) bool {
- if isDialing(n.ID) {
+ if err := s.checkDial(n, peers); err != nil {
+ glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err)
return false
}
s.dialing[n.ID] = flag
@@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
// Create dials for static nodes if they are not connected.
for id, t := range s.static {
- if !isDialing(id) {
+ err := s.checkDial(t.dest, peers)
+ switch err {
+ case errNotWhitelisted, errSelf:
+ glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err)
+ delete(s.static, t.dest.ID)
+ case nil:
s.dialing[id] = t.flags
newtasks = append(newtasks, t)
}
@@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
return newtasks
}
+var (
+ errSelf = errors.New("is self")
+ errAlreadyDialing = errors.New("already dialing")
+ errAlreadyConnected = errors.New("already connected")
+ errRecentlyDialed = errors.New("recently dialed")
+ errNotWhitelisted = errors.New("not contained in netrestrict whitelist")
+)
+
+func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
+ _, dialing := s.dialing[n.ID]
+ switch {
+ case dialing:
+ return errAlreadyDialing
+ case peers[n.ID] != nil:
+ return errAlreadyConnected
+ case s.ntab != nil && n.ID == s.ntab.Self().ID:
+ return errSelf
+ case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
+ return errNotWhitelisted
+ case s.hist.contains(n.ID):
+ return errRecentlyDialed
+ }
+ return nil
+}
+
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index 05d9b7562..c850233db 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -25,6 +25,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
func init() {
@@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf,
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{
- init: newDialState(nil, fakeTable{}, 5),
+ init: newDialState(nil, fakeTable{}, 5, nil),
rounds: []round{
// A discovery query is launched.
{
@@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(nil, table, 10),
+ init: newDialState(nil, table, 10, nil),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
@@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) {
})
}
+// This test checks that candidates that do not match the netrestrict list are not dialed.
+func TestDialStateNetRestrict(t *testing.T) {
+ // This table always returns the same random nodes
+ // in the order given below.
+ table := fakeTable{
+ {ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
+ {ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
+ {ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
+ {ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
+ {ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
+ {ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
+ {ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
+ {ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
+ }
+ restrict := new(netutil.Netlist)
+ restrict.Add("127.0.2.0/24")
+
+ runDialTest(t, dialtest{
+ init: newDialState(nil, table, 10, restrict),
+ rounds: []round{
+ {
+ new: []task{
+ &dialTask{flags: dynDialedConn, dest: table[4]},
+ &discoverTask{},
+ },
+ },
+ },
+ })
+}
+
// This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*discover.Node{
@@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(wantStatic, fakeTable{}, 0),
+ init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(wantStatic, fakeTable{}, 0),
+ init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) {
resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
table := &resolveMock{answer: resolved}
- state := newDialState(nil, table, 0)
+ state := newDialState(nil, table, 0, nil)
// Check that the task is generated with an incomplete ID.
dest := discover.NewNode(uintID(1), nil, 0, 0)
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 1a2405740..102c7c2d1 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node)
n.sha = hashAtDistance(base, ld)
+ n.IP = net.IP{10, 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n
}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 74758b6fd..e09c63ffb 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -126,8 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
}
-func nodeFromRPC(rn rpcNode) (*Node, error) {
- // TODO: don't accept localhost, LAN addresses from internet hosts
+func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+ if rn.UDP <= 1024 {
+ return nil, errors.New("low port")
+ }
+ if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
+ return nil, err
+ }
+ if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
+ return nil, errors.New("not contained in netrestrict whitelist")
+ }
n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
err := n.validateComplete()
return n, err
@@ -151,6 +160,7 @@ type conn interface {
// udp implements the RPC protocol.
type udp struct {
conn conn
+ netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey
ourEndpoint rpcEndpoint
@@ -201,7 +211,7 @@ type reply struct {
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) {
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
addr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
return nil, err
@@ -210,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
if err != nil {
return nil, err
}
- tab, _, err := newUDP(priv, conn, natm, nodeDBPath)
+ tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
@@ -218,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) {
+func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
udp := &udp{
- conn: c,
- priv: priv,
- closing: make(chan struct{}),
- gotreply: make(chan reply),
- addpending: make(chan *pending),
+ conn: c,
+ priv: priv,
+ netrestrict: netrestrict,
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addpending: make(chan *pending),
}
realaddr := c.LocalAddr().(*net.UDPAddr)
if natm != nil {
@@ -281,9 +292,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
reply := r.(*neighbors)
for _, rn := range reply.Nodes {
nreceived++
- if n, err := nodeFromRPC(rn); err == nil {
- nodes = append(nodes, n)
+ n, err := t.nodeFromRPC(toaddr, rn)
+ if err != nil {
+ glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
+ continue
}
+ nodes = append(nodes, n)
}
return nreceived >= bucketSize
})
@@ -479,13 +493,6 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte,
return packet, nil
}
-func isTemporaryError(err error) bool {
- tempErr, ok := err.(interface {
- Temporary() bool
- })
- return ok && tempErr.Temporary() || isPacketTooBig(err)
-}
-
// readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop() {
defer t.conn.Close()
@@ -495,7 +502,7 @@ func (t *udp) readLoop() {
buf := make([]byte, 1280)
for {
nbytes, from, err := t.conn.ReadFromUDP(buf)
- if isTemporaryError(err) {
+ if netutil.IsTemporaryError(err) {
// Ignore temporary read errors.
glog.V(logger.Debug).Infof("Temporary read error: %v", err)
continue
@@ -602,6 +609,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
for i, n := range closest {
+ if netutil.CheckRelayIP(from.IP, n.IP) != nil {
+ continue
+ }
p.Nodes = append(p.Nodes, nodeToRPC(n))
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
t.send(from, neighborsPacket, p)
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index f43bf3726..53cfac6f9 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -43,56 +43,6 @@ func init() {
spew.Config.DisableMethods = true
}
-// This test checks that isPacketTooBig correctly identifies
-// errors that result from receiving a UDP packet larger
-// than the supplied receive buffer.
-func TestIsPacketTooBig(t *testing.T) {
- listener, err := net.ListenPacket("udp", "127.0.0.1:0")
- if err != nil {
- t.Fatal(err)
- }
- defer listener.Close()
- sender, err := net.Dial("udp", listener.LocalAddr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer sender.Close()
-
- sendN := 1800
- recvN := 300
- for i := 0; i < 20; i++ {
- go func() {
- buf := make([]byte, sendN)
- for i := range buf {
- buf[i] = byte(i)
- }
- sender.Write(buf)
- }()
-
- buf := make([]byte, recvN)
- listener.SetDeadline(time.Now().Add(1 * time.Second))
- n, _, err := listener.ReadFrom(buf)
- if err != nil {
- if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
- continue
- }
- if !isPacketTooBig(err) {
- t.Fatal("unexpected read error:", spew.Sdump(err))
- }
- continue
- }
- if n != recvN {
- t.Fatalf("short read: %d, want %d", n, recvN)
- }
- for i := range buf {
- if buf[i] != byte(i) {
- t.Fatalf("error in pattern")
- break
- }
- }
- }
-}
-
// shared test variables
var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
@@ -118,9 +68,9 @@ func newUDPTest(t *testing.T) *udpTest {
pipe: newpipe(),
localkey: newkey(),
remotekey: newkey(),
- remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
+ remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "")
+ test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
return test
}
@@ -362,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
// check that the sent neighbors are all returned by findnode
select {
case result := <-resultc:
- if !reflect.DeepEqual(result, list) {
- t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
+ want := append(list[:2], list[3:]...)
+ if !reflect.DeepEqual(result, want) {
+ t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, want)
}
case err := <-errc:
t.Errorf("findnode error: %v", err)
diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go
index 7ad6f1e5b..d1c48904e 100644
--- a/p2p/discv5/net.go
+++ b/p2p/discv5/net.go
@@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -45,6 +46,7 @@ const (
bucketRefreshInterval = 1 * time.Minute
seedCount = 30
seedMaxAge = 5 * 24 * time.Hour
+ lowPort = 1024
)
const testTopic = "foo"
@@ -62,8 +64,9 @@ func debugLog(s string) {
// Network manages the table and all protocol interaction.
type Network struct {
- db *nodeDB // database of known nodes
- conn transport
+ db *nodeDB // database of known nodes
+ conn transport
+ netrestrict *netutil.Netlist
closed chan struct{} // closed when loop is done
closeReq chan struct{} // 'request to close'
@@ -132,7 +135,7 @@ type timeoutEvent struct {
node *Node
}
-func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) {
+func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
ourID := PubkeyID(&ourPubkey)
var db *nodeDB
@@ -147,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d
net := &Network{
db: db,
conn: conn,
+ netrestrict: netrestrict,
tab: tab,
topictab: newTopicTable(db, tab.self),
ticketStore: newTicketStore(),
@@ -684,16 +688,22 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node {
return n
}
-func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) {
+func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
if rn.ID == net.tab.self.ID {
return nil, errors.New("is self")
}
+ if rn.UDP <= lowPort {
+ return nil, errors.New("low port")
+ }
n = net.nodes[rn.ID]
if n == nil {
// We haven't seen this node before.
- n, err = nodeFromRPC(rn)
- n.state = unknown
+ n, err = nodeFromRPC(sender, rn)
+ if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
+ return n, errors.New("not contained in netrestrict whitelist")
+ }
if err == nil {
+ n.state = unknown
net.nodes[n.ID] = n
}
return n, err
@@ -1095,7 +1105,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
net.conn.sendNeighbours(n, results)
return n.state, nil
case neighborsPacket:
- err := net.handleNeighboursPacket(n, pkt.data.(*neighbors))
+ err := net.handleNeighboursPacket(n, pkt)
return n.state, err
case neighboursTimeout:
if n.pendingNeighbours != nil {
@@ -1182,17 +1192,18 @@ func rlpHash(x interface{}) (h common.Hash) {
return h
}
-func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error {
+func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
if n.pendingNeighbours == nil {
return errNoQuery
}
net.abortTimedEvent(n, neighboursTimeout)
+ req := pkt.data.(*neighbors)
nodes := make([]*Node, len(req.Nodes))
for i, rn := range req.Nodes {
- nn, err := net.internNodeFromNeighbours(rn)
+ nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
if err != nil {
- glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err)
+ glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err)
continue
}
nodes[i] = nn
diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go
index 422daa33b..327457c7c 100644
--- a/p2p/discv5/net_test.go
+++ b/p2p/discv5/net_test.go
@@ -28,7 +28,7 @@ import (
func TestNetwork_Lookup(t *testing.T) {
key, _ := crypto.GenerateKey()
- network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "")
+ network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
if err != nil {
t.Fatal(err)
}
@@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) {
// t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
// }
// seed table with initial node (otherwise lookup will terminate immediately)
- seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)}
+ seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)}
if err := network.SetFallbackNodes(seeds); err != nil {
t.Fatal(err)
}
@@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) {
func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) {
// current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port)
- if to.UDP == 0 {
- panic("query to node at distance 0")
+ if to.UDP <= lowPort {
+ panic("query to node at or below distance 0")
}
next := to.UDP - 1
var result []rpcNode
- for i, id := range tn.dists[to.UDP] {
- result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
+ for i, id := range tn.dists[to.UDP-lowPort] {
+ result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
}
injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
}
@@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha
// ignored
case findnodeHashPacket:
// current log distance is encoded in port number
- // fmt.Println("findnode query at dist", toaddr.Port)
- if to.UDP == 0 {
- panic("query to node at distance 0")
+ // fmt.Println("findnode query at dist", toaddr.Port-lowPort)
+ if to.UDP <= lowPort {
+ panic("query to node at or below distance 0")
}
next := to.UDP - 1
var result []rpcNode
- for i, id := range tn.dists[to.UDP] {
- result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
+ for i, id := range tn.dists[to.UDP-lowPort] {
+ result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
}
injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
default:
@@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int,
panic("sendTopicRegister called")
}
-func (*preminedTestnet) Close() {}
-func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) }
+func (*preminedTestnet) Close() {}
+
+func (*preminedTestnet) localAddr() *net.UDPAddr {
+ return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000}
+}
// mine generates a testnet struct literal with nodes at
// various distances to the given target.
diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go
index 2e232fbaa..cb64d7fa0 100644
--- a/p2p/discv5/sim_test.go
+++ b/p2p/discv5/sim_test.go
@@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network {
addr := &net.UDPAddr{IP: ip, Port: 30303}
transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
- net, err := newNetwork(transport, key.PublicKey, nil, "<no database>")
+ net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
if err != nil {
panic("cannot launch new node: " + err.Error())
}
diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go
index 46d3200bf..a6114e032 100644
--- a/p2p/discv5/udp.go
+++ b/p2p/discv5/udp.go
@@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
}
-func nodeFromRPC(rn rpcNode) (*Node, error) {
- // TODO: don't accept localhost, LAN addresses from internet hosts
+func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+ if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
+ return nil, err
+ }
n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
err := n.validateComplete()
return n, err
@@ -235,12 +238,12 @@ type udp struct {
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) {
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
transport, err := listenUDP(priv, laddr)
if err != nil {
return nil, err
}
- net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath)
+ net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
@@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
return
}
for i, result := range nodes {
+ if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
+ continue
+ }
p.Nodes = append(p.Nodes, nodeToRPC(result))
if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
@@ -385,7 +391,7 @@ func (t *udp) readLoop() {
buf := make([]byte, 1280)
for {
nbytes, from, err := t.conn.ReadFromUDP(buf)
- if isTemporaryError(err) {
+ if netutil.IsTemporaryError(err) {
// Ignore temporary read errors.
glog.V(logger.Debug).Infof("Temporary read error: %v", err)
continue
@@ -398,13 +404,6 @@ func (t *udp) readLoop() {
}
}
-func isTemporaryError(err error) bool {
- tempErr, ok := err.(interface {
- Temporary() bool
- })
- return ok && tempErr.Temporary() || isPacketTooBig(err)
-}
-
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
pkt := ingressPacket{remoteAddr: from}
if err := decodePacket(buf, &pkt); err != nil {
diff --git a/p2p/discv5/udp_test.go b/p2p/discv5/udp_test.go
index cacc0f004..98c737669 100644
--- a/p2p/discv5/udp_test.go
+++ b/p2p/discv5/udp_test.go
@@ -36,56 +36,6 @@ func init() {
spew.Config.DisableMethods = true
}
-// This test checks that isPacketTooBig correctly identifies
-// errors that result from receiving a UDP packet larger
-// than the supplied receive buffer.
-func TestIsPacketTooBig(t *testing.T) {
- listener, err := net.ListenPacket("udp", "127.0.0.1:0")
- if err != nil {
- t.Fatal(err)
- }
- defer listener.Close()
- sender, err := net.Dial("udp", listener.LocalAddr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer sender.Close()
-
- sendN := 1800
- recvN := 300
- for i := 0; i < 20; i++ {
- go func() {
- buf := make([]byte, sendN)
- for i := range buf {
- buf[i] = byte(i)
- }
- sender.Write(buf)
- }()
-
- buf := make([]byte, recvN)
- listener.SetDeadline(time.Now().Add(1 * time.Second))
- n, _, err := listener.ReadFrom(buf)
- if err != nil {
- if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
- continue
- }
- if !isPacketTooBig(err) {
- t.Fatal("unexpected read error:", spew.Sdump(err))
- }
- continue
- }
- if n != recvN {
- t.Fatalf("short read: %d, want %d", n, recvN)
- }
- for i := range buf {
- if buf[i] != byte(i) {
- t.Fatalf("error in pattern")
- break
- }
- }
- }
-}
-
// shared test variables
var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
diff --git a/p2p/discv5/udp_windows.go b/p2p/discv5/udp_windows.go
deleted file mode 100644
index 1ab9d655e..000000000
--- a/p2p/discv5/udp_windows.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2016 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-//+build windows
-
-package discv5
-
-import (
- "net"
- "os"
- "syscall"
-)
-
-const _WSAEMSGSIZE = syscall.Errno(10040)
-
-// reports whether err indicates that a UDP packet didn't
-// fit the receive buffer. On Windows, WSARecvFrom returns
-// code WSAEMSGSIZE and no data if this happens.
-func isPacketTooBig(err error) bool {
- if opErr, ok := err.(*net.OpError); ok {
- if scErr, ok := opErr.Err.(*os.SyscallError); ok {
- return scErr.Err == _WSAEMSGSIZE
- }
- return opErr.Err == _WSAEMSGSIZE
- }
- return false
-}
diff --git a/p2p/discv5/udp_notwindows.go b/p2p/netutil/error.go
index 4da18d0f6..cb21b9cd4 100644
--- a/p2p/discv5/udp_notwindows.go
+++ b/p2p/netutil/error.go
@@ -14,13 +14,12 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-//+build !windows
+package netutil
-package discv5
-
-// reports whether err indicates that a UDP packet didn't
-// fit the receive buffer. There is no such error on
-// non-Windows platforms.
-func isPacketTooBig(err error) bool {
- return false
+// IsTemporaryError checks whether the given error should be considered temporary.
+func IsTemporaryError(err error) bool {
+ tempErr, ok := err.(interface {
+ Temporary() bool
+ })
+ return ok && tempErr.Temporary() || isPacketTooBig(err)
}
diff --git a/p2p/netutil/error_test.go b/p2p/netutil/error_test.go
new file mode 100644
index 000000000..645e48f83
--- /dev/null
+++ b/p2p/netutil/error_test.go
@@ -0,0 +1,73 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package netutil
+
+import (
+ "net"
+ "testing"
+ "time"
+)
+
+// This test checks that isPacketTooBig correctly identifies
+// errors that result from receiving a UDP packet larger
+// than the supplied receive buffer.
+func TestIsPacketTooBig(t *testing.T) {
+ listener, err := net.ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer listener.Close()
+ sender, err := net.Dial("udp", listener.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer sender.Close()
+
+ sendN := 1800
+ recvN := 300
+ for i := 0; i < 20; i++ {
+ go func() {
+ buf := make([]byte, sendN)
+ for i := range buf {
+ buf[i] = byte(i)
+ }
+ sender.Write(buf)
+ }()
+
+ buf := make([]byte, recvN)
+ listener.SetDeadline(time.Now().Add(1 * time.Second))
+ n, _, err := listener.ReadFrom(buf)
+ if err != nil {
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ continue
+ }
+ if !isPacketTooBig(err) {
+ t.Fatalf("unexpected read error: %v", err)
+ }
+ continue
+ }
+ if n != recvN {
+ t.Fatalf("short read: %d, want %d", n, recvN)
+ }
+ for i := range buf {
+ if buf[i] != byte(i) {
+ t.Fatalf("error in pattern")
+ break
+ }
+ }
+ }
+}
diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go
new file mode 100644
index 000000000..3c3715788
--- /dev/null
+++ b/p2p/netutil/net.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package netutil contains extensions to the net package.
+package netutil
+
+import (
+ "errors"
+ "net"
+ "strings"
+)
+
+var lan4, lan6, special4, special6 Netlist
+
+func init() {
+ // Lists from RFC 5735, RFC 5156,
+ // https://www.iana.org/assignments/iana-ipv4-special-registry/
+ lan4.Add("0.0.0.0/8") // "This" network
+ lan4.Add("10.0.0.0/8") // Private Use
+ lan4.Add("172.16.0.0/12") // Private Use
+ lan4.Add("192.168.0.0/16") // Private Use
+ lan6.Add("fe80::/10") // Link-Local
+ lan6.Add("fc00::/7") // Unique-Local
+ special4.Add("192.0.0.0/29") // IPv4 Service Continuity
+ special4.Add("192.0.0.9/32") // PCP Anycast
+ special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery
+ special4.Add("192.0.0.171/32") // NAT64/DNS64 Discovery
+ special4.Add("192.0.2.0/24") // TEST-NET-1
+ special4.Add("192.31.196.0/24") // AS112
+ special4.Add("192.52.193.0/24") // AMT
+ special4.Add("192.88.99.0/24") // 6to4 Relay Anycast
+ special4.Add("192.175.48.0/24") // AS112
+ special4.Add("198.18.0.0/15") // Device Benchmark Testing
+ special4.Add("198.51.100.0/24") // TEST-NET-2
+ special4.Add("203.0.113.0/24") // TEST-NET-3
+ special4.Add("255.255.255.255/32") // Limited Broadcast
+
+ // http://www.iana.org/assignments/iana-ipv6-special-registry/
+ special6.Add("100::/64")
+ special6.Add("2001::/32")
+ special6.Add("2001:1::1/128")
+ special6.Add("2001:2::/48")
+ special6.Add("2001:3::/32")
+ special6.Add("2001:4:112::/48")
+ special6.Add("2001:5::/32")
+ special6.Add("2001:10::/28")
+ special6.Add("2001:20::/28")
+ special6.Add("2001:db8::/32")
+ special6.Add("2002::/16")
+}
+
+// Netlist is a list of IP networks.
+type Netlist []net.IPNet
+
+// ParseNetlist parses a comma-separated list of CIDR masks.
+// Whitespace and extra commas are ignored.
+func ParseNetlist(s string) (*Netlist, error) {
+ ws := strings.NewReplacer(" ", "", "\n", "", "\t", "")
+ masks := strings.Split(ws.Replace(s), ",")
+ l := make(Netlist, 0)
+ for _, mask := range masks {
+ if mask == "" {
+ continue
+ }
+ _, n, err := net.ParseCIDR(mask)
+ if err != nil {
+ return nil, err
+ }
+ l = append(l, *n)
+ }
+ return &l, nil
+}
+
+// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
+// intended to be used for setting up static lists.
+func (l *Netlist) Add(cidr string) {
+ _, n, err := net.ParseCIDR(cidr)
+ if err != nil {
+ panic(err)
+ }
+ *l = append(*l, *n)
+}
+
+// Contains reports whether the given IP is contained in the list.
+func (l *Netlist) Contains(ip net.IP) bool {
+ if l == nil {
+ return false
+ }
+ for _, net := range *l {
+ if net.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// IsLAN reports whether an IP is a local network address.
+func IsLAN(ip net.IP) bool {
+ if ip.IsLoopback() {
+ return true
+ }
+ if v4 := ip.To4(); v4 != nil {
+ return lan4.Contains(v4)
+ }
+ return lan6.Contains(ip)
+}
+
+// IsSpecialNetwork reports whether an IP is located in a special-use network range
+// This includes broadcast, multicast and documentation addresses.
+func IsSpecialNetwork(ip net.IP) bool {
+ if ip.IsMulticast() {
+ return true
+ }
+ if v4 := ip.To4(); v4 != nil {
+ return special4.Contains(v4)
+ }
+ return special6.Contains(ip)
+}
+
+var (
+ errInvalid = errors.New("invalid IP")
+ errUnspecified = errors.New("zero address")
+ errSpecial = errors.New("special network")
+ errLoopback = errors.New("loopback address from non-loopback host")
+ errLAN = errors.New("LAN address from WAN host")
+)
+
+// CheckRelayIP reports whether an IP relayed from the given sender IP
+// is a valid connection target.
+//
+// There are four rules:
+// - Special network addresses are never valid.
+// - Loopback addresses are OK if relayed by a loopback host.
+// - LAN addresses are OK if relayed by a LAN host.
+// - All other addresses are always acceptable.
+func CheckRelayIP(sender, addr net.IP) error {
+ if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
+ return errInvalid
+ }
+ if addr.IsUnspecified() {
+ return errUnspecified
+ }
+ if IsSpecialNetwork(addr) {
+ return errSpecial
+ }
+ if addr.IsLoopback() && !sender.IsLoopback() {
+ return errLoopback
+ }
+ if IsLAN(addr) && !IsLAN(sender) {
+ return errLAN
+ }
+ return nil
+}
diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go
new file mode 100644
index 000000000..1ee1fcb4d
--- /dev/null
+++ b/p2p/netutil/net_test.go
@@ -0,0 +1,173 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package netutil
+
+import (
+ "net"
+ "reflect"
+ "testing"
+
+ "github.com/davecgh/go-spew/spew"
+)
+
+func TestParseNetlist(t *testing.T) {
+ var tests = []struct {
+ input string
+ wantErr error
+ wantList *Netlist
+ }{
+ {
+ input: "",
+ wantList: &Netlist{},
+ },
+ {
+ input: "127.0.0.0/8",
+ wantErr: nil,
+ wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}},
+ },
+ {
+ input: "127.0.0.0/44",
+ wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"},
+ },
+ {
+ input: "127.0.0.0/16, 23.23.23.23/24,",
+ wantList: &Netlist{
+ {IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)},
+ {IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ l, err := ParseNetlist(test.input)
+ if !reflect.DeepEqual(err, test.wantErr) {
+ t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr)
+ continue
+ }
+ if !reflect.DeepEqual(l, test.wantList) {
+ spew.Dump(l)
+ spew.Dump(test.wantList)
+ t.Errorf("%q: got %v, want %v", test.input, l, test.wantList)
+ }
+ }
+}
+
+func TestNilNetListContains(t *testing.T) {
+ var list *Netlist
+ checkContains(t, list.Contains, nil, []string{"1.2.3.4"})
+}
+
+func TestIsLAN(t *testing.T) {
+ checkContains(t, IsLAN,
+ []string{ // included
+ "0.0.0.0",
+ "0.2.0.8",
+ "127.0.0.1",
+ "10.0.1.1",
+ "10.22.0.3",
+ "172.31.252.251",
+ "192.168.1.4",
+ "fe80::f4a1:8eff:fec5:9d9d",
+ "febf::ab32:2233",
+ "fc00::4",
+ },
+ []string{ // excluded
+ "192.0.2.1",
+ "1.0.0.0",
+ "172.32.0.1",
+ "fec0::2233",
+ },
+ )
+}
+
+func TestIsSpecialNetwork(t *testing.T) {
+ checkContains(t, IsSpecialNetwork,
+ []string{ // included
+ "192.0.2.1",
+ "192.0.2.44",
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ "255.255.255.255",
+ "224.0.0.22", // IPv4 multicast
+ "ff05::1:3", // IPv6 multicast
+ },
+ []string{ // excluded
+ "192.0.3.1",
+ "1.0.0.0",
+ "172.32.0.1",
+ "fec0::2233",
+ },
+ )
+}
+
+func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) {
+ for _, s := range inc {
+ if !fn(parseIP(s)) {
+ t.Error("returned false for included address", s)
+ }
+ }
+ for _, s := range exc {
+ if fn(parseIP(s)) {
+ t.Error("returned true for excluded address", s)
+ }
+ }
+}
+
+func parseIP(s string) net.IP {
+ ip := net.ParseIP(s)
+ if ip == nil {
+ panic("invalid " + s)
+ }
+ return ip
+}
+
+func TestCheckRelayIP(t *testing.T) {
+ tests := []struct {
+ sender, addr string
+ want error
+ }{
+ {"127.0.0.1", "0.0.0.0", errUnspecified},
+ {"192.168.0.1", "0.0.0.0", errUnspecified},
+ {"23.55.1.242", "0.0.0.0", errUnspecified},
+ {"127.0.0.1", "255.255.255.255", errSpecial},
+ {"192.168.0.1", "255.255.255.255", errSpecial},
+ {"23.55.1.242", "255.255.255.255", errSpecial},
+ {"192.168.0.1", "127.0.2.19", errLoopback},
+ {"23.55.1.242", "192.168.0.1", errLAN},
+
+ {"127.0.0.1", "127.0.2.19", nil},
+ {"127.0.0.1", "192.168.0.1", nil},
+ {"127.0.0.1", "23.55.1.242", nil},
+ {"192.168.0.1", "192.168.0.1", nil},
+ {"192.168.0.1", "23.55.1.242", nil},
+ {"23.55.1.242", "23.55.1.242", nil},
+ }
+
+ for _, test := range tests {
+ err := CheckRelayIP(parseIP(test.sender), parseIP(test.addr))
+ if err != test.want {
+ t.Errorf("%s from %s: got %q, want %q", test.addr, test.sender, err, test.want)
+ }
+ }
+}
+
+func BenchmarkCheckRelayIP(b *testing.B) {
+ sender := parseIP("23.55.1.242")
+ addr := parseIP("23.55.1.2")
+ for i := 0; i < b.N; i++ {
+ CheckRelayIP(sender, addr)
+ }
+}
diff --git a/p2p/discover/udp_notwindows.go b/p2p/netutil/toobig_notwindows.go
index e9de83aa9..47b643857 100644
--- a/p2p/discover/udp_notwindows.go
+++ b/p2p/netutil/toobig_notwindows.go
@@ -16,9 +16,9 @@
//+build !windows
-package discover
+package netutil
-// reports whether err indicates that a UDP packet didn't
+// isPacketTooBig reports whether err indicates that a UDP packet didn't
// fit the receive buffer. There is no such error on
// non-Windows platforms.
func isPacketTooBig(err error) bool {
diff --git a/p2p/discover/udp_windows.go b/p2p/netutil/toobig_windows.go
index 66bbf9597..dfbb6d44f 100644
--- a/p2p/discover/udp_windows.go
+++ b/p2p/netutil/toobig_windows.go
@@ -16,7 +16,7 @@
//+build windows
-package discover
+package netutil
import (
"net"
@@ -26,7 +26,7 @@ import (
const _WSAEMSGSIZE = syscall.Errno(10040)
-// reports whether err indicates that a UDP packet didn't
+// isPacketTooBig reports whether err indicates that a UDP packet didn't
// fit the receive buffer. On Windows, WSARecvFrom returns
// code WSAEMSGSIZE and no data if this happens.
func isPacketTooBig(err error) bool {
diff --git a/p2p/server.go b/p2p/server.go
index 7381127dc..cf9672e2d 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
@@ -101,6 +102,11 @@ type Config struct {
// allowed to connect, even above the peer limit.
TrustedNodes []*discover.Node
+ // Connectivity can be restricted to certain IP networks.
+ // If this option is set to a non-nil value, only hosts which match one of the
+ // IP networks contained in the list are considered.
+ NetRestrict *netutil.Netlist
+
// NodeDatabase is the path to the database containing the previously seen
// live nodes in the network.
NodeDatabase string
@@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) {
// node table
if srv.Discovery {
- ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
+ ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
if err != nil {
return err
}
@@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) {
}
if srv.DiscoveryV5 {
- ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase)
+ ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
if err != nil {
return err
}
@@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) {
if !srv.Discovery {
dynPeers = 0
}
- dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers)
+ dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
@@ -634,8 +640,19 @@ func (srv *Server) listenLoop() {
}
break
}
+
+ // Reject connections that do not match NetRestrict.
+ if srv.NetRestrict != nil {
+ if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
+ glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr())
+ fd.Close()
+ slots <- struct{}{}
+ continue
+ }
+ }
+
fd = newMeteredConn(fd, true)
- glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
+ glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection
// has been established.