diff options
Diffstat (limited to 'trie/node.go')
-rw-r--r-- | trie/node.go | 174 |
1 files changed, 150 insertions, 24 deletions
diff --git a/trie/node.go b/trie/node.go index 9d49029de..0bfa21dc4 100644 --- a/trie/node.go +++ b/trie/node.go @@ -16,46 +16,172 @@ package trie -import "fmt" +import ( + "fmt" + "io" + "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} -type Node interface { - Value() Node - Copy(*Trie) Node // All nodes, for now, return them self - Dirty() bool +type node interface { fstring(string) string - Hash() interface{} - RlpData() interface{} - setDirty(dirty bool) } -// Value node -func (self *ValueNode) String() string { return self.fstring("") } -func (self *FullNode) String() string { return self.fstring("") } -func (self *ShortNode) String() string { return self.fstring("") } -func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) } +type ( + fullNode [17]node + shortNode struct { + Key []byte + Val node + } + hashNode []byte + valueNode []byte +) -//func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("< %x > ", self.key) } -func (self *HashNode) fstring(ind string) string { - return fmt.Sprintf("%v", self.trie.trans(self)) -} +// Pretty printing. +func (n fullNode) String() string { return n.fstring("") } +func (n shortNode) String() string { return n.fstring("") } +func (n hashNode) String() string { return n.fstring("") } +func (n valueNode) String() string { return n.fstring("") } -// Full node -func (self *FullNode) fstring(ind string) string { +func (n fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) - for i, node := range self.nodes { + for i, node := range n { if node == nil { resp += fmt.Sprintf("%s: <nil> ", indices[i]) } else { resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" ")) } } - return resp + fmt.Sprintf("\n%s] ", ind) } +func (n shortNode) fstring(ind string) string { + return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" ")) +} +func (n hashNode) fstring(ind string) string { + return fmt.Sprintf("<%x> ", []byte(n)) +} +func (n valueNode) fstring(ind string) string { + return fmt.Sprintf("%x ", []byte(n)) +} + +func mustDecodeNode(dbkey, buf []byte) node { + n, err := decodeNode(buf) + if err != nil { + panic(fmt.Sprintf("node %x: %v", dbkey, err)) + } + return n +} + +// decodeNode parses the RLP encoding of a trie node. +func decodeNode(buf []byte) (node, error) { + if len(buf) == 0 { + return nil, io.ErrUnexpectedEOF + } + elems, _, err := rlp.SplitList(buf) + if err != nil { + return nil, fmt.Errorf("decode error: %v", err) + } + switch c, _ := rlp.CountValues(elems); c { + case 2: + n, err := decodeShort(elems) + return n, wrapError(err, "short") + case 17: + n, err := decodeFull(elems) + return n, wrapError(err, "full") + default: + return nil, fmt.Errorf("invalid number of list elements: %v", c) + } +} + +func decodeShort(buf []byte) (node, error) { + kbuf, rest, err := rlp.SplitString(buf) + if err != nil { + return nil, err + } + key := compactDecode(kbuf) + if key[len(key)-1] == 16 { + // value node + val, _, err := rlp.SplitString(rest) + if err != nil { + return nil, fmt.Errorf("invalid value node: %v", err) + } + return shortNode{key, valueNode(val)}, nil + } + r, _, err := decodeRef(rest) + if err != nil { + return nil, wrapError(err, "val") + } + return shortNode{key, r}, nil +} + +func decodeFull(buf []byte) (fullNode, error) { + var n fullNode + for i := 0; i < 16; i++ { + cld, rest, err := decodeRef(buf) + if err != nil { + return n, wrapError(err, fmt.Sprintf("[%d]", i)) + } + n[i], buf = cld, rest + } + val, _, err := rlp.SplitString(buf) + if err != nil { + return n, err + } + if len(val) > 0 { + n[16] = valueNode(val) + } + return n, nil +} + +const hashLen = len(common.Hash{}) + +func decodeRef(buf []byte) (node, []byte, error) { + kind, val, rest, err := rlp.Split(buf) + if err != nil { + return nil, buf, err + } + switch { + case kind == rlp.List: + // 'embedded' node reference. The encoding must be smaller + // than a hash in order to be valid. + if size := len(buf) - len(rest); size > hashLen { + err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) + return nil, buf, err + } + n, err := decodeNode(buf) + return n, rest, err + case kind == rlp.String && len(val) == 0: + // empty node + return nil, rest, nil + case kind == rlp.String && len(val) == 32: + return hashNode(val), rest, nil + default: + return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val)) + } +} + +// wraps a decoding error with information about the path to the +// invalid child node (for debugging encoding issues). +type decodeError struct { + what error + stack []string +} + +func wrapError(err error, ctx string) error { + if err == nil { + return nil + } + if decErr, ok := err.(*decodeError); ok { + decErr.stack = append(decErr.stack, ctx) + return decErr + } + return &decodeError{err, []string{ctx}} +} -// Short node -func (self *ShortNode) fstring(ind string) string { - return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" ")) +func (err *decodeError) Error() string { + return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-")) } |