aboutsummaryrefslogtreecommitdiffstats
path: root/trie/node.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie/node.go')
-rw-r--r--trie/node.go60
1 files changed, 39 insertions, 21 deletions
diff --git a/trie/node.go b/trie/node.go
index 0bfa21dc4..b97d370be 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -29,18 +29,36 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b
type node interface {
fstring(string) string
+ cache() (hashNode, bool)
}
type (
- fullNode [17]node
+ fullNode struct {
+ Children [17]node // Actual trie node data to encode/decode (needs custom encoder)
+ hash hashNode // Cached hash of the node to prevent rehashing (may be nil)
+ dirty bool // Cached flag whether the node's new or already stored
+ }
shortNode struct {
- Key []byte
- Val node
+ Key []byte
+ Val node
+ hash hashNode // Cached hash of the node to prevent rehashing (may be nil)
+ dirty bool // Cached flag whether the node's new or already stored
}
hashNode []byte
valueNode []byte
)
+// EncodeRLP encodes a full node into the consensus RLP format.
+func (n fullNode) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, n.Children)
+}
+
+// Cache accessors to retrieve precalculated values (avoid lengthy type switches).
+func (n fullNode) cache() (hashNode, bool) { return n.hash, n.dirty }
+func (n shortNode) cache() (hashNode, bool) { return n.hash, n.dirty }
+func (n hashNode) cache() (hashNode, bool) { return nil, true }
+func (n valueNode) cache() (hashNode, bool) { return nil, true }
+
// Pretty printing.
func (n fullNode) String() string { return n.fstring("") }
func (n shortNode) String() string { return n.fstring("") }
@@ -49,7 +67,7 @@ func (n valueNode) String() string { return n.fstring("") }
func (n fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
- for i, node := range n {
+ for i, node := range n.Children {
if node == nil {
resp += fmt.Sprintf("%s: <nil> ", indices[i])
} else {
@@ -68,16 +86,16 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n))
}
-func mustDecodeNode(dbkey, buf []byte) node {
- n, err := decodeNode(buf)
+func mustDecodeNode(hash, buf []byte) node {
+ n, err := decodeNode(hash, buf)
if err != nil {
- panic(fmt.Sprintf("node %x: %v", dbkey, err))
+ panic(fmt.Sprintf("node %x: %v", hash, err))
}
return n
}
// decodeNode parses the RLP encoding of a trie node.
-func decodeNode(buf []byte) (node, error) {
+func decodeNode(hash, buf []byte) (node, error) {
if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF
}
@@ -87,18 +105,18 @@ func decodeNode(buf []byte) (node, error) {
}
switch c, _ := rlp.CountValues(elems); c {
case 2:
- n, err := decodeShort(elems)
+ n, err := decodeShort(hash, buf, elems)
return n, wrapError(err, "short")
case 17:
- n, err := decodeFull(elems)
+ n, err := decodeFull(hash, buf, elems)
return n, wrapError(err, "full")
default:
return nil, fmt.Errorf("invalid number of list elements: %v", c)
}
}
-func decodeShort(buf []byte) (node, error) {
- kbuf, rest, err := rlp.SplitString(buf)
+func decodeShort(hash, buf, elems []byte) (node, error) {
+ kbuf, rest, err := rlp.SplitString(elems)
if err != nil {
return nil, err
}
@@ -109,30 +127,30 @@ func decodeShort(buf []byte) (node, error) {
if err != nil {
return nil, fmt.Errorf("invalid value node: %v", err)
}
- return shortNode{key, valueNode(val)}, nil
+ return shortNode{key, valueNode(val), hash, false}, nil
}
r, _, err := decodeRef(rest)
if err != nil {
return nil, wrapError(err, "val")
}
- return shortNode{key, r}, nil
+ return shortNode{key, r, hash, false}, nil
}
-func decodeFull(buf []byte) (fullNode, error) {
- var n fullNode
+func decodeFull(hash, buf, elems []byte) (fullNode, error) {
+ n := fullNode{hash: hash}
for i := 0; i < 16; i++ {
- cld, rest, err := decodeRef(buf)
+ cld, rest, err := decodeRef(elems)
if err != nil {
return n, wrapError(err, fmt.Sprintf("[%d]", i))
}
- n[i], buf = cld, rest
+ n.Children[i], elems = cld, rest
}
- val, _, err := rlp.SplitString(buf)
+ val, _, err := rlp.SplitString(elems)
if err != nil {
return n, err
}
if len(val) > 0 {
- n[16] = valueNode(val)
+ n.Children[16] = valueNode(val)
}
return n, nil
}
@@ -152,7 +170,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err
}
- n, err := decodeNode(buf)
+ n, err := decodeNode(nil, buf)
return n, rest, err
case kind == rlp.String && len(val) == 0:
// empty node