diff options
Diffstat (limited to 'trie/trie.go')
-rw-r--r-- | trie/trie.go | 639 |
1 files changed, 346 insertions, 293 deletions
diff --git a/trie/trie.go b/trie/trie.go index abf48a850..aa8d39fe2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -19,372 +19,425 @@ package trie import ( "bytes" - "container/list" + "errors" "fmt" - "sync" + "hash" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rlp" ) -func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(nil, backend) +const defaultCacheCapacity = 800 - it := t1.Iterator() - for it.Next() { - t2.Update(it.Key, it.Value) - } - - return bytes.Equal(t2.Hash(), t1.Hash()), t2 -} - -type Trie struct { - mu sync.Mutex - root Node - roothash []byte - cache *Cache - - revisions *list.List -} - -func New(root []byte, backend Backend) *Trie { - trie := &Trie{} - trie.revisions = list.New() - trie.roothash = root - if backend != nil { - trie.cache = NewCache(backend) - } +var ( + // The global cache stores decoded trie nodes by hash as they get loaded. + globalCache = newARC(defaultCacheCapacity) + // This is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") +) - if root != nil { - value := common.NewValueFromBytes(trie.cache.Get(root)) - trie.root = trie.mknode(value) - } +var ErrMissingRoot = errors.New("missing root node") - return trie +// Database must be implemented by backing stores for the trie. +type Database interface { + DatabaseWriter + // Get returns the value for key from the database. + Get(key []byte) (value []byte, err error) } -func (self *Trie) Iterator() *Iterator { - return NewIterator(self) +// DatabaseWriter wraps the Put method of a backing store for the trie. +type DatabaseWriter interface { + // Put stores the mapping key->value in the database. + // Implementations must not hold onto the value bytes, the trie + // will reuse the slice across calls to Put. + Put(key, value []byte) error } -func (self *Trie) Copy() *Trie { - cpy := make([]byte, 32) - copy(cpy, self.roothash) // NOTE: cpy isn't being used anywhere? - trie := New(nil, nil) - trie.cache = self.cache.Copy() - if self.root != nil { - trie.root = self.root.Copy(trie) - } - - return trie +// Trie is a Merkle Patricia Trie. +// The zero value is an empty trie with no database. +// Use New to create a trie that sits on top of a database. +// +// Trie is not safe for concurrent use. +type Trie struct { + root node + db Database + *hasher } -// Legacy support -func (self *Trie) Root() []byte { return self.Hash() } -func (self *Trie) Hash() []byte { - var hash []byte - if self.root != nil { - t := self.root.Hash() - if byts, ok := t.([]byte); ok && len(byts) > 0 { - hash = byts - } else { - hash = crypto.Sha3(common.Encode(self.root.RlpData())) +// New creates a trie with an existing root node from db. +// +// If root is the zero hash or the sha3 hash of an empty string, the +// trie is initially empty and does not require a database. Otherwise, +// New will panics if db is nil or root does not exist in the +// database. Accessing the trie loads nodes from db on demand. +func New(root common.Hash, db Database) (*Trie, error) { + trie := &Trie{db: db} + if (root != common.Hash{}) && root != emptyRoot { + if db == nil { + panic("trie.New: cannot use existing root without a database") } - } else { - hash = crypto.Sha3(common.Encode("")) - } - - if !bytes.Equal(hash, self.roothash) { - self.revisions.PushBack(self.roothash) - self.roothash = hash + if v, _ := trie.db.Get(root[:]); len(v) == 0 { + return nil, ErrMissingRoot + } + trie.root = hashNode(root.Bytes()) } - - return hash + return trie, nil } -func (self *Trie) Commit() { - self.mu.Lock() - defer self.mu.Unlock() - // Hash first - self.Hash() - - self.cache.Flush() +// Iterator returns an iterator over all mappings in the trie. +func (t *Trie) Iterator() *Iterator { + return NewIterator(t) } -// Reset should only be called if the trie has been hashed -func (self *Trie) Reset() { - self.mu.Lock() - defer self.mu.Unlock() - - self.cache.Reset() - - if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).([]byte) - self.roothash = revision +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *Trie) Get(key []byte) []byte { + key = compactHexDecode(key) + tn := t.root + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + return nil + } + tn = n.Val + key = key[len(n.Key):] + case fullNode: + tn = n[key[0]] + key = key[1:] + case nil: + return nil + case hashNode: + tn = t.resolveHash(n) + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } } - value := common.NewValueFromBytes(self.cache.Get(self.roothash)) - self.root = self.mknode(value) + return tn.(valueNode) } -func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } -func (self *Trie) Update(key, value []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *Trie) Update(key, value []byte) { + k := compactHexDecode(key) if len(value) != 0 { - node := NewValueNode(self, value) - node.dirty = true - self.root = self.insert(self.root, k, node) + t.root = t.insert(t.root, k, valueNode(value)) } else { - self.root = self.delete(self.root, k) + t.root = t.delete(t.root, k) } - - return self.root -} - -func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } -func (self *Trie) Get(key []byte) []byte { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - - n := self.get(self.root, k) - if n != nil { - return n.(*ValueNode).Val() - } - - return nil } -func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } -func (self *Trie) Delete(key []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - self.root = self.delete(self.root, k) - - return self.root -} - -func (self *Trie) insert(node Node, key []byte, value Node) Node { +func (t *Trie) insert(n node, key []byte, value node) node { if len(key) == 0 { return value } - - if node == nil { - node := NewShortNode(self, key, value) - node.dirty = true - return node - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(k, key) { - node := NewShortNode(self, key, value) - node.dirty = true - return node - + switch n := n.(type) { + case shortNode: + matchlen := prefixLen(key, n.Key) + // If the whole key matches, keep this short node as is + // and only update the value. + if matchlen == len(n.Key) { + return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)} } - - var n Node - matchlength := MatchingNibbleLength(key, k) - if matchlength == len(k) { - n = self.insert(cnode, key[matchlength:], value) - } else { - pnode := self.insert(nil, k[matchlength+1:], cnode) - nnode := self.insert(nil, key[matchlength+1:], value) - fulln := NewFullNode(self) - fulln.dirty = true - fulln.set(k[matchlength], pnode) - fulln.set(key[matchlength], nnode) - n = fulln - } - if matchlength == 0 { - return n + // Otherwise branch out at the index where they differ. + var branch fullNode + branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val) + branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value) + // Replace this shortNode with the branch if it occurs at index 0. + if matchlen == 0 { + return branch } + // Otherwise, replace it with a short node leading up to the branch. + return shortNode{key[:matchlen], branch} - snode := NewShortNode(self, key[:matchlength], n) - snode.dirty = true - return snode + case fullNode: + n[key[0]] = t.insert(n[key[0]], key[1:], value) + return n - case *FullNode: - cpy := node.Copy(self).(*FullNode) - cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) - cpy.dirty = true + case nil: + return shortNode{key, value} - return cpy + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and insert into it. This leaves all child nodes on + // the path to the value in the trie. + // + // TODO: track whether insertion changed the value and keep + // n as a hash node if it didn't. + return t.insert(t.resolveHash(n), key, value) default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } } -func (self *Trie) get(node Node, key []byte) Node { - if len(key) == 0 { - return node - } - - if node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - - if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { - return self.get(cnode, key[len(k):]) - } - - return nil - case *FullNode: - return self.get(node.branch(key[0]), key[1:]) - default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) - } +// Delete removes any existing value for key from the trie. +func (t *Trie) Delete(key []byte) { + k := compactHexDecode(key) + t.root = t.delete(t.root, k) } -func (self *Trie) delete(node Node, key []byte) Node { - if len(key) == 0 && node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(key, k) { - return nil - } else if bytes.Equal(key[:len(k)], k) { - child := self.delete(cnode, key[len(k):]) - - var n Node - switch child := child.(type) { - case *ShortNode: - nkey := append(k, child.Key()...) - n = NewShortNode(self, nkey, child.Value()) - n.(*ShortNode).dirty = true - case *FullNode: - sn := NewShortNode(self, node.Key(), child) - sn.dirty = true - sn.key = node.key - n = sn - } - - return n - } else { - return node +// delete returns the new root of the trie with key deleted. +// It reduces the trie to minimal form by simplifying +// nodes on the way up after deleting recursively. +func (t *Trie) delete(n node, key []byte) node { + switch n := n.(type) { + case shortNode: + matchlen := prefixLen(key, n.Key) + if matchlen < len(n.Key) { + return n // don't replace n on mismatch + } + if matchlen == len(key) { + return nil // remove n entirely for whole matches + } + // The key is longer than n.Key. Remove the remaining suffix + // from the subtrie. Child can never be nil here since the + // subtrie must contain at least two other values with keys + // longer than n.Key. + child := t.delete(n.Val, key[len(n.Key):]) + switch child := child.(type) { + case shortNode: + // Deleting from the subtrie reduced it to another + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which + // always creates a new slice) instead of append to + // avoid modifying n.Key since it might be shared with + // other nodes. + return shortNode{concat(n.Key, child.Key...), child.Val} + default: + return shortNode{n.Key, child} } - case *FullNode: - n := node.Copy(self).(*FullNode) - n.set(key[0], self.delete(n.branch(key[0]), key[1:])) - n.dirty = true - + case fullNode: + n[key[0]] = t.delete(n[key[0]], key[1:]) + // Check how many non-nil entries are left after deleting and + // reduce the full node to a short node if only one entry is + // left. Since n must've contained at least two children + // before deletion (otherwise it would not be a full node) n + // can never be reduced to nil. + // + // When the loop is done, pos contains the index of the single + // value that is left in n or -2 if n contains at least two + // values. pos := -1 - for i := 0; i < 17; i++ { - if n.branch(byte(i)) != nil { + for i, cld := range n { + if cld != nil { if pos == -1 { pos = i } else { pos = -2 + break } } } - - var nnode Node - if pos == 16 { - nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) - nnode.(*ShortNode).dirty = true - } else if pos >= 0 { - cnode := n.branch(byte(pos)) - switch cnode := cnode.(type) { - case *ShortNode: - // Stitch keys - k := append([]byte{byte(pos)}, cnode.Key()...) - nnode = NewShortNode(self, k, cnode.Value()) - nnode.(*ShortNode).dirty = true - case *FullNode: - nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) - nnode.(*ShortNode).dirty = true + if pos >= 0 { + if pos != 16 { + // If the remaining entry is a short node, it replaces + // n and its key gets the missing nibble tacked to the + // front. This avoids creating an invalid + // shortNode{..., shortNode{...}}. Since the entry + // might not be loaded yet, resolve it just for this + // check. + cnode := t.resolve(n[pos]) + if cnode, ok := cnode.(shortNode); ok { + k := append([]byte{byte(pos)}, cnode.Key...) + return shortNode{k, cnode.Val} + } } - } else { - nnode = n + // Otherwise, n is replaced by a one-nibble short node + // containing the child. + return shortNode{[]byte{byte(pos)}, n[pos]} } + // n still contains at least two values and cannot be reduced. + return n - return nnode case nil: return nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and delete from it. This leaves all child nodes on + // the path to the value in the trie. + // + // TODO: track whether deletion actually hit a key and keep + // n as a hash node if it didn't. + return t.delete(t.resolveHash(n), key) + default: - panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) } } -// casting functions and cache storing -func (self *Trie) mknode(value *common.Value) Node { - l := value.Len() - switch l { - case 0: - return nil - case 2: - // A value node may consists of 2 bytes. - if value.Get(0).Len() != 0 { - key := CompactDecode(value.Get(0).Bytes()) - if key[len(key)-1] == 16 { - return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes())) - } else { - return NewShortNode(self, key, self.mknode(value.Get(1))) - } - } - case 17: - if len(value.Bytes()) != 17 { - fnode := NewFullNode(self) - for i := 0; i < 16; i++ { - fnode.set(byte(i), self.mknode(value.Get(i))) - } - return fnode +func concat(s1 []byte, s2 ...byte) []byte { + r := make([]byte, len(s1)+len(s2)) + copy(r, s1) + copy(r[len(s1):], s2) + return r +} + +func (t *Trie) resolve(n node) node { + if n, ok := n.(hashNode); ok { + return t.resolveHash(n) + } + return n +} + +func (t *Trie) resolveHash(n hashNode) node { + if v, ok := globalCache.Get(n); ok { + return v + } + enc, err := t.db.Get(n) + if err != nil || enc == nil { + // TODO: This needs to be improved to properly distinguish errors. + // Disk I/O errors shouldn't produce nil (and cause a + // consensus failure or weird crash), but it is unclear how + // they could be handled because the entire stack above the trie isn't + // prepared to cope with missing state nodes. + if glog.V(logger.Error) { + glog.Errorf("Dangling hash node ref %x: %v", n, err) } - case 32: - return NewHash(value.Bytes(), self) + return nil + } + dec := mustDecodeNode(n, enc) + if dec != nil { + globalCache.Put(n, dec) } + return dec +} + +// Root returns the root hash of the trie. +// Deprecated: use Hash instead. +func (t *Trie) Root() []byte { return t.Hash().Bytes() } - return NewValueNode(self, value.Bytes()) +// Hash returns the root hash of the trie. It does not write to the +// database and can be used even if the trie doesn't have one. +func (t *Trie) Hash() common.Hash { + root, _ := t.hashRoot(nil) + return common.BytesToHash(root.(hashNode)) } -func (self *Trie) trans(node Node) Node { - switch node := node.(type) { - case *HashNode: - value := common.NewValueFromBytes(self.cache.Get(node.key)) - return self.mknode(value) - default: - return node +// Commit writes all nodes to the trie's database. +// Nodes are stored with their sha3 hash as the key. +// +// Committing flushes nodes from memory. +// Subsequent Get calls will load nodes from the database. +func (t *Trie) Commit() (root common.Hash, err error) { + if t.db == nil { + panic("Commit called on trie with nil database") } + return t.CommitTo(t.db) } -func (self *Trie) store(node Node) interface{} { - data := common.Encode(node) - if len(data) >= 32 { - key := crypto.Sha3(data) - if node.Dirty() { - //fmt.Println("save", node) - //fmt.Println() - self.cache.Put(key, data) - } +// CommitTo writes all nodes to the given database. +// Nodes are stored with their sha3 hash as the key. +// +// Committing flushes nodes from memory. Subsequent Get calls will +// load nodes from the trie's database. Calling code must ensure that +// the changes made to db are written back to the trie's attached +// database before using the trie. +func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { + n, err := t.hashRoot(db) + if err != nil { + return (common.Hash{}), err + } + t.root = n + return common.BytesToHash(n.(hashNode)), nil +} - return key +func (t *Trie) hashRoot(db DatabaseWriter) (node, error) { + if t.root == nil { + return hashNode(emptyRoot.Bytes()), nil + } + if t.hasher == nil { + t.hasher = newHasher() } + return t.hasher.hash(t.root, db, true) +} - return node.RlpData() +type hasher struct { + tmp *bytes.Buffer + sha hash.Hash } -func (self *Trie) PrintRoot() { - fmt.Println(self.root) - fmt.Printf("root=%x\n", self.Root()) +func newHasher() *hasher { + return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} +} + +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) { + hashed, err := h.replaceChildren(n, db) + if err != nil { + return hashNode{}, err + } + if n, err = h.store(hashed, db, force); err != nil { + return hashNode{}, err + } + return n, nil +} + +// hashChildren replaces child nodes of n with their hashes if the encoded +// size of the child is larger than a hash. +func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) { + var err error + switch n := n.(type) { + case shortNode: + n.Key = compactEncode(n.Key) + if _, ok := n.Val.(valueNode); !ok { + if n.Val, err = h.hash(n.Val, db, false); err != nil { + return n, err + } + } + if n.Val == nil { + // Ensure that nil children are encoded as empty strings. + n.Val = valueNode(nil) + } + return n, nil + case fullNode: + for i := 0; i < 16; i++ { + if n[i] != nil { + if n[i], err = h.hash(n[i], db, false); err != nil { + return n, err + } + } else { + // Ensure that nil children are encoded as empty strings. + n[i] = valueNode(nil) + } + } + if n[16] == nil { + n[16] = valueNode(nil) + } + return n, nil + default: + return n, nil + } +} + +func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { + // Don't store hashes or empty nodes. + if _, isHash := n.(hashNode); n == nil || isHash { + return n, nil + } + h.tmp.Reset() + if err := rlp.Encode(h.tmp, n); err != nil { + panic("encode error: " + err.Error()) + } + if h.tmp.Len() < 32 && !force { + // Nodes smaller than 32 bytes are stored inside their parent. + return n, nil + } + // Larger nodes are replaced by their hash and stored in the database. + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + key := hashNode(h.sha.Sum(nil)) + if db != nil { + err := db.Put(key, h.tmp.Bytes()) + return key, err + } + return key, nil } |