diff options
Diffstat (limited to 'trie')
-rw-r--r-- | trie/iterator.go | 14 | ||||
-rw-r--r-- | trie/iterator_test.go | 2 | ||||
-rw-r--r-- | trie/secure_trie.go | 26 | ||||
-rw-r--r-- | trie/trie.go | 65 | ||||
-rw-r--r-- | trie/trie_test.go | 51 |
5 files changed, 68 insertions, 90 deletions
diff --git a/trie/iterator.go b/trie/iterator.go index aff614f95..fda7c6cbe 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -2,19 +2,17 @@ package trie import ( "bytes" - - "github.com/ethereum/go-ethereum/common" ) type Iterator struct { trie *Trie - Key common.Hash + Key []byte Value []byte } func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie} + return &Iterator{trie: trie, Key: nil} } func (self *Iterator) Next() bool { @@ -22,15 +20,15 @@ func (self *Iterator) Next() bool { defer self.trie.mu.Unlock() isIterStart := false - if (self.Key == common.Hash{}) { + if self.Key == nil { isIterStart = true - //self.Key = make([]byte, 32) + self.Key = make([]byte, 32) } - key := RemTerm(CompactHexDecode(self.Key.Str())) + key := RemTerm(CompactHexDecode(string(self.Key))) k := self.next(self.trie.root, key, isIterStart) - self.Key = common.StringToHash(DecodeCompact(k)) + self.Key = []byte(DecodeCompact(k)) return len(k) > 0 } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 5f95caa68..74d9e903c 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -22,7 +22,7 @@ func TestIterator(t *testing.T) { it := trie.Iterator() for it.Next() { - v[it.Key.Str()] = true + v[string(it.Key)] = true } for k, found := range v { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index b31791cad..b9fa376b8 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -1,38 +1,34 @@ package trie -import ( - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" -) +import "github.com/ethereum/go-ethereum/crypto" type SecureTrie struct { *Trie } -func NewSecure(root common.Hash, backend Backend) *SecureTrie { +func NewSecure(root []byte, backend Backend) *SecureTrie { return &SecureTrie{New(root, backend)} } -func (self *SecureTrie) Update(key common.Hash, value []byte) Node { - return self.Trie.Update(common.BytesToHash(crypto.Sha3(key[:])), value) +func (self *SecureTrie) Update(key, value []byte) Node { + return self.Trie.Update(crypto.Sha3(key), value) } - func (self *SecureTrie) UpdateString(key, value string) Node { - return self.Update(common.StringToHash(key), []byte(value)) + return self.Update([]byte(key), []byte(value)) } -func (self *SecureTrie) Get(key common.Hash) []byte { - return self.Trie.Get(common.BytesToHash(crypto.Sha3(key[:]))) +func (self *SecureTrie) Get(key []byte) []byte { + return self.Trie.Get(crypto.Sha3(key)) } func (self *SecureTrie) GetString(key string) []byte { - return self.Get(common.StringToHash(key)) + return self.Get([]byte(key)) } -func (self *SecureTrie) Delete(key common.Hash) Node { - return self.Trie.Delete(common.BytesToHash(crypto.Sha3(key[:]))) +func (self *SecureTrie) Delete(key []byte) Node { + return self.Trie.Delete(crypto.Sha3(key)) } func (self *SecureTrie) DeleteString(key string) Node { - return self.Delete(common.StringToHash(key)) + return self.Delete([]byte(key)) } func (self *SecureTrie) Copy() *SecureTrie { diff --git a/trie/trie.go b/trie/trie.go index 759718400..1c1112a7f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -6,32 +6,31 @@ import ( "fmt" "sync" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/common" ) func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(common.Hash{}, backend) + t2 := New(nil, backend) it := t1.Iterator() for it.Next() { t2.Update(it.Key, it.Value) } - a, b := t2.Hash(), t1.Hash() - return bytes.Equal(a[:], b[:]), t2 + return bytes.Equal(t2.Hash(), t1.Hash()), t2 } type Trie struct { mu sync.Mutex root Node - roothash common.Hash + roothash []byte cache *Cache revisions *list.List } -func New(root common.Hash, backend Backend) *Trie { +func New(root []byte, backend Backend) *Trie { trie := &Trie{} trie.revisions = list.New() trie.roothash = root @@ -39,8 +38,8 @@ func New(root common.Hash, backend Backend) *Trie { trie.cache = NewCache(backend) } - if (root != common.Hash{}) { - value := common.NewValueFromBytes(trie.cache.Get(root[:])) + if root != nil { + value := common.NewValueFromBytes(trie.cache.Get(root)) trie.root = trie.mknode(value) } @@ -52,13 +51,9 @@ func (self *Trie) Iterator() *Iterator { } func (self *Trie) Copy() *Trie { - //cpy := make([]byte, 32) - //copy(cpy, self.roothash) - - // cheap copying method - var cpy common.Hash - cpy.Set(self.roothash) - trie := New(common.Hash{}, nil) + cpy := make([]byte, 32) + copy(cpy, self.roothash) + trie := New(nil, nil) trie.cache = self.cache.Copy() if self.root != nil { trie.root = self.root.Copy(trie) @@ -68,21 +63,21 @@ func (self *Trie) Copy() *Trie { } // Legacy support -func (self *Trie) Root() common.Hash { return self.Hash() } -func (self *Trie) Hash() common.Hash { - var hash common.Hash +func (self *Trie) Root() []byte { return self.Hash() } +func (self *Trie) Hash() []byte { + var hash []byte if self.root != nil { t := self.root.Hash() - if h, ok := t.(common.Hash); ok && (h != common.Hash{}) { - hash = h + if byts, ok := t.([]byte); ok && len(byts) > 0 { + hash = byts } else { - hash = common.BytesToHash(crypto.Sha3(common.Encode(self.root.RlpData()))) + hash = crypto.Sha3(common.Encode(self.root.RlpData())) } } else { - hash = common.BytesToHash(crypto.Sha3(common.Encode(""))) + hash = crypto.Sha3(common.Encode("")) } - if hash != self.roothash { + if !bytes.Equal(hash, self.roothash) { self.revisions.PushBack(self.roothash) self.roothash = hash } @@ -107,21 +102,19 @@ func (self *Trie) Reset() { self.cache.Reset() if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).(common.Hash) + revision := self.revisions.Remove(self.revisions.Back()).([]byte) self.roothash = revision } - value := common.NewValueFromBytes(self.cache.Get(self.roothash[:])) + value := common.NewValueFromBytes(self.cache.Get(self.roothash)) self.root = self.mknode(value) } -func (self *Trie) UpdateString(key, value string) Node { - return self.Update(common.StringToHash(key), []byte(value)) -} -func (self *Trie) Update(key common.Hash, value []byte) Node { +func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } +func (self *Trie) Update(key, value []byte) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) if len(value) != 0 { self.root = self.insert(self.root, k, &ValueNode{self, value}) @@ -132,12 +125,12 @@ func (self *Trie) Update(key common.Hash, value []byte) Node { return self.root } -func (self *Trie) GetString(key string) []byte { return self.Get(common.StringToHash(key)) } -func (self *Trie) Get(key common.Hash) []byte { +func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } +func (self *Trie) Get(key []byte) []byte { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) n := self.get(self.root, k) if n != nil { @@ -147,12 +140,12 @@ func (self *Trie) Get(key common.Hash) []byte { return nil } -func (self *Trie) DeleteString(key string) Node { return self.Delete(common.StringToHash(key)) } -func (self *Trie) Delete(key common.Hash) Node { +func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } +func (self *Trie) Delete(key []byte) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(key.Str()) + k := CompactHexDecode(string(key)) self.root = self.delete(self.root, k) return self.root diff --git a/trie/trie_test.go b/trie/trie_test.go index f5d17c3da..1393e0c97 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/common" ) type Db map[string][]byte @@ -16,18 +16,18 @@ func (self Db) Put(k, v []byte) { self[string(k)] = v } // Used for testing func NewEmpty() *Trie { - return New(common.Hash{}, make(Db)) + return New(nil, make(Db)) } func NewEmptySecure() *SecureTrie { - return NewSecure(common.Hash{}, make(Db)) + return NewSecure(nil, make(Db)) } func TestEmptyTrie(t *testing.T) { trie := NewEmpty() res := trie.Hash() exp := crypto.Sha3(common.Encode("")) - if !bytes.Equal(res[:], exp[:]) { + if !bytes.Equal(res, exp) { t.Errorf("expected %x got %x", exp, res) } } @@ -41,7 +41,7 @@ func TestInsert(t *testing.T) { exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") root := trie.Hash() - if !bytes.Equal(root[:], exp[:]) { + if !bytes.Equal(root, exp) { t.Errorf("exp %x got %x", exp, root) } @@ -50,7 +50,7 @@ func TestInsert(t *testing.T) { exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") root = trie.Hash() - if !bytes.Equal(root[:], exp) { + if !bytes.Equal(root, exp) { t.Errorf("exp %x got %x", exp, root) } } @@ -96,7 +96,7 @@ func TestDelete(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash[:], exp) { + if !bytes.Equal(hash, exp) { t.Errorf("expected %x got %x", exp, hash) } } @@ -120,7 +120,7 @@ func TestEmptyValues(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash[:], exp) { + if !bytes.Equal(hash, exp) { t.Errorf("expected %x got %x", exp, hash) } } @@ -150,7 +150,7 @@ func TestReplication(t *testing.T) { hash := trie2.Hash() exp := trie.Hash() - if !bytes.Equal(hash[:], exp[:]) { + if !bytes.Equal(hash, exp) { t.Errorf("root failure. expected %x got %x", exp, hash) } @@ -168,9 +168,7 @@ func TestReset(t *testing.T) { } trie.Commit() - var before common.Hash - before.Set(trie.roothash) - + before := common.CopyBytes(trie.roothash) trie.UpdateString("should", "revert") trie.Hash() // Should have no effect @@ -179,11 +177,9 @@ func TestReset(t *testing.T) { // ### trie.Reset() + after := common.CopyBytes(trie.roothash) - var after common.Hash - after.Set(trie.roothash) - - if before != after { + if !bytes.Equal(before, after) { t.Errorf("expected roots to be equal. %x - %x", before, after) } } @@ -252,7 +248,7 @@ func BenchmarkGets(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - trie.GetString("horse") + trie.Get([]byte("horse")) } } @@ -267,9 +263,8 @@ func BenchmarkUpdate(b *testing.B) { } type kv struct { - k common.Hash - v []byte - t bool + k, v []byte + t bool } func TestLargeData(t *testing.T) { @@ -277,21 +272,17 @@ func TestLargeData(t *testing.T) { vals := make(map[string]*kv) for i := byte(0); i < 255; i++ { - var k1 common.Hash - k1.SetBytes([]byte{i}) - var k2 common.Hash - k2.SetBytes([]byte{10, i}) - value := &kv{k1, []byte{i}, false} - value2 := &kv{k2, []byte{i}, false} + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} trie.Update(value.k, value.v) trie.Update(value2.k, value2.v) - vals[value.k.Str()] = value - vals[value2.k.Str()] = value2 + vals[string(value.k)] = value + vals[string(value2.k)] = value2 } it := trie.Iterator() for it.Next() { - vals[it.Key.Str()].t = true + vals[string(it.Key)].t = true } var untouched []*kv @@ -332,7 +323,7 @@ func TestSecureDelete(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") - if !bytes.Equal(hash[:], exp) { + if !bytes.Equal(hash, exp) { t.Errorf("expected %x got %x", exp, hash) } } |