diff options
Diffstat (limited to 'trie/trie.go')
-rw-r--r-- | trie/trie.go | 65 |
1 files changed, 29 insertions, 36 deletions
diff --git a/trie/trie.go b/trie/trie.go index 759718400..1c1112a7f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -6,32 +6,31 @@ import ( "fmt" "sync" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/common" ) func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(common.Hash{}, backend) + t2 := New(nil, backend) it := t1.Iterator() for it.Next() { t2.Update(it.Key, it.Value) } - a, b := t2.Hash(), t1.Hash() - return bytes.Equal(a[:], b[:]), t2 + return bytes.Equal(t2.Hash(), t1.Hash()), t2 } type Trie struct { mu sync.Mutex root Node - roothash common.Hash + roothash []byte cache *Cache revisions *list.List } -func New(root common.Hash, backend Backend) *Trie { +func New(root []byte, backend Backend) *Trie { trie := &Trie{} trie.revisions = list.New() trie.roothash = root @@ -39,8 +38,8 @@ func New(root common.Hash, backend Backend) *Trie { trie.cache = NewCache(backend) } - if (root != common.Hash{}) { - value := common.NewValueFromBytes(trie.cache.Get(root[:])) + if root != nil { + value := common.NewValueFromBytes(trie.cache.Get(root)) trie.root = trie.mknode(value) } @@ -52,13 +51,9 @@ func (self *Trie) Iterator() *Iterator { } func (self *Trie) Copy() *Trie { - //cpy := make([]byte, 32) - //copy(cpy, self.roothash) - - // cheap copying method - var cpy common.Hash - cpy.Set(self.roothash) - trie := New(common.Hash{}, nil) + cpy := make([]byte, 32) + copy(cpy, self.roothash) + trie := New(nil, nil) trie.cache = self.cache.Copy() if self.root != nil { trie.root = self.root.Copy(trie) @@ -68,21 +63,21 @@ func (self *Trie) Copy() *Trie { } // Legacy support -func (self *Trie) Root() common.Hash { return self.Hash() } -func (self *Trie) Hash() common.Hash { - var hash common.Hash +func (self *Trie) Root() []byte { return self.Hash() } +func (self *Trie) Hash() []byte { + var hash []byte if self.root != nil { t := self.root.Hash() - if h, ok := t.(common.Hash); ok && (h != common.Hash{}) { - hash = h + if byts, ok := t.([]byte); ok && len(byts) > 0 { + hash = byts } else { - hash = common.BytesToHash(crypto.Sha3(common.Encode(self.root.RlpData()))) + hash = crypto.Sha3(common.Encode(self.root.RlpData())) } } else { - hash = common.BytesToHash(crypto.Sha3(common.Encode(""))) + hash = crypto.Sha3(common.Encode("")) } - if hash != self.roothash { + if !bytes.Equal(hash, self.roothash) { self.revisions.PushBack(self.roothash) self.roothash = hash } @@ -107,21 +102,19 @@ func (self *Trie) Reset() { self.cache.Reset() if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).(common.Hash) + revision := self.revisions.Remove(self.revisions.Back()).([]byte) self.roothash = revision } - value := common.NewValueFromBytes(self.cache.Get(self.roothash[:])) + value := common.NewValueFromBytes(self.cache.Get(self.roothash)) self.root = self.mknode(value) } -func (self *Trie) UpdateString(key, value string) Node { - return self.Update(common.StringToHash(key), []byte(value)) -} -func (self *Trie) Update(key common.Hash, value []byte) Node { +func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } +func (self *Trie) Update(key, value []byte) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) if len(value) != 0 { self.root = self.insert(self.root, k, &ValueNode{self, value}) @@ -132,12 +125,12 @@ func (self *Trie) Update(key common.Hash, value []byte) Node { return self.root } -func (self *Trie) GetString(key string) []byte { return self.Get(common.StringToHash(key)) } -func (self *Trie) Get(key common.Hash) []byte { +func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } +func (self *Trie) Get(key []byte) []byte { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) n := self.get(self.root, k) if n != nil { @@ -147,12 +140,12 @@ func (self *Trie) Get(key common.Hash) []byte { return nil } -func (self *Trie) DeleteString(key string) Node { return self.Delete(common.StringToHash(key)) } -func (self *Trie) Delete(key common.Hash) Node { +func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } +func (self *Trie) Delete(key []byte) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) self.root = self.delete(self.root, k) return self.root |