aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/node.go83
-rw-r--r--p2p/discover/table.go3
2 files changed, 73 insertions, 13 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index 6662a6cb7..d8a5cc351 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -1,8 +1,10 @@
package discover
import (
+ "bytes"
"crypto/ecdsa"
"crypto/elliptic"
+ "encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -11,13 +13,16 @@ import (
"math/rand"
"net"
"net/url"
+ "os"
"strconv"
"strings"
- "sync"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
+ "github.com/syndtr/goleveldb/leveldb"
+ "github.com/syndtr/goleveldb/leveldb/opt"
+ "github.com/syndtr/goleveldb/leveldb/storage"
)
const nodeIDBits = 512
@@ -308,23 +313,77 @@ func randomID(a NodeID, n int) (b NodeID) {
// nodeDB stores all nodes we know about.
type nodeDB struct {
- mu sync.RWMutex
- byID map[NodeID]*Node
+ ldb *leveldb.DB
+}
+
+var dbVersionKey = []byte("pv")
+
+// Opens the backing LevelDB. If path is "", we use an in-memory database.
+func newNodeDB(path string, version int64) (db *nodeDB, err error) {
+ db = new(nodeDB)
+ opts := new(opt.Options)
+ if path == "" {
+ db.ldb, err = leveldb.Open(storage.NewMemStorage(), opts)
+ } else {
+ db.ldb, err = openLDB(path, opts, version)
+ }
+ return db, err
+}
+
+func openLDB(path string, opts *opt.Options, version int64) (*leveldb.DB, error) {
+ ldb, err := leveldb.OpenFile(path, opts)
+ if _, iscorrupted := err.(leveldb.ErrCorrupted); iscorrupted {
+ ldb, err = leveldb.RecoverFile(path, opts)
+ }
+ if err != nil {
+ return nil, err
+ }
+ // The nodes contained in the database correspond to a certain
+ // protocol version. Flush all nodes if the DB version doesn't match.
+ // There is no need to do this for memory databases because they
+ // won't ever be used with a different protocol version.
+ shouldVal := make([]byte, binary.MaxVarintLen64)
+ shouldVal = shouldVal[:binary.PutVarint(shouldVal, version)]
+ val, err := ldb.Get(dbVersionKey, nil)
+ if err == leveldb.ErrNotFound {
+ err = ldb.Put(dbVersionKey, shouldVal, nil)
+ } else if err == nil && !bytes.Equal(val, shouldVal) {
+ // Delete and start over.
+ ldb.Close()
+ if err = os.RemoveAll(path); err != nil {
+ return nil, err
+ }
+ return openLDB(path, opts, version)
+ }
+ if err != nil {
+ ldb.Close()
+ ldb = nil
+ }
+ return ldb, err
}
func (db *nodeDB) get(id NodeID) *Node {
- db.mu.RLock()
- defer db.mu.RUnlock()
- return db.byID[id]
+ v, err := db.ldb.Get(id[:], nil)
+ if err != nil {
+ return nil
+ }
+ n := new(Node)
+ if err := rlp.DecodeBytes(v, n); err != nil {
+ return nil
+ }
+ return n
}
-func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
- db.mu.Lock()
- defer db.mu.Unlock()
- if db.byID == nil {
- db.byID = make(map[NodeID]*Node)
+func (db *nodeDB) update(n *Node) error {
+ v, err := rlp.EncodeToBytes(n)
+ if err != nil {
+ return err
}
+ return db.ldb.Put(n.ID[:], v, nil)
+}
+
+func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
- db.byID[n.ID] = n
+ db.update(n)
return n
}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index e2e846456..ba2f9b8ec 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -59,9 +59,10 @@ type bucket struct {
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
+ db, _ := newNodeDB("", Version)
tab := &Table{
net: t,
- db: new(nodeDB),
+ db: db,
self: newNode(ourID, ourAddr),
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),