aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJeffrey Wilcke <jeffrey@ethereum.org>2015-10-01 19:34:38 +0800
committerJeffrey Wilcke <jeffrey@ethereum.org>2015-10-01 19:34:38 +0800
commit49ae53850622f3ea051184dccc867fbfec4c9ecb (patch)
tree669f4b161773c9b95c6631e83a3c34d446c33fef
parent581c0901af22d678aedd9eefae6144582c23e1a0 (diff)
parentc1a352c1085baa5c5f7650d331603bbb5532dea4 (diff)
downloadgo-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar.gz
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar.bz2
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar.lz
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar.xz
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.tar.zst
go-tangerine-49ae53850622f3ea051184dccc867fbfec4c9ecb.zip
Merge pull request #1405 from fjl/lean-trie
core, trie: new trie
-rw-r--r--build/update-license.go3
-rw-r--r--core/block_processor.go20
-rw-r--r--core/chain_makers.go14
-rw-r--r--core/chain_manager.go8
-rw-r--r--core/genesis.go16
-rw-r--r--core/state/state_object.go23
-rw-r--r--core/state/state_test.go3
-rw-r--r--core/state/statedb.go104
-rw-r--r--core/types/derive_sha.go15
-rw-r--r--miner/worker.go5
-rw-r--r--tests/block_test_util.go12
-rw-r--r--tests/state_test_util.go8
-rw-r--r--trie/arc.go194
-rw-r--r--trie/cache.go78
-rw-r--r--trie/encoding.go72
-rw-r--r--trie/encoding_test.go36
-rw-r--r--trie/fullnode.go94
-rw-r--r--trie/hashnode.go46
-rw-r--r--trie/iterator.go64
-rw-r--r--trie/iterator_test.go6
-rw-r--r--trie/node.go174
-rw-r--r--trie/proof.go122
-rw-r--r--trie/proof_test.go139
-rw-r--r--trie/secure_trie.go97
-rw-r--r--trie/secure_trie_test.go74
-rw-r--r--trie/shortnode.go57
-rw-r--r--trie/slice.go69
-rw-r--r--trie/trie.go639
-rw-r--r--trie/trie_test.go340
-rw-r--r--trie/valuenode.go42
30 files changed, 1515 insertions, 1059 deletions
diff --git a/build/update-license.go b/build/update-license.go
index e28005cbd..04f52a13c 100644
--- a/build/update-license.go
+++ b/build/update-license.go
@@ -46,9 +46,10 @@ var (
skipPrefixes = []string{
// boring stuff
"Godeps/", "tests/files/", "build/",
- // don't relicense vendored packages
+ // don't relicense vendored sources
"crypto/sha3/", "crypto/ecies/", "logger/glog/",
"crypto/curve.go",
+ "trie/arc.go",
}
// paths with this prefix are licensed as GPL. all other files are LGPL.
diff --git a/core/block_processor.go b/core/block_processor.go
index 238b2db95..40590bdc5 100644
--- a/core/block_processor.go
+++ b/core/block_processor.go
@@ -100,10 +100,8 @@ func (self *BlockProcessor) ApplyTransaction(gp GasPool, statedb *state.StateDB,
}
// Update the state with pending changes
- statedb.SyncIntermediate()
-
usedGas.Add(usedGas, gas)
- receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas)
+ receipt := types.NewReceipt(statedb.IntermediateRoot().Bytes(), usedGas)
receipt.TxHash = tx.Hash()
receipt.GasUsed = new(big.Int).Set(gas)
if MessageCreatesContract(tx) {
@@ -265,16 +263,16 @@ func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs st
// Accumulate static rewards; block reward, uncle's and uncle inclusion.
AccumulateRewards(state, header, uncles)
- // Commit state objects/accounts to a temporary trie (does not save)
- // used to calculate the state root.
- state.SyncObjects()
- if header.Root != state.Root() {
- err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root())
- return
+ // Commit state objects/accounts to a database batch and calculate
+ // the state root. The database is not modified if the root
+ // doesn't match.
+ root, batch := state.CommitBatch()
+ if header.Root != root {
+ return nil, nil, fmt.Errorf("invalid merkle root: header=%x computed=%x", header.Root, root)
}
- // Sync the current block's state to the database
- state.Sync()
+ // Execute the database writes.
+ batch.Write()
return state.Logs(), receipts, nil
}
diff --git a/core/chain_makers.go b/core/chain_makers.go
index 70233438d..3af9b0b89 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -17,6 +17,7 @@
package core
import (
+ "fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
@@ -94,9 +95,9 @@ func (b *BlockGen) AddTx(tx *types.Transaction) {
if err != nil {
panic(err)
}
- b.statedb.SyncIntermediate()
+ root := b.statedb.IntermediateRoot()
b.header.GasUsed.Add(b.header.GasUsed, gas)
- receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed)
+ receipt := types.NewReceipt(root.Bytes(), b.header.GasUsed)
logs := b.statedb.GetLogs(tx.Hash())
receipt.SetLogs(logs)
receipt.Bloom = types.CreateBloom(types.Receipts{receipt})
@@ -163,8 +164,11 @@ func GenerateChain(parent *types.Block, db ethdb.Database, n int, gen func(int,
gen(i, b)
}
AccumulateRewards(statedb, h, b.uncles)
- statedb.SyncIntermediate()
- h.Root = statedb.Root()
+ root, err := statedb.Commit()
+ if err != nil {
+ panic(fmt.Sprintf("state write error: %v", err))
+ }
+ h.Root = root
return types.NewBlock(h, b.txs, b.uncles, b.receipts)
}
for i := 0; i < n; i++ {
@@ -184,7 +188,7 @@ func makeHeader(parent *types.Block, state *state.StateDB) *types.Header {
time = new(big.Int).Add(parent.Time(), big.NewInt(10)) // block time is fixed at 10 seconds
}
return &types.Header{
- Root: state.Root(),
+ Root: state.IntermediateRoot(),
ParentHash: parent.Hash(),
Coinbase: parent.Coinbase(),
Difficulty: CalcDifficulty(time.Uint64(), new(big.Int).Sub(time, big.NewInt(10)).Uint64(), parent.Number(), parent.Difficulty()),
diff --git a/core/chain_manager.go b/core/chain_manager.go
index 383fce70c..55ef3fcad 100644
--- a/core/chain_manager.go
+++ b/core/chain_manager.go
@@ -840,8 +840,8 @@ out:
}
func blockErr(block *types.Block, err error) {
- h := block.Header()
- glog.V(logger.Error).Infof("Bad block #%v (%x)\n", h.Number, h.Hash().Bytes())
- glog.V(logger.Error).Infoln(err)
- glog.V(logger.Debug).Infoln(verifyNonces)
+ if glog.V(logger.Error) {
+ glog.Errorf("Bad block #%v (%s)\n", block.Number(), block.Hash().Hex())
+ glog.Errorf(" %v", err)
+ }
}
diff --git a/core/genesis.go b/core/genesis.go
index b2346da65..bf97da2e2 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -69,7 +69,7 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block,
statedb.SetState(address, common.HexToHash(key), common.HexToHash(value))
}
}
- statedb.SyncObjects()
+ root, stateBatch := statedb.CommitBatch()
difficulty := common.String2Big(genesis.Difficulty)
block := types.NewBlock(&types.Header{
@@ -81,7 +81,7 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block,
Difficulty: difficulty,
MixDigest: common.HexToHash(genesis.Mixhash),
Coinbase: common.HexToAddress(genesis.Coinbase),
- Root: statedb.Root(),
+ Root: root,
}, nil, nil, nil)
if block := GetBlock(chainDb, block.Hash()); block != nil {
@@ -92,8 +92,10 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block,
}
return block, nil
}
- statedb.Sync()
+ if err := stateBatch.Write(); err != nil {
+ return nil, fmt.Errorf("cannot write state: %v", err)
+ }
if err := WriteTd(chainDb, block.Hash(), difficulty); err != nil {
return nil, err
}
@@ -115,12 +117,14 @@ func GenesisBlockForTesting(db ethdb.Database, addr common.Address, balance *big
statedb := state.New(common.Hash{}, db)
obj := statedb.GetOrNewStateObject(addr)
obj.SetBalance(balance)
- statedb.SyncObjects()
- statedb.Sync()
+ root, err := statedb.Commit()
+ if err != nil {
+ panic(fmt.Sprintf("cannot write state: %v", err))
+ }
block := types.NewBlock(&types.Header{
Difficulty: params.GenesisDifficulty,
GasLimit: params.GenesisGasLimit,
- Root: statedb.Root(),
+ Root: root,
}, nil, nil, nil)
return block
}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index 353f2357b..40af9ed9c 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -90,15 +90,13 @@ type StateObject struct {
func NewStateObject(address common.Address, db ethdb.Database) *StateObject {
object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true}
- object.trie = trie.NewSecure((common.Hash{}).Bytes(), db)
+ object.trie, _ = trie.NewSecure(common.Hash{}, db)
object.storage = make(Storage)
object.gasPool = new(big.Int)
-
return object
}
func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Database) *StateObject {
- // TODO clean me up
var extobject struct {
Nonce uint64
Balance *big.Int
@@ -107,7 +105,13 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Datab
}
err := rlp.Decode(bytes.NewReader(data), &extobject)
if err != nil {
- fmt.Println(err)
+ glog.Errorf("can't decode state object %x: %v", address, err)
+ return nil
+ }
+ trie, err := trie.NewSecure(extobject.Root, db)
+ if err != nil {
+ // TODO: bubble this up or panic
+ glog.Errorf("can't create account trie with root %x: %v", extobject.Root[:], err)
return nil
}
@@ -115,11 +119,10 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Datab
object.nonce = extobject.Nonce
object.balance = extobject.Balance
object.codeHash = extobject.CodeHash
- object.trie = trie.NewSecure(extobject.Root[:], db)
+ object.trie = trie
object.storage = make(map[string]common.Hash)
object.gasPool = new(big.Int)
object.code, _ = db.Get(extobject.CodeHash)
-
return object
}
@@ -215,6 +218,7 @@ func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) SetGasLimit(gasLimit *big.Int) {
self.gasPool = new(big.Int).Set(gasLimit)
+ self.dirty = true
if glog.V(logger.Core) {
glog.Infof("%x: gas (+ %v)", self.Address(), self.gasPool)
@@ -225,19 +229,14 @@ func (self *StateObject) SubGas(gas, price *big.Int) error {
if self.gasPool.Cmp(gas) < 0 {
return GasLimitError(self.gasPool, gas)
}
-
self.gasPool.Sub(self.gasPool, gas)
-
- rGas := new(big.Int).Set(gas)
- rGas.Mul(rGas, price)
-
self.dirty = true
-
return nil
}
func (self *StateObject) AddGas(gas, price *big.Int) {
self.gasPool.Add(self.gasPool, gas)
+ self.dirty = true
}
func (self *StateObject) Copy() *StateObject {
diff --git a/core/state/state_test.go b/core/state/state_test.go
index 60836738e..b5a7f4081 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -89,8 +89,7 @@ func TestNull(t *testing.T) {
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
state.SetState(address, common.Hash{}, value)
- state.SyncIntermediate()
- state.Sync()
+ state.Commit()
value = state.GetState(address, common.Hash{})
if !common.EmptyHash(value) {
t.Errorf("expected empty hash. got %x", value)
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 24f97e32a..4233c763b 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -35,7 +35,6 @@ import (
type StateDB struct {
db ethdb.Database
trie *trie.SecureTrie
- root common.Hash
stateObjects map[string]*StateObject
@@ -49,12 +48,19 @@ type StateDB struct {
// Create a new state from a given trie
func New(root common.Hash, db ethdb.Database) *StateDB {
- trie := trie.NewSecure(root[:], db)
- return &StateDB{root: root, db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)}
-}
-
-func (self *StateDB) PrintRoot() {
- self.trie.Trie.PrintRoot()
+ tr, err := trie.NewSecure(root, db)
+ if err != nil {
+ // TODO: bubble this up
+ tr, _ = trie.NewSecure(common.Hash{}, db)
+ glog.Errorf("can't create state trie with root %x: %v", root[:], err)
+ }
+ return &StateDB{
+ db: db,
+ trie: tr,
+ stateObjects: make(map[string]*StateObject),
+ refund: new(big.Int),
+ logs: make(map[common.Hash]Logs),
+ }
}
func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) {
@@ -196,7 +202,6 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
if len(stateObject.CodeHash()) > 0 {
self.db.Put(stateObject.CodeHash(), stateObject.code)
}
-
addr := stateObject.Address()
self.trie.Update(addr[:], stateObject.RlpEncode())
}
@@ -207,6 +212,7 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
addr := stateObject.Address()
self.trie.Delete(addr[:])
+ //delete(self.stateObjects, addr.Str())
}
// Retrieve a state object given my the address. Nil if not found
@@ -303,65 +309,67 @@ func (self *StateDB) Set(state *StateDB) {
self.logSize = state.logSize
}
-func (s *StateDB) Root() common.Hash {
- return common.BytesToHash(s.trie.Root())
-}
-
-// Syncs the trie and all siblings
-func (s *StateDB) Sync() {
- // Sync all nested states
+// IntermediateRoot computes the current root hash of the state trie.
+// It is called in between transactions to get the root hash that
+// goes into transaction receipts.
+func (s *StateDB) IntermediateRoot() common.Hash {
+ s.refund = new(big.Int)
for _, stateObject := range s.stateObjects {
- stateObject.trie.Commit()
- }
-
- s.trie.Commit()
-
- s.Empty()
-}
-
-func (self *StateDB) Empty() {
- self.stateObjects = make(map[string]*StateObject)
- self.refund = new(big.Int)
-}
-
-func (self *StateDB) Refunds() *big.Int {
- return self.refund
-}
-
-// SyncIntermediate updates the intermediate state and all mid steps
-func (self *StateDB) SyncIntermediate() {
- self.refund = new(big.Int)
-
- for _, stateObject := range self.stateObjects {
if stateObject.dirty {
if stateObject.remove {
- self.DeleteStateObject(stateObject)
+ s.DeleteStateObject(stateObject)
} else {
stateObject.Update()
-
- self.UpdateStateObject(stateObject)
+ s.UpdateStateObject(stateObject)
}
stateObject.dirty = false
}
}
+ return s.trie.Hash()
}
-// SyncObjects syncs the changed objects to the trie
-func (self *StateDB) SyncObjects() {
- self.trie = trie.NewSecure(self.root[:], self.db)
+// Commit commits all state changes to the database.
+func (s *StateDB) Commit() (root common.Hash, err error) {
+ return s.commit(s.db)
+}
- self.refund = new(big.Int)
+// CommitBatch commits all state changes to a write batch but does not
+// execute the batch. It is used to validate state changes against
+// the root hash stored in a block.
+func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
+ batch = s.db.NewBatch()
+ root, _ = s.commit(batch)
+ return root, batch
+}
- for _, stateObject := range self.stateObjects {
+func (s *StateDB) commit(db trie.DatabaseWriter) (common.Hash, error) {
+ s.refund = new(big.Int)
+
+ for _, stateObject := range s.stateObjects {
if stateObject.remove {
- self.DeleteStateObject(stateObject)
+ // If the object has been removed, don't bother syncing it
+ // and just mark it for deletion in the trie.
+ s.DeleteStateObject(stateObject)
} else {
+ // Write any storage changes in the state object to its trie.
stateObject.Update()
-
- self.UpdateStateObject(stateObject)
+ // Commit the trie of the object to the batch.
+ // This updates the trie root internally, so
+ // getting the root hash of the storage trie
+ // through UpdateStateObject is fast.
+ if _, err := stateObject.trie.CommitTo(db); err != nil {
+ return common.Hash{}, err
+ }
+ // Update the object in the account trie.
+ s.UpdateStateObject(stateObject)
}
stateObject.dirty = false
}
+ return s.trie.CommitTo(db)
+}
+
+func (self *StateDB) Refunds() *big.Int {
+ return self.refund
}
// Debug stuff
diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go
index 478edb0e8..00c42c5bc 100644
--- a/core/types/derive_sha.go
+++ b/core/types/derive_sha.go
@@ -17,8 +17,9 @@
package types
import (
+ "bytes"
+
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@@ -29,12 +30,12 @@ type DerivableList interface {
}
func DeriveSha(list DerivableList) common.Hash {
- db, _ := ethdb.NewMemDatabase()
- trie := trie.New(nil, db)
+ keybuf := new(bytes.Buffer)
+ trie := new(trie.Trie)
for i := 0; i < list.Len(); i++ {
- key, _ := rlp.EncodeToBytes(uint(i))
- trie.Update(key, list.GetRlp(i))
+ keybuf.Reset()
+ rlp.Encode(keybuf, uint(i))
+ trie.Update(keybuf.Bytes(), list.GetRlp(i))
}
-
- return common.BytesToHash(trie.Root())
+ return trie.Hash()
}
diff --git a/miner/worker.go b/miner/worker.go
index 22d0b9b6e..098f42a72 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -266,7 +266,6 @@ func (self *worker) wait() {
block := result.Block
work := result.Work
- work.state.Sync()
if self.fullValidation {
if _, err := self.chain.InsertChain(types.Blocks{block}); err != nil {
glog.V(logger.Error).Infoln("mining err", err)
@@ -274,6 +273,7 @@ func (self *worker) wait() {
}
go self.mux.Post(core.NewMinedBlockEvent{block})
} else {
+ work.state.Commit()
parent := self.chain.GetBlock(block.ParentHash())
if parent == nil {
glog.V(logger.Error).Infoln("Invalid block found during mining")
@@ -528,8 +528,7 @@ func (self *worker) commitNewWork() {
if atomic.LoadInt32(&self.mining) == 1 {
// commit state root after all state transitions.
core.AccumulateRewards(work.state, header, uncles)
- work.state.SyncObjects()
- header.Root = work.state.Root()
+ header.Root = work.state.IntermediateRoot()
}
// create the new block whose nonce will be mined.
diff --git a/tests/block_test_util.go b/tests/block_test_util.go
index 3ca00bae8..33577cf55 100644
--- a/tests/block_test_util.go
+++ b/tests/block_test_util.go
@@ -253,13 +253,13 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro
statedb.SetState(common.HexToAddress(addrString), common.HexToHash(k), common.HexToHash(v))
}
}
- // sync objects to trie
- statedb.SyncObjects()
- // sync trie to disk
- statedb.Sync()
- if !bytes.Equal(t.Genesis.Root().Bytes(), statedb.Root().Bytes()) {
- return nil, fmt.Errorf("computed state root does not match genesis block %x %x", t.Genesis.Root().Bytes()[:4], statedb.Root().Bytes()[:4])
+ root, err := statedb.Commit()
+ if err != nil {
+ return nil, fmt.Errorf("error writing state: %v", err)
+ }
+ if t.Genesis.Root() != root {
+ return nil, fmt.Errorf("computed state root does not match genesis block: genesis=%x computed=%x", t.Genesis.Root().Bytes()[:4], root.Bytes()[:4])
}
return statedb, nil
}
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index 95ecdd0a8..3d8dfca31 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -201,9 +201,9 @@ func runStateTest(test VmTest) error {
}
}
- statedb.Sync()
- if common.HexToHash(test.PostStateRoot) != statedb.Root() {
- return fmt.Errorf("Post state root error. Expected %s, got %x", test.PostStateRoot, statedb.Root())
+ root, _ := statedb.Commit()
+ if common.HexToHash(test.PostStateRoot) != root {
+ return fmt.Errorf("Post state root error. Expected %s, got %x", test.PostStateRoot, root)
}
// check logs
@@ -247,7 +247,7 @@ func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state.
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) {
statedb.Set(snapshot)
}
- statedb.SyncObjects()
+ statedb.Commit()
return ret, vmenv.state.Logs(), vmenv.Gas, err
}
diff --git a/trie/arc.go b/trie/arc.go
new file mode 100644
index 000000000..9da012e16
--- /dev/null
+++ b/trie/arc.go
@@ -0,0 +1,194 @@
+// Copyright (c) 2015 Hans Alexander Gugel <alexander.gugel@gmail.com>
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+// This file contains a modified version of package arc from
+// https://github.com/alexanderGugel/arc
+//
+// It implements the ARC (Adaptive Replacement Cache) algorithm as detailed in
+// https://www.usenix.org/legacy/event/fast03/tech/full_papers/megiddo/megiddo.pdf
+
+package trie
+
+import (
+ "container/list"
+ "sync"
+)
+
+type arc struct {
+ p int
+ c int
+ t1 *list.List
+ b1 *list.List
+ t2 *list.List
+ b2 *list.List
+ cache map[string]*entry
+ mutex sync.Mutex
+}
+
+type entry struct {
+ key hashNode
+ value node
+ ll *list.List
+ el *list.Element
+}
+
+// newARC returns a new Adaptive Replacement Cache with the
+// given capacity.
+func newARC(c int) *arc {
+ return &arc{
+ c: c,
+ t1: list.New(),
+ b1: list.New(),
+ t2: list.New(),
+ b2: list.New(),
+ cache: make(map[string]*entry, c),
+ }
+}
+
+// Put inserts a new key-value pair into the cache.
+// This optimizes future access to this entry (side effect).
+func (a *arc) Put(key hashNode, value node) bool {
+ a.mutex.Lock()
+ defer a.mutex.Unlock()
+ ent, ok := a.cache[string(key)]
+ if ok != true {
+ ent = &entry{key: key, value: value}
+ a.req(ent)
+ a.cache[string(key)] = ent
+ } else {
+ ent.value = value
+ a.req(ent)
+ }
+ return ok
+}
+
+// Get retrieves a previously via Set inserted entry.
+// This optimizes future access to this entry (side effect).
+func (a *arc) Get(key hashNode) (value node, ok bool) {
+ a.mutex.Lock()
+ defer a.mutex.Unlock()
+ ent, ok := a.cache[string(key)]
+ if ok {
+ a.req(ent)
+ return ent.value, ent.value != nil
+ }
+ return nil, false
+}
+
+func (a *arc) req(ent *entry) {
+ if ent.ll == a.t1 || ent.ll == a.t2 {
+ // Case I
+ ent.setMRU(a.t2)
+ } else if ent.ll == a.b1 {
+ // Case II
+ // Cache Miss in t1 and t2
+
+ // Adaptation
+ var d int
+ if a.b1.Len() >= a.b2.Len() {
+ d = 1
+ } else {
+ d = a.b2.Len() / a.b1.Len()
+ }
+ a.p = a.p + d
+ if a.p > a.c {
+ a.p = a.c
+ }
+
+ a.replace(ent)
+ ent.setMRU(a.t2)
+ } else if ent.ll == a.b2 {
+ // Case III
+ // Cache Miss in t1 and t2
+
+ // Adaptation
+ var d int
+ if a.b2.Len() >= a.b1.Len() {
+ d = 1
+ } else {
+ d = a.b1.Len() / a.b2.Len()
+ }
+ a.p = a.p - d
+ if a.p < 0 {
+ a.p = 0
+ }
+
+ a.replace(ent)
+ ent.setMRU(a.t2)
+ } else if ent.ll == nil {
+ // Case IV
+
+ if a.t1.Len()+a.b1.Len() == a.c {
+ // Case A
+ if a.t1.Len() < a.c {
+ a.delLRU(a.b1)
+ a.replace(ent)
+ } else {
+ a.delLRU(a.t1)
+ }
+ } else if a.t1.Len()+a.b1.Len() < a.c {
+ // Case B
+ if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() >= a.c {
+ if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() == 2*a.c {
+ a.delLRU(a.b2)
+ }
+ a.replace(ent)
+ }
+ }
+
+ ent.setMRU(a.t1)
+ }
+}
+
+func (a *arc) delLRU(list *list.List) {
+ lru := list.Back()
+ list.Remove(lru)
+ delete(a.cache, string(lru.Value.(*entry).key))
+}
+
+func (a *arc) replace(ent *entry) {
+ if a.t1.Len() > 0 && ((a.t1.Len() > a.p) || (ent.ll == a.b2 && a.t1.Len() == a.p)) {
+ lru := a.t1.Back().Value.(*entry)
+ lru.value = nil
+ lru.setMRU(a.b1)
+ } else {
+ lru := a.t2.Back().Value.(*entry)
+ lru.value = nil
+ lru.setMRU(a.b2)
+ }
+}
+
+func (e *entry) setLRU(list *list.List) {
+ e.detach()
+ e.ll = list
+ e.el = e.ll.PushBack(e)
+}
+
+func (e *entry) setMRU(list *list.List) {
+ e.detach()
+ e.ll = list
+ e.el = e.ll.PushFront(e)
+}
+
+func (e *entry) detach() {
+ if e.ll != nil {
+ e.ll.Remove(e.el)
+ }
+}
diff --git a/trie/cache.go b/trie/cache.go
deleted file mode 100644
index e475fc861..000000000
--- a/trie/cache.go
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import (
- "github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/logger/glog"
- "github.com/syndtr/goleveldb/leveldb"
-)
-
-type Backend interface {
- Get([]byte) ([]byte, error)
- Put([]byte, []byte) error
-}
-
-type Cache struct {
- batch *leveldb.Batch
- store map[string][]byte
- backend Backend
-}
-
-func NewCache(backend Backend) *Cache {
- return &Cache{new(leveldb.Batch), make(map[string][]byte), backend}
-}
-
-func (self *Cache) Get(key []byte) []byte {
- data := self.store[string(key)]
- if data == nil {
- data, _ = self.backend.Get(key)
- }
-
- return data
-}
-
-func (self *Cache) Put(key []byte, data []byte) {
- self.batch.Put(key, data)
- self.store[string(key)] = data
-}
-
-// Flush flushes the trie to the backing layer. If this is a leveldb instance
-// we'll use a batched write, otherwise we'll use regular put.
-func (self *Cache) Flush() {
- if db, ok := self.backend.(*ethdb.LDBDatabase); ok {
- if err := db.LDB().Write(self.batch, nil); err != nil {
- glog.Fatal("db write err:", err)
- }
- } else {
- for k, v := range self.store {
- self.backend.Put([]byte(k), v)
- }
- }
-}
-
-func (self *Cache) Copy() *Cache {
- cache := NewCache(self.backend)
- for k, v := range self.store {
- cache.store[k] = v
- }
- return cache
-}
-
-func (self *Cache) Reset() {
- //self.store = make(map[string][]byte)
-}
diff --git a/trie/encoding.go b/trie/encoding.go
index 9c862d78f..3c172b843 100644
--- a/trie/encoding.go
+++ b/trie/encoding.go
@@ -16,34 +16,36 @@
package trie
-func CompactEncode(hexSlice []byte) []byte {
- terminator := 0
+func compactEncode(hexSlice []byte) []byte {
+ terminator := byte(0)
if hexSlice[len(hexSlice)-1] == 16 {
terminator = 1
- }
-
- if terminator == 1 {
hexSlice = hexSlice[:len(hexSlice)-1]
}
-
- oddlen := len(hexSlice) % 2
- flags := byte(2*terminator + oddlen)
- if oddlen != 0 {
- hexSlice = append([]byte{flags}, hexSlice...)
- } else {
- hexSlice = append([]byte{flags, 0}, hexSlice...)
+ var (
+ odd = byte(len(hexSlice) % 2)
+ buflen = len(hexSlice)/2 + 1
+ bi, hi = 0, 0 // indices
+ hs = byte(0) // shift: flips between 0 and 4
+ )
+ if odd == 0 {
+ bi = 1
+ hs = 4
}
-
- l := len(hexSlice) / 2
- var buf = make([]byte, l)
- for i := 0; i < l; i++ {
- buf[i] = 16*hexSlice[2*i] + hexSlice[2*i+1]
+ buf := make([]byte, buflen)
+ buf[0] = terminator<<5 | byte(odd)<<4
+ for bi < len(buf) && hi < len(hexSlice) {
+ buf[bi] |= hexSlice[hi] << hs
+ if hs == 0 {
+ bi++
+ }
+ hi, hs = hi+1, hs^(1<<2)
}
return buf
}
-func CompactDecode(str []byte) []byte {
- base := CompactHexDecode(str)
+func compactDecode(str []byte) []byte {
+ base := compactHexDecode(str)
base = base[:len(base)-1]
if base[0] >= 2 {
base = append(base, 16)
@@ -53,11 +55,10 @@ func CompactDecode(str []byte) []byte {
} else {
base = base[2:]
}
-
return base
}
-func CompactHexDecode(str []byte) []byte {
+func compactHexDecode(str []byte) []byte {
l := len(str)*2 + 1
var nibbles = make([]byte, l)
for i, b := range str {
@@ -68,7 +69,7 @@ func CompactHexDecode(str []byte) []byte {
return nibbles
}
-func DecodeCompact(key []byte) []byte {
+func decodeCompact(key []byte) []byte {
l := len(key) / 2
var res = make([]byte, l)
for i := 0; i < l; i++ {
@@ -77,3 +78,30 @@ func DecodeCompact(key []byte) []byte {
}
return res
}
+
+// prefixLen returns the length of the common prefix of a and b.
+func prefixLen(a, b []byte) int {
+ var i, length = 0, len(a)
+ if len(b) < length {
+ length = len(b)
+ }
+ for ; i < length; i++ {
+ if a[i] != b[i] {
+ break
+ }
+ }
+ return i
+}
+
+func hasTerm(s []byte) bool {
+ return s[len(s)-1] == 16
+}
+
+func remTerm(s []byte) []byte {
+ if hasTerm(s) {
+ b := make([]byte, len(s)-1)
+ copy(b, s)
+ return b
+ }
+ return s
+}
diff --git a/trie/encoding_test.go b/trie/encoding_test.go
index e49b57ef0..061d48d58 100644
--- a/trie/encoding_test.go
+++ b/trie/encoding_test.go
@@ -23,7 +23,7 @@ import (
checker "gopkg.in/check.v1"
)
-func Test(t *testing.T) { checker.TestingT(t) }
+func TestEncoding(t *testing.T) { checker.TestingT(t) }
type TrieEncodingSuite struct{}
@@ -32,64 +32,64 @@ var _ = checker.Suite(&TrieEncodingSuite{})
func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) {
// even compact encode
test1 := []byte{1, 2, 3, 4, 5}
- res1 := CompactEncode(test1)
+ res1 := compactEncode(test1)
c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45"))
// odd compact encode
test2 := []byte{0, 1, 2, 3, 4, 5}
- res2 := CompactEncode(test2)
+ res2 := compactEncode(test2)
c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45"))
//odd terminated compact encode
test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
- res3 := CompactEncode(test3)
+ res3 := compactEncode(test3)
c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8"))
// even terminated compact encode
test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16}
- res4 := CompactEncode(test4)
+ res4 := compactEncode(test4)
c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8"))
}
func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) {
exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
- res := CompactHexDecode([]byte("verb"))
+ res := compactHexDecode([]byte("verb"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) {
// odd compact decode
exp := []byte{1, 2, 3, 4, 5}
- res := CompactDecode([]byte("\x11\x23\x45"))
+ res := compactDecode([]byte("\x11\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even compact decode
exp = []byte{0, 1, 2, 3, 4, 5}
- res = CompactDecode([]byte("\x00\x01\x23\x45"))
+ res = compactDecode([]byte("\x00\x01\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
- res = CompactDecode([]byte("\x20\x0f\x1c\xb8"))
+ res = compactDecode([]byte("\x20\x0f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{15, 1, 12, 11, 8 /*term*/, 16}
- res = CompactDecode([]byte("\x3f\x1c\xb8"))
+ res = compactDecode([]byte("\x3f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) {
exp, _ := hex.DecodeString("012345")
- res := DecodeCompact([]byte{0, 1, 2, 3, 4, 5})
+ res := decodeCompact([]byte{0, 1, 2, 3, 4, 5})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("012345")
- res = DecodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
+ res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("abcdef")
- res = DecodeCompact([]byte{10, 11, 12, 13, 14, 15})
+ res = decodeCompact([]byte{10, 11, 12, 13, 14, 15})
c.Assert(res, checker.DeepEquals, exp)
}
@@ -97,29 +97,27 @@ func BenchmarkCompactEncode(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
for i := 0; i < b.N; i++ {
- CompactEncode(testBytes)
+ compactEncode(testBytes)
}
}
func BenchmarkCompactDecode(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
for i := 0; i < b.N; i++ {
- CompactDecode(testBytes)
+ compactDecode(testBytes)
}
}
func BenchmarkCompactHexDecode(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
- CompactHexDecode(testBytes)
+ compactHexDecode(testBytes)
}
-
}
func BenchmarkDecodeCompact(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
- DecodeCompact(testBytes)
+ decodeCompact(testBytes)
}
-
}
diff --git a/trie/fullnode.go b/trie/fullnode.go
deleted file mode 100644
index 8ff019ec4..000000000
--- a/trie/fullnode.go
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-type FullNode struct {
- trie *Trie
- nodes [17]Node
- dirty bool
-}
-
-func NewFullNode(t *Trie) *FullNode {
- return &FullNode{trie: t}
-}
-
-func (self *FullNode) Dirty() bool { return self.dirty }
-func (self *FullNode) Value() Node {
- self.nodes[16] = self.trie.trans(self.nodes[16])
- return self.nodes[16]
-}
-func (self *FullNode) Branches() []Node {
- return self.nodes[:16]
-}
-
-func (self *FullNode) Copy(t *Trie) Node {
- nnode := NewFullNode(t)
- for i, node := range self.nodes {
- if node != nil {
- nnode.nodes[i] = node
- }
- }
- nnode.dirty = true
-
- return nnode
-}
-
-// Returns the length of non-nil nodes
-func (self *FullNode) Len() (amount int) {
- for _, node := range self.nodes {
- if node != nil {
- amount++
- }
- }
-
- return
-}
-
-func (self *FullNode) Hash() interface{} {
- return self.trie.store(self)
-}
-
-func (self *FullNode) RlpData() interface{} {
- t := make([]interface{}, 17)
- for i, node := range self.nodes {
- if node != nil {
- t[i] = node.Hash()
- } else {
- t[i] = ""
- }
- }
-
- return t
-}
-
-func (self *FullNode) set(k byte, value Node) {
- self.nodes[int(k)] = value
- self.dirty = true
-}
-
-func (self *FullNode) branch(i byte) Node {
- if self.nodes[int(i)] != nil {
- self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)])
-
- return self.nodes[int(i)]
- }
- return nil
-}
-
-func (self *FullNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
diff --git a/trie/hashnode.go b/trie/hashnode.go
deleted file mode 100644
index d4a0bc7ec..000000000
--- a/trie/hashnode.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type HashNode struct {
- key []byte
- trie *Trie
- dirty bool
-}
-
-func NewHash(key []byte, trie *Trie) *HashNode {
- return &HashNode{key, trie, false}
-}
-
-func (self *HashNode) RlpData() interface{} {
- return self.key
-}
-
-func (self *HashNode) Hash() interface{} {
- return self.key
-}
-
-func (self *HashNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
-
-// These methods will never be called but we have to satisfy Node interface
-func (self *HashNode) Value() Node { return nil }
-func (self *HashNode) Dirty() bool { return true }
-func (self *HashNode) Copy(t *Trie) Node { return NewHash(common.CopyBytes(self.key), t) }
diff --git a/trie/iterator.go b/trie/iterator.go
index 9c4c7fbe5..38555fe08 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -16,9 +16,7 @@
package trie
-import (
- "bytes"
-)
+import "bytes"
type Iterator struct {
trie *Trie
@@ -32,32 +30,29 @@ func NewIterator(trie *Trie) *Iterator {
}
func (self *Iterator) Next() bool {
- self.trie.mu.Lock()
- defer self.trie.mu.Unlock()
-
isIterStart := false
if self.Key == nil {
isIterStart = true
self.Key = make([]byte, 32)
}
- key := RemTerm(CompactHexDecode(self.Key))
+ key := remTerm(compactHexDecode(self.Key))
k := self.next(self.trie.root, key, isIterStart)
- self.Key = []byte(DecodeCompact(k))
+ self.Key = []byte(decodeCompact(k))
return len(k) > 0
}
-func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
+func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byte {
if node == nil {
return nil
}
switch node := node.(type) {
- case *FullNode:
+ case fullNode:
if len(key) > 0 {
- k := self.next(node.branch(key[0]), key[1:], isIterStart)
+ k := self.next(node[key[0]], key[1:], isIterStart)
if k != nil {
return append([]byte{key[0]}, k...)
}
@@ -69,31 +64,31 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
}
for i := r; i < 16; i++ {
- k := self.key(node.branch(byte(i)))
+ k := self.key(node[i])
if k != nil {
return append([]byte{i}, k...)
}
}
- case *ShortNode:
- k := RemTerm(node.Key())
- if vnode, ok := node.Value().(*ValueNode); ok {
+ case shortNode:
+ k := remTerm(node.Key)
+ if vnode, ok := node.Val.(valueNode); ok {
switch bytes.Compare([]byte(k), key) {
case 0:
if isIterStart {
- self.Value = vnode.Val()
+ self.Value = vnode
return k
}
case 1:
- self.Value = vnode.Val()
+ self.Value = vnode
return k
}
} else {
- cnode := node.Value()
+ cnode := node.Val
var ret []byte
skey := key[len(k):]
- if BeginsWith(key, k) {
+ if bytes.HasPrefix(key, k) {
ret = self.next(cnode, skey, isIterStart)
} else if bytes.Compare(k, key[:len(k)]) > 0 {
return self.key(node)
@@ -103,37 +98,36 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
return append(k, ret...)
}
}
- }
+ case hashNode:
+ return self.next(self.trie.resolveHash(node), key, isIterStart)
+ }
return nil
}
-func (self *Iterator) key(node Node) []byte {
+func (self *Iterator) key(node interface{}) []byte {
switch node := node.(type) {
- case *ShortNode:
+ case shortNode:
// Leaf node
- if vnode, ok := node.Value().(*ValueNode); ok {
- k := RemTerm(node.Key())
- self.Value = vnode.Val()
-
+ k := remTerm(node.Key)
+ if vnode, ok := node.Val.(valueNode); ok {
+ self.Value = vnode
return k
- } else {
- k := RemTerm(node.Key())
- return append(k, self.key(node.Value())...)
}
- case *FullNode:
- if node.Value() != nil {
- self.Value = node.Value().(*ValueNode).Val()
-
+ return append(k, self.key(node.Val)...)
+ case fullNode:
+ if node[16] != nil {
+ self.Value = node[16].(valueNode)
return []byte{16}
}
-
for i := 0; i < 16; i++ {
- k := self.key(node.branch(byte(i)))
+ k := self.key(node[i])
if k != nil {
return append([]byte{byte(i)}, k...)
}
}
+ case hashNode:
+ return self.key(self.trie.resolveHash(node))
}
return nil
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 148f9adf9..fdc60b412 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -19,7 +19,7 @@ package trie
import "testing"
func TestIterator(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -32,11 +32,11 @@ func TestIterator(t *testing.T) {
v := make(map[string]bool)
for _, val := range vals {
v[val.k] = false
- trie.UpdateString(val.k, val.v)
+ trie.Update([]byte(val.k), []byte(val.v))
}
trie.Commit()
- it := trie.Iterator()
+ it := NewIterator(trie)
for it.Next() {
v[string(it.Key)] = true
}
diff --git a/trie/node.go b/trie/node.go
index 9d49029de..0bfa21dc4 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -16,46 +16,172 @@
package trie
-import "fmt"
+import (
+ "fmt"
+ "io"
+ "strings"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
-type Node interface {
- Value() Node
- Copy(*Trie) Node // All nodes, for now, return them self
- Dirty() bool
+type node interface {
fstring(string) string
- Hash() interface{}
- RlpData() interface{}
- setDirty(dirty bool)
}
-// Value node
-func (self *ValueNode) String() string { return self.fstring("") }
-func (self *FullNode) String() string { return self.fstring("") }
-func (self *ShortNode) String() string { return self.fstring("") }
-func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) }
+type (
+ fullNode [17]node
+ shortNode struct {
+ Key []byte
+ Val node
+ }
+ hashNode []byte
+ valueNode []byte
+)
-//func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("< %x > ", self.key) }
-func (self *HashNode) fstring(ind string) string {
- return fmt.Sprintf("%v", self.trie.trans(self))
-}
+// Pretty printing.
+func (n fullNode) String() string { return n.fstring("") }
+func (n shortNode) String() string { return n.fstring("") }
+func (n hashNode) String() string { return n.fstring("") }
+func (n valueNode) String() string { return n.fstring("") }
-// Full node
-func (self *FullNode) fstring(ind string) string {
+func (n fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
- for i, node := range self.nodes {
+ for i, node := range n {
if node == nil {
resp += fmt.Sprintf("%s: <nil> ", indices[i])
} else {
resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" "))
}
}
-
return resp + fmt.Sprintf("\n%s] ", ind)
}
+func (n shortNode) fstring(ind string) string {
+ return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
+}
+func (n hashNode) fstring(ind string) string {
+ return fmt.Sprintf("<%x> ", []byte(n))
+}
+func (n valueNode) fstring(ind string) string {
+ return fmt.Sprintf("%x ", []byte(n))
+}
+
+func mustDecodeNode(dbkey, buf []byte) node {
+ n, err := decodeNode(buf)
+ if err != nil {
+ panic(fmt.Sprintf("node %x: %v", dbkey, err))
+ }
+ return n
+}
+
+// decodeNode parses the RLP encoding of a trie node.
+func decodeNode(buf []byte) (node, error) {
+ if len(buf) == 0 {
+ return nil, io.ErrUnexpectedEOF
+ }
+ elems, _, err := rlp.SplitList(buf)
+ if err != nil {
+ return nil, fmt.Errorf("decode error: %v", err)
+ }
+ switch c, _ := rlp.CountValues(elems); c {
+ case 2:
+ n, err := decodeShort(elems)
+ return n, wrapError(err, "short")
+ case 17:
+ n, err := decodeFull(elems)
+ return n, wrapError(err, "full")
+ default:
+ return nil, fmt.Errorf("invalid number of list elements: %v", c)
+ }
+}
+
+func decodeShort(buf []byte) (node, error) {
+ kbuf, rest, err := rlp.SplitString(buf)
+ if err != nil {
+ return nil, err
+ }
+ key := compactDecode(kbuf)
+ if key[len(key)-1] == 16 {
+ // value node
+ val, _, err := rlp.SplitString(rest)
+ if err != nil {
+ return nil, fmt.Errorf("invalid value node: %v", err)
+ }
+ return shortNode{key, valueNode(val)}, nil
+ }
+ r, _, err := decodeRef(rest)
+ if err != nil {
+ return nil, wrapError(err, "val")
+ }
+ return shortNode{key, r}, nil
+}
+
+func decodeFull(buf []byte) (fullNode, error) {
+ var n fullNode
+ for i := 0; i < 16; i++ {
+ cld, rest, err := decodeRef(buf)
+ if err != nil {
+ return n, wrapError(err, fmt.Sprintf("[%d]", i))
+ }
+ n[i], buf = cld, rest
+ }
+ val, _, err := rlp.SplitString(buf)
+ if err != nil {
+ return n, err
+ }
+ if len(val) > 0 {
+ n[16] = valueNode(val)
+ }
+ return n, nil
+}
+
+const hashLen = len(common.Hash{})
+
+func decodeRef(buf []byte) (node, []byte, error) {
+ kind, val, rest, err := rlp.Split(buf)
+ if err != nil {
+ return nil, buf, err
+ }
+ switch {
+ case kind == rlp.List:
+ // 'embedded' node reference. The encoding must be smaller
+ // than a hash in order to be valid.
+ if size := len(buf) - len(rest); size > hashLen {
+ err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
+ return nil, buf, err
+ }
+ n, err := decodeNode(buf)
+ return n, rest, err
+ case kind == rlp.String && len(val) == 0:
+ // empty node
+ return nil, rest, nil
+ case kind == rlp.String && len(val) == 32:
+ return hashNode(val), rest, nil
+ default:
+ return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
+ }
+}
+
+// wraps a decoding error with information about the path to the
+// invalid child node (for debugging encoding issues).
+type decodeError struct {
+ what error
+ stack []string
+}
+
+func wrapError(err error, ctx string) error {
+ if err == nil {
+ return nil
+ }
+ if decErr, ok := err.(*decodeError); ok {
+ decErr.stack = append(decErr.stack, ctx)
+ return decErr
+ }
+ return &decodeError{err, []string{ctx}}
+}
-// Short node
-func (self *ShortNode) fstring(ind string) string {
- return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" "))
+func (err *decodeError) Error() string {
+ return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-"))
}
diff --git a/trie/proof.go b/trie/proof.go
new file mode 100644
index 000000000..a705c49db
--- /dev/null
+++ b/trie/proof.go
@@ -0,0 +1,122 @@
+package trie
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Prove constructs a merkle proof for key. The result contains all
+// encoded nodes on the path to the value at key. The value itself is
+// also included in the last node and can be retrieved by verifying
+// the proof.
+//
+// The returned proof is nil if the trie does not contain a value for key.
+// For existing keys, the proof will have at least one element.
+func (t *Trie) Prove(key []byte) []rlp.RawValue {
+ // Collect all nodes on the path to key.
+ key = compactHexDecode(key)
+ nodes := []node{}
+ tn := t.root
+ for len(key) > 0 {
+ switch n := tn.(type) {
+ case shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ // The trie doesn't contain the key.
+ return nil
+ }
+ tn = n.Val
+ key = key[len(n.Key):]
+ nodes = append(nodes, n)
+ case fullNode:
+ tn = n[key[0]]
+ key = key[1:]
+ nodes = append(nodes, n)
+ case nil:
+ return nil
+ case hashNode:
+ tn = t.resolveHash(n)
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
+ }
+ if t.hasher == nil {
+ t.hasher = newHasher()
+ }
+ proof := make([]rlp.RawValue, 0, len(nodes))
+ for i, n := range nodes {
+ // Don't bother checking for errors here since hasher panics
+ // if encoding doesn't work and we're not writing to any database.
+ n, _ = t.hasher.replaceChildren(n, nil)
+ hn, _ := t.hasher.store(n, nil, false)
+ if _, ok := hn.(hashNode); ok || i == 0 {
+ // If the node's database encoding is a hash (or is the
+ // root node), it becomes a proof element.
+ enc, _ := rlp.EncodeToBytes(n)
+ proof = append(proof, enc)
+ }
+ }
+ return proof
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the
+// value for key in a trie with the given root hash. VerifyProof
+// returns an error if the proof contains invalid trie nodes or the
+// wrong value.
+func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
+ key = compactHexDecode(key)
+ sha := sha3.NewKeccak256()
+ wantHash := rootHash.Bytes()
+ for i, buf := range proof {
+ sha.Reset()
+ sha.Write(buf)
+ if !bytes.Equal(sha.Sum(nil), wantHash) {
+ return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
+ }
+ n, err := decodeNode(buf)
+ if err != nil {
+ return nil, fmt.Errorf("bad proof node %d: %v", i, err)
+ }
+ keyrest, cld := get(n, key)
+ switch cld := cld.(type) {
+ case nil:
+ return nil, fmt.Errorf("key mismatch at proof node %d", i)
+ case hashNode:
+ key = keyrest
+ wantHash = cld
+ case valueNode:
+ if i != len(proof)-1 {
+ return nil, errors.New("additional nodes at end of proof")
+ }
+ return cld, nil
+ }
+ }
+ return nil, errors.New("unexpected end of proof")
+}
+
+func get(tn node, key []byte) ([]byte, node) {
+ for len(key) > 0 {
+ switch n := tn.(type) {
+ case shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ return nil, nil
+ }
+ tn = n.Val
+ key = key[len(n.Key):]
+ case fullNode:
+ tn = n[key[0]]
+ key = key[1:]
+ case hashNode:
+ return key, n
+ case nil:
+ return key, nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
+ }
+ return nil, tn.(valueNode)
+}
diff --git a/trie/proof_test.go b/trie/proof_test.go
new file mode 100644
index 000000000..6b5bef05c
--- /dev/null
+++ b/trie/proof_test.go
@@ -0,0 +1,139 @@
+package trie
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ mrand "math/rand"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+func init() {
+ mrand.Seed(time.Now().Unix())
+}
+
+func TestProof(t *testing.T) {
+ trie, vals := randomTrie(500)
+ root := trie.Hash()
+ for _, kv := range vals {
+ proof := trie.Prove(kv.k)
+ if proof == nil {
+ t.Fatalf("missing key %x while constructing proof", kv.k)
+ }
+ val, err := VerifyProof(root, kv.k, proof)
+ if err != nil {
+ t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof)
+ }
+ if !bytes.Equal(val, kv.v) {
+ t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v)
+ }
+ }
+}
+
+func TestOneElementProof(t *testing.T) {
+ trie := new(Trie)
+ updateString(trie, "k", "v")
+ proof := trie.Prove([]byte("k"))
+ if proof == nil {
+ t.Fatal("nil proof")
+ }
+ if len(proof) != 1 {
+ t.Error("proof should have one element")
+ }
+ val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+ if err != nil {
+ t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof)
+ }
+ if !bytes.Equal(val, []byte("v")) {
+ t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val)
+ }
+}
+
+func TestVerifyBadProof(t *testing.T) {
+ trie, vals := randomTrie(800)
+ root := trie.Hash()
+ for _, kv := range vals {
+ proof := trie.Prove(kv.k)
+ if proof == nil {
+ t.Fatal("nil proof")
+ }
+ mutateByte(proof[mrand.Intn(len(proof))])
+ if _, err := VerifyProof(root, kv.k, proof); err == nil {
+ t.Fatalf("expected proof to fail for key %x", kv.k)
+ }
+ }
+}
+
+// mutateByte changes one byte in b.
+func mutateByte(b []byte) {
+ for r := mrand.Intn(len(b)); ; {
+ new := byte(mrand.Intn(255))
+ if new != b[r] {
+ b[r] = new
+ break
+ }
+ }
+}
+
+func BenchmarkProve(b *testing.B) {
+ trie, vals := randomTrie(100)
+ var keys []string
+ for k := range vals {
+ keys = append(keys, k)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ kv := vals[keys[i%len(keys)]]
+ if trie.Prove(kv.k) == nil {
+ b.Fatalf("nil proof for %x", kv.k)
+ }
+ }
+}
+
+func BenchmarkVerifyProof(b *testing.B) {
+ trie, vals := randomTrie(100)
+ root := trie.Hash()
+ var keys []string
+ var proofs [][]rlp.RawValue
+ for k := range vals {
+ keys = append(keys, k)
+ proofs = append(proofs, trie.Prove([]byte(k)))
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ im := i % len(keys)
+ if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
+ b.Fatalf("key %x: error", keys[im], err)
+ }
+ }
+}
+
+func randomTrie(n int) (*Trie, map[string]*kv) {
+ trie := new(Trie)
+ vals := make(map[string]*kv)
+ for i := byte(0); i < 100; i++ {
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+ for i := 0; i < n; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ vals[string(value.k)] = value
+ }
+ return trie, vals
+}
+
+func randBytes(n int) []byte {
+ r := make([]byte, n)
+ crand.Read(r)
+ return r
+}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 47c7542bb..47d1934d0 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -16,46 +16,93 @@
package trie
-import "github.com/ethereum/go-ethereum/crypto"
+import (
+ "hash"
-var keyPrefix = []byte("secure-key-")
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+)
+var secureKeyPrefix = []byte("secure-key-")
+
+// SecureTrie wraps a trie with key hashing. In a secure trie, all
+// access operations hash the key using keccak256. This prevents
+// calling code from creating long chains of nodes that
+// increase the access time.
+//
+// Contrary to a regular trie, a SecureTrie can only be created with
+// New and must have an attached database. The database also stores
+// the preimage of each key.
+//
+// SecureTrie is not safe for concurrent use.
type SecureTrie struct {
*Trie
-}
-func NewSecure(root []byte, backend Backend) *SecureTrie {
- return &SecureTrie{New(root, backend)}
+ hash hash.Hash
+ secKeyBuf []byte
+ hashKeyBuf []byte
}
-func (self *SecureTrie) Update(key, value []byte) Node {
- shaKey := crypto.Sha3(key)
- self.Trie.cache.Put(append(keyPrefix, shaKey...), key)
-
- return self.Trie.Update(shaKey, value)
-}
-func (self *SecureTrie) UpdateString(key, value string) Node {
- return self.Update([]byte(key), []byte(value))
+// NewSecure creates a trie with an existing root node from db.
+//
+// If root is the zero hash or the sha3 hash of an empty string, the
+// trie is initially empty. Otherwise, New will panics if db is nil
+// and returns ErrMissingRoot if the root node cannpt be found.
+// Accessing the trie loads nodes from db on demand.
+func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
+ if db == nil {
+ panic("NewSecure called with nil database")
+ }
+ trie, err := New(root, db)
+ if err != nil {
+ return nil, err
+ }
+ return &SecureTrie{Trie: trie}, nil
}
-func (self *SecureTrie) Get(key []byte) []byte {
- return self.Trie.Get(crypto.Sha3(key))
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *SecureTrie) Get(key []byte) []byte {
+ return t.Trie.Get(t.hashKey(key))
}
-func (self *SecureTrie) GetString(key string) []byte {
- return self.Get([]byte(key))
+
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *SecureTrie) Update(key, value []byte) {
+ hk := t.hashKey(key)
+ t.Trie.Update(hk, value)
+ t.Trie.db.Put(t.secKey(hk), key)
}
-func (self *SecureTrie) Delete(key []byte) Node {
- return self.Trie.Delete(crypto.Sha3(key))
+// Delete removes any existing value for key from the trie.
+func (t *SecureTrie) Delete(key []byte) {
+ t.Trie.Delete(t.hashKey(key))
}
-func (self *SecureTrie) DeleteString(key string) Node {
- return self.Delete([]byte(key))
+
+// GetKey returns the sha3 preimage of a hashed key that was
+// previously used to store a value.
+func (t *SecureTrie) GetKey(shaKey []byte) []byte {
+ key, _ := t.Trie.db.Get(t.secKey(shaKey))
+ return key
}
-func (self *SecureTrie) Copy() *SecureTrie {
- return &SecureTrie{self.Trie.Copy()}
+func (t *SecureTrie) secKey(key []byte) []byte {
+ t.secKeyBuf = append(t.secKeyBuf[:0], secureKeyPrefix...)
+ t.secKeyBuf = append(t.secKeyBuf, key...)
+ return t.secKeyBuf
}
-func (self *SecureTrie) GetKey(shaKey []byte) []byte {
- return self.Trie.cache.Get(append(keyPrefix, shaKey...))
+func (t *SecureTrie) hashKey(key []byte) []byte {
+ if t.hash == nil {
+ t.hash = sha3.NewKeccak256()
+ t.hashKeyBuf = make([]byte, 32)
+ }
+ t.hash.Reset()
+ t.hash.Write(key)
+ t.hashKeyBuf = t.hash.Sum(t.hashKeyBuf[:0])
+ return t.hashKeyBuf
}
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
new file mode 100644
index 000000000..13c6cd02e
--- /dev/null
+++ b/trie/secure_trie_test.go
@@ -0,0 +1,74 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+)
+
+func newEmptySecure() *SecureTrie {
+ db, _ := ethdb.NewMemDatabase()
+ trie, _ := NewSecure(common.Hash{}, db)
+ return trie
+}
+
+func TestSecureDelete(t *testing.T) {
+ trie := newEmptySecure()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.Update([]byte(val.k), []byte(val.v))
+ } else {
+ trie.Delete([]byte(val.k))
+ }
+ }
+ hash := trie.Hash()
+ exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestSecureGetKey(t *testing.T) {
+ trie := newEmptySecure()
+ trie.Update([]byte("foo"), []byte("bar"))
+
+ key := []byte("foo")
+ value := []byte("bar")
+ seckey := crypto.Sha3(key)
+
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Errorf("Get did not return bar")
+ }
+ if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
+ t.Errorf("GetKey returned %q, want %q", k, key)
+ }
+}
diff --git a/trie/shortnode.go b/trie/shortnode.go
deleted file mode 100644
index 569d5f109..000000000
--- a/trie/shortnode.go
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type ShortNode struct {
- trie *Trie
- key []byte
- value Node
- dirty bool
-}
-
-func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
- return &ShortNode{t, CompactEncode(key), value, false}
-}
-func (self *ShortNode) Value() Node {
- self.value = self.trie.trans(self.value)
-
- return self.value
-}
-func (self *ShortNode) Dirty() bool { return self.dirty }
-func (self *ShortNode) Copy(t *Trie) Node {
- node := &ShortNode{t, nil, self.value.Copy(t), self.dirty}
- node.key = common.CopyBytes(self.key)
- node.dirty = true
- return node
-}
-
-func (self *ShortNode) RlpData() interface{} {
- return []interface{}{self.key, self.value.Hash()}
-}
-func (self *ShortNode) Hash() interface{} {
- return self.trie.store(self)
-}
-
-func (self *ShortNode) Key() []byte {
- return CompactDecode(self.key)
-}
-
-func (self *ShortNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
diff --git a/trie/slice.go b/trie/slice.go
deleted file mode 100644
index ccefbd064..000000000
--- a/trie/slice.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import (
- "bytes"
- "math"
-)
-
-// Helper function for comparing slices
-func CompareIntSlice(a, b []int) bool {
- if len(a) != len(b) {
- return false
- }
- for i, v := range a {
- if v != b[i] {
- return false
- }
- }
- return true
-}
-
-// Returns the amount of nibbles that match each other from 0 ...
-func MatchingNibbleLength(a, b []byte) int {
- var i, length = 0, int(math.Min(float64(len(a)), float64(len(b))))
-
- for i < length {
- if a[i] != b[i] {
- break
- }
- i++
- }
-
- return i
-}
-
-func HasTerm(s []byte) bool {
- return s[len(s)-1] == 16
-}
-
-func RemTerm(s []byte) []byte {
- if HasTerm(s) {
- return s[:len(s)-1]
- }
-
- return s
-}
-
-func BeginsWith(a, b []byte) bool {
- if len(b) > len(a) {
- return false
- }
-
- return bytes.Equal(a[:len(b)], b)
-}
diff --git a/trie/trie.go b/trie/trie.go
index abf48a850..aa8d39fe2 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -19,372 +19,425 @@ package trie
import (
"bytes"
- "container/list"
+ "errors"
"fmt"
- "sync"
+ "hash"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/rlp"
)
-func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
- t2 := New(nil, backend)
+const defaultCacheCapacity = 800
- it := t1.Iterator()
- for it.Next() {
- t2.Update(it.Key, it.Value)
- }
-
- return bytes.Equal(t2.Hash(), t1.Hash()), t2
-}
-
-type Trie struct {
- mu sync.Mutex
- root Node
- roothash []byte
- cache *Cache
-
- revisions *list.List
-}
-
-func New(root []byte, backend Backend) *Trie {
- trie := &Trie{}
- trie.revisions = list.New()
- trie.roothash = root
- if backend != nil {
- trie.cache = NewCache(backend)
- }
+var (
+ // The global cache stores decoded trie nodes by hash as they get loaded.
+ globalCache = newARC(defaultCacheCapacity)
+ // This is the known root hash of an empty trie.
+ emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+)
- if root != nil {
- value := common.NewValueFromBytes(trie.cache.Get(root))
- trie.root = trie.mknode(value)
- }
+var ErrMissingRoot = errors.New("missing root node")
- return trie
+// Database must be implemented by backing stores for the trie.
+type Database interface {
+ DatabaseWriter
+ // Get returns the value for key from the database.
+ Get(key []byte) (value []byte, err error)
}
-func (self *Trie) Iterator() *Iterator {
- return NewIterator(self)
+// DatabaseWriter wraps the Put method of a backing store for the trie.
+type DatabaseWriter interface {
+ // Put stores the mapping key->value in the database.
+ // Implementations must not hold onto the value bytes, the trie
+ // will reuse the slice across calls to Put.
+ Put(key, value []byte) error
}
-func (self *Trie) Copy() *Trie {
- cpy := make([]byte, 32)
- copy(cpy, self.roothash) // NOTE: cpy isn't being used anywhere?
- trie := New(nil, nil)
- trie.cache = self.cache.Copy()
- if self.root != nil {
- trie.root = self.root.Copy(trie)
- }
-
- return trie
+// Trie is a Merkle Patricia Trie.
+// The zero value is an empty trie with no database.
+// Use New to create a trie that sits on top of a database.
+//
+// Trie is not safe for concurrent use.
+type Trie struct {
+ root node
+ db Database
+ *hasher
}
-// Legacy support
-func (self *Trie) Root() []byte { return self.Hash() }
-func (self *Trie) Hash() []byte {
- var hash []byte
- if self.root != nil {
- t := self.root.Hash()
- if byts, ok := t.([]byte); ok && len(byts) > 0 {
- hash = byts
- } else {
- hash = crypto.Sha3(common.Encode(self.root.RlpData()))
+// New creates a trie with an existing root node from db.
+//
+// If root is the zero hash or the sha3 hash of an empty string, the
+// trie is initially empty and does not require a database. Otherwise,
+// New will panics if db is nil or root does not exist in the
+// database. Accessing the trie loads nodes from db on demand.
+func New(root common.Hash, db Database) (*Trie, error) {
+ trie := &Trie{db: db}
+ if (root != common.Hash{}) && root != emptyRoot {
+ if db == nil {
+ panic("trie.New: cannot use existing root without a database")
}
- } else {
- hash = crypto.Sha3(common.Encode(""))
- }
-
- if !bytes.Equal(hash, self.roothash) {
- self.revisions.PushBack(self.roothash)
- self.roothash = hash
+ if v, _ := trie.db.Get(root[:]); len(v) == 0 {
+ return nil, ErrMissingRoot
+ }
+ trie.root = hashNode(root.Bytes())
}
-
- return hash
+ return trie, nil
}
-func (self *Trie) Commit() {
- self.mu.Lock()
- defer self.mu.Unlock()
- // Hash first
- self.Hash()
-
- self.cache.Flush()
+// Iterator returns an iterator over all mappings in the trie.
+func (t *Trie) Iterator() *Iterator {
+ return NewIterator(t)
}
-// Reset should only be called if the trie has been hashed
-func (self *Trie) Reset() {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- self.cache.Reset()
-
- if self.revisions.Len() > 0 {
- revision := self.revisions.Remove(self.revisions.Back()).([]byte)
- self.roothash = revision
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *Trie) Get(key []byte) []byte {
+ key = compactHexDecode(key)
+ tn := t.root
+ for len(key) > 0 {
+ switch n := tn.(type) {
+ case shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ return nil
+ }
+ tn = n.Val
+ key = key[len(n.Key):]
+ case fullNode:
+ tn = n[key[0]]
+ key = key[1:]
+ case nil:
+ return nil
+ case hashNode:
+ tn = t.resolveHash(n)
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
}
- value := common.NewValueFromBytes(self.cache.Get(self.roothash))
- self.root = self.mknode(value)
+ return tn.(valueNode)
}
-func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
-func (self *Trie) Update(key, value []byte) Node {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
-
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *Trie) Update(key, value []byte) {
+ k := compactHexDecode(key)
if len(value) != 0 {
- node := NewValueNode(self, value)
- node.dirty = true
- self.root = self.insert(self.root, k, node)
+ t.root = t.insert(t.root, k, valueNode(value))
} else {
- self.root = self.delete(self.root, k)
+ t.root = t.delete(t.root, k)
}
-
- return self.root
-}
-
-func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
-func (self *Trie) Get(key []byte) []byte {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
-
- n := self.get(self.root, k)
- if n != nil {
- return n.(*ValueNode).Val()
- }
-
- return nil
}
-func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
-func (self *Trie) Delete(key []byte) Node {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
- self.root = self.delete(self.root, k)
-
- return self.root
-}
-
-func (self *Trie) insert(node Node, key []byte, value Node) Node {
+func (t *Trie) insert(n node, key []byte, value node) node {
if len(key) == 0 {
return value
}
-
- if node == nil {
- node := NewShortNode(self, key, value)
- node.dirty = true
- return node
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
- if bytes.Equal(k, key) {
- node := NewShortNode(self, key, value)
- node.dirty = true
- return node
-
+ switch n := n.(type) {
+ case shortNode:
+ matchlen := prefixLen(key, n.Key)
+ // If the whole key matches, keep this short node as is
+ // and only update the value.
+ if matchlen == len(n.Key) {
+ return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)}
}
-
- var n Node
- matchlength := MatchingNibbleLength(key, k)
- if matchlength == len(k) {
- n = self.insert(cnode, key[matchlength:], value)
- } else {
- pnode := self.insert(nil, k[matchlength+1:], cnode)
- nnode := self.insert(nil, key[matchlength+1:], value)
- fulln := NewFullNode(self)
- fulln.dirty = true
- fulln.set(k[matchlength], pnode)
- fulln.set(key[matchlength], nnode)
- n = fulln
- }
- if matchlength == 0 {
- return n
+ // Otherwise branch out at the index where they differ.
+ var branch fullNode
+ branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val)
+ branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value)
+ // Replace this shortNode with the branch if it occurs at index 0.
+ if matchlen == 0 {
+ return branch
}
+ // Otherwise, replace it with a short node leading up to the branch.
+ return shortNode{key[:matchlen], branch}
- snode := NewShortNode(self, key[:matchlength], n)
- snode.dirty = true
- return snode
+ case fullNode:
+ n[key[0]] = t.insert(n[key[0]], key[1:], value)
+ return n
- case *FullNode:
- cpy := node.Copy(self).(*FullNode)
- cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
- cpy.dirty = true
+ case nil:
+ return shortNode{key, value}
- return cpy
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and insert into it. This leaves all child nodes on
+ // the path to the value in the trie.
+ //
+ // TODO: track whether insertion changed the value and keep
+ // n as a hash node if it didn't.
+ return t.insert(t.resolveHash(n), key, value)
default:
- panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
-func (self *Trie) get(node Node, key []byte) Node {
- if len(key) == 0 {
- return node
- }
-
- if node == nil {
- return nil
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
-
- if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
- return self.get(cnode, key[len(k):])
- }
-
- return nil
- case *FullNode:
- return self.get(node.branch(key[0]), key[1:])
- default:
- panic(fmt.Sprintf("%T: invalid node: %v", node, node))
- }
+// Delete removes any existing value for key from the trie.
+func (t *Trie) Delete(key []byte) {
+ k := compactHexDecode(key)
+ t.root = t.delete(t.root, k)
}
-func (self *Trie) delete(node Node, key []byte) Node {
- if len(key) == 0 && node == nil {
- return nil
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
- if bytes.Equal(key, k) {
- return nil
- } else if bytes.Equal(key[:len(k)], k) {
- child := self.delete(cnode, key[len(k):])
-
- var n Node
- switch child := child.(type) {
- case *ShortNode:
- nkey := append(k, child.Key()...)
- n = NewShortNode(self, nkey, child.Value())
- n.(*ShortNode).dirty = true
- case *FullNode:
- sn := NewShortNode(self, node.Key(), child)
- sn.dirty = true
- sn.key = node.key
- n = sn
- }
-
- return n
- } else {
- return node
+// delete returns the new root of the trie with key deleted.
+// It reduces the trie to minimal form by simplifying
+// nodes on the way up after deleting recursively.
+func (t *Trie) delete(n node, key []byte) node {
+ switch n := n.(type) {
+ case shortNode:
+ matchlen := prefixLen(key, n.Key)
+ if matchlen < len(n.Key) {
+ return n // don't replace n on mismatch
+ }
+ if matchlen == len(key) {
+ return nil // remove n entirely for whole matches
+ }
+ // The key is longer than n.Key. Remove the remaining suffix
+ // from the subtrie. Child can never be nil here since the
+ // subtrie must contain at least two other values with keys
+ // longer than n.Key.
+ child := t.delete(n.Val, key[len(n.Key):])
+ switch child := child.(type) {
+ case shortNode:
+ // Deleting from the subtrie reduced it to another
+ // short node. Merge the nodes to avoid creating a
+ // shortNode{..., shortNode{...}}. Use concat (which
+ // always creates a new slice) instead of append to
+ // avoid modifying n.Key since it might be shared with
+ // other nodes.
+ return shortNode{concat(n.Key, child.Key...), child.Val}
+ default:
+ return shortNode{n.Key, child}
}
- case *FullNode:
- n := node.Copy(self).(*FullNode)
- n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
- n.dirty = true
-
+ case fullNode:
+ n[key[0]] = t.delete(n[key[0]], key[1:])
+ // Check how many non-nil entries are left after deleting and
+ // reduce the full node to a short node if only one entry is
+ // left. Since n must've contained at least two children
+ // before deletion (otherwise it would not be a full node) n
+ // can never be reduced to nil.
+ //
+ // When the loop is done, pos contains the index of the single
+ // value that is left in n or -2 if n contains at least two
+ // values.
pos := -1
- for i := 0; i < 17; i++ {
- if n.branch(byte(i)) != nil {
+ for i, cld := range n {
+ if cld != nil {
if pos == -1 {
pos = i
} else {
pos = -2
+ break
}
}
}
-
- var nnode Node
- if pos == 16 {
- nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
- nnode.(*ShortNode).dirty = true
- } else if pos >= 0 {
- cnode := n.branch(byte(pos))
- switch cnode := cnode.(type) {
- case *ShortNode:
- // Stitch keys
- k := append([]byte{byte(pos)}, cnode.Key()...)
- nnode = NewShortNode(self, k, cnode.Value())
- nnode.(*ShortNode).dirty = true
- case *FullNode:
- nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
- nnode.(*ShortNode).dirty = true
+ if pos >= 0 {
+ if pos != 16 {
+ // If the remaining entry is a short node, it replaces
+ // n and its key gets the missing nibble tacked to the
+ // front. This avoids creating an invalid
+ // shortNode{..., shortNode{...}}. Since the entry
+ // might not be loaded yet, resolve it just for this
+ // check.
+ cnode := t.resolve(n[pos])
+ if cnode, ok := cnode.(shortNode); ok {
+ k := append([]byte{byte(pos)}, cnode.Key...)
+ return shortNode{k, cnode.Val}
+ }
}
- } else {
- nnode = n
+ // Otherwise, n is replaced by a one-nibble short node
+ // containing the child.
+ return shortNode{[]byte{byte(pos)}, n[pos]}
}
+ // n still contains at least two values and cannot be reduced.
+ return n
- return nnode
case nil:
return nil
+
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and delete from it. This leaves all child nodes on
+ // the path to the value in the trie.
+ //
+ // TODO: track whether deletion actually hit a key and keep
+ // n as a hash node if it didn't.
+ return t.delete(t.resolveHash(n), key)
+
default:
- panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
}
}
-// casting functions and cache storing
-func (self *Trie) mknode(value *common.Value) Node {
- l := value.Len()
- switch l {
- case 0:
- return nil
- case 2:
- // A value node may consists of 2 bytes.
- if value.Get(0).Len() != 0 {
- key := CompactDecode(value.Get(0).Bytes())
- if key[len(key)-1] == 16 {
- return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes()))
- } else {
- return NewShortNode(self, key, self.mknode(value.Get(1)))
- }
- }
- case 17:
- if len(value.Bytes()) != 17 {
- fnode := NewFullNode(self)
- for i := 0; i < 16; i++ {
- fnode.set(byte(i), self.mknode(value.Get(i)))
- }
- return fnode
+func concat(s1 []byte, s2 ...byte) []byte {
+ r := make([]byte, len(s1)+len(s2))
+ copy(r, s1)
+ copy(r[len(s1):], s2)
+ return r
+}
+
+func (t *Trie) resolve(n node) node {
+ if n, ok := n.(hashNode); ok {
+ return t.resolveHash(n)
+ }
+ return n
+}
+
+func (t *Trie) resolveHash(n hashNode) node {
+ if v, ok := globalCache.Get(n); ok {
+ return v
+ }
+ enc, err := t.db.Get(n)
+ if err != nil || enc == nil {
+ // TODO: This needs to be improved to properly distinguish errors.
+ // Disk I/O errors shouldn't produce nil (and cause a
+ // consensus failure or weird crash), but it is unclear how
+ // they could be handled because the entire stack above the trie isn't
+ // prepared to cope with missing state nodes.
+ if glog.V(logger.Error) {
+ glog.Errorf("Dangling hash node ref %x: %v", n, err)
}
- case 32:
- return NewHash(value.Bytes(), self)
+ return nil
+ }
+ dec := mustDecodeNode(n, enc)
+ if dec != nil {
+ globalCache.Put(n, dec)
}
+ return dec
+}
+
+// Root returns the root hash of the trie.
+// Deprecated: use Hash instead.
+func (t *Trie) Root() []byte { return t.Hash().Bytes() }
- return NewValueNode(self, value.Bytes())
+// Hash returns the root hash of the trie. It does not write to the
+// database and can be used even if the trie doesn't have one.
+func (t *Trie) Hash() common.Hash {
+ root, _ := t.hashRoot(nil)
+ return common.BytesToHash(root.(hashNode))
}
-func (self *Trie) trans(node Node) Node {
- switch node := node.(type) {
- case *HashNode:
- value := common.NewValueFromBytes(self.cache.Get(node.key))
- return self.mknode(value)
- default:
- return node
+// Commit writes all nodes to the trie's database.
+// Nodes are stored with their sha3 hash as the key.
+//
+// Committing flushes nodes from memory.
+// Subsequent Get calls will load nodes from the database.
+func (t *Trie) Commit() (root common.Hash, err error) {
+ if t.db == nil {
+ panic("Commit called on trie with nil database")
}
+ return t.CommitTo(t.db)
}
-func (self *Trie) store(node Node) interface{} {
- data := common.Encode(node)
- if len(data) >= 32 {
- key := crypto.Sha3(data)
- if node.Dirty() {
- //fmt.Println("save", node)
- //fmt.Println()
- self.cache.Put(key, data)
- }
+// CommitTo writes all nodes to the given database.
+// Nodes are stored with their sha3 hash as the key.
+//
+// Committing flushes nodes from memory. Subsequent Get calls will
+// load nodes from the trie's database. Calling code must ensure that
+// the changes made to db are written back to the trie's attached
+// database before using the trie.
+func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
+ n, err := t.hashRoot(db)
+ if err != nil {
+ return (common.Hash{}), err
+ }
+ t.root = n
+ return common.BytesToHash(n.(hashNode)), nil
+}
- return key
+func (t *Trie) hashRoot(db DatabaseWriter) (node, error) {
+ if t.root == nil {
+ return hashNode(emptyRoot.Bytes()), nil
+ }
+ if t.hasher == nil {
+ t.hasher = newHasher()
}
+ return t.hasher.hash(t.root, db, true)
+}
- return node.RlpData()
+type hasher struct {
+ tmp *bytes.Buffer
+ sha hash.Hash
}
-func (self *Trie) PrintRoot() {
- fmt.Println(self.root)
- fmt.Printf("root=%x\n", self.Root())
+func newHasher() *hasher {
+ return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
+}
+
+func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) {
+ hashed, err := h.replaceChildren(n, db)
+ if err != nil {
+ return hashNode{}, err
+ }
+ if n, err = h.store(hashed, db, force); err != nil {
+ return hashNode{}, err
+ }
+ return n, nil
+}
+
+// hashChildren replaces child nodes of n with their hashes if the encoded
+// size of the child is larger than a hash.
+func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) {
+ var err error
+ switch n := n.(type) {
+ case shortNode:
+ n.Key = compactEncode(n.Key)
+ if _, ok := n.Val.(valueNode); !ok {
+ if n.Val, err = h.hash(n.Val, db, false); err != nil {
+ return n, err
+ }
+ }
+ if n.Val == nil {
+ // Ensure that nil children are encoded as empty strings.
+ n.Val = valueNode(nil)
+ }
+ return n, nil
+ case fullNode:
+ for i := 0; i < 16; i++ {
+ if n[i] != nil {
+ if n[i], err = h.hash(n[i], db, false); err != nil {
+ return n, err
+ }
+ } else {
+ // Ensure that nil children are encoded as empty strings.
+ n[i] = valueNode(nil)
+ }
+ }
+ if n[16] == nil {
+ n[16] = valueNode(nil)
+ }
+ return n, nil
+ default:
+ return n, nil
+ }
+}
+
+func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
+ // Don't store hashes or empty nodes.
+ if _, isHash := n.(hashNode); n == nil || isHash {
+ return n, nil
+ }
+ h.tmp.Reset()
+ if err := rlp.Encode(h.tmp, n); err != nil {
+ panic("encode error: " + err.Error())
+ }
+ if h.tmp.Len() < 32 && !force {
+ // Nodes smaller than 32 bytes are stored inside their parent.
+ return n, nil
+ }
+ // Larger nodes are replaced by their hash and stored in the database.
+ h.sha.Reset()
+ h.sha.Write(h.tmp.Bytes())
+ key := hashNode(h.sha.Sum(nil))
+ if db != nil {
+ err := db.Put(key, h.tmp.Bytes())
+ return key, err
+ }
+ return key, nil
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index ae4e5efe4..c96861bed 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -18,89 +18,109 @@ package trie
import (
"bytes"
+ "encoding/binary"
"fmt"
+ "io/ioutil"
+ "os"
"testing"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
)
-type Db map[string][]byte
-
-func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil }
-func (self Db) Put(k, v []byte) error { self[string(k)] = v; return nil }
-
-// Used for testing
-func NewEmpty() *Trie {
- return New(nil, make(Db))
+func init() {
+ spew.Config.Indent = " "
+ spew.Config.DisableMethods = true
}
-func NewEmptySecure() *SecureTrie {
- return NewSecure(nil, make(Db))
+// Used for testing
+func newEmpty() *Trie {
+ db, _ := ethdb.NewMemDatabase()
+ trie, _ := New(common.Hash{}, db)
+ return trie
}
func TestEmptyTrie(t *testing.T) {
- trie := NewEmpty()
+ var trie Trie
res := trie.Hash()
- exp := crypto.Sha3(common.Encode(""))
- if !bytes.Equal(res, exp) {
+ exp := emptyRoot
+ if res != common.Hash(exp) {
t.Errorf("expected %x got %x", exp, res)
}
}
func TestNull(t *testing.T) {
- trie := NewEmpty()
-
+ var trie Trie
key := make([]byte, 32)
value := common.FromHex("0x823140710bf13990e4500136726d8b55")
trie.Update(key, value)
value = trie.Get(key)
}
+func TestMissingRoot(t *testing.T) {
+ db, _ := ethdb.NewMemDatabase()
+ trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db)
+ if trie != nil {
+ t.Error("New returned non-nil trie for invalid root")
+ }
+ if err != ErrMissingRoot {
+ t.Error("New returned wrong error: %v", err)
+ }
+}
+
func TestInsert(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
- trie.UpdateString("doe", "reindeer")
- trie.UpdateString("dog", "puppy")
- trie.UpdateString("dogglesworth", "cat")
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
- exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash()
- if !bytes.Equal(root, exp) {
+ if root != exp {
t.Errorf("exp %x got %x", exp, root)
}
- trie = NewEmpty()
- trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+ trie = newEmpty()
+ updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
- exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
- root = trie.Hash()
- if !bytes.Equal(root, exp) {
+ exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root, err := trie.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if root != exp {
t.Errorf("exp %x got %x", exp, root)
}
}
func TestGet(t *testing.T) {
- trie := NewEmpty()
-
- trie.UpdateString("doe", "reindeer")
- trie.UpdateString("dog", "puppy")
- trie.UpdateString("dogglesworth", "cat")
+ trie := newEmpty()
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ for i := 0; i < 2; i++ {
+ res := getString(trie, "dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
- res := trie.GetString("dog")
- if !bytes.Equal(res, []byte("puppy")) {
- t.Errorf("expected puppy got %x", res)
- }
+ unknown := getString(trie, "unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
- unknown := trie.GetString("unknown")
- if unknown != nil {
- t.Errorf("expected nil got %x", unknown)
+ if i == 1 {
+ return
+ }
+ trie.Commit()
}
}
func TestDelete(t *testing.T) {
- trie := NewEmpty()
-
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -113,21 +133,21 @@ func TestDelete(t *testing.T) {
}
for _, val := range vals {
if val.v != "" {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
} else {
- trie.DeleteString(val.k)
+ deleteString(trie, val.k)
}
}
hash := trie.Hash()
- exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
- if !bytes.Equal(hash, exp) {
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestEmptyValues(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -140,78 +160,85 @@ func TestEmptyValues(t *testing.T) {
{"shaman", ""},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
hash := trie.Hash()
- exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
- if !bytes.Equal(hash, exp) {
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestReplication(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
- {"ether", ""},
{"dog", "puppy"},
- {"shaman", ""},
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
- trie.Commit()
-
- trie2 := New(trie.Root(), trie.cache.backend)
- if string(trie2.GetString("horse")) != "stallion" {
- t.Error("expected to have horse => stallion")
+ exp, err := trie.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
}
- hash := trie2.Hash()
- exp := trie.Hash()
- if !bytes.Equal(hash, exp) {
+ // create a new trie on top of the database and check that lookups work.
+ trie2, err := New(exp, trie.db)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ for _, kv := range vals {
+ if string(getString(trie2, kv.k)) != kv.v {
+ t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
+ }
+ }
+ hash, err := trie2.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
-}
-
-func TestReset(t *testing.T) {
- trie := NewEmpty()
- vals := []struct{ k, v string }{
+ // perform some insertions on the new trie.
+ vals2 := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
+ // {"shaman", "horse"},
+ // {"doge", "coin"},
+ // {"ether", ""},
+ // {"dog", "puppy"},
+ // {"somethingveryoddindeedthis is", "myothernodedata"},
+ // {"shaman", ""},
}
- for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ for _, val := range vals2 {
+ updateString(trie2, val.k, val.v)
}
- trie.Commit()
-
- before := common.CopyBytes(trie.roothash)
- trie.UpdateString("should", "revert")
- trie.Hash()
- // Should have no effect
- trie.Hash()
- trie.Hash()
- // ###
-
- trie.Reset()
- after := common.CopyBytes(trie.roothash)
+ if trie2.Hash() != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+}
- if !bytes.Equal(before, after) {
- t.Errorf("expected roots to be equal. %x - %x", before, after)
+func paranoiaCheck(t1 *Trie) (bool, *Trie) {
+ t2 := new(Trie)
+ it := NewIterator(t1)
+ for it.Next() {
+ t2.Update(it.Key, it.Value)
}
+ return t2.Hash() == t1.Hash(), t2
}
func TestParanoia(t *testing.T) {
t.Skip()
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -225,13 +252,13 @@ func TestParanoia(t *testing.T) {
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
trie.Commit()
- ok, t2 := ParanoiaCheck(trie, trie.cache.backend)
+ ok, t2 := paranoiaCheck(trie)
if !ok {
- t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash)
+ t.Errorf("trie paranoia check failed %x %x", trie.Hash(), t2.Hash())
}
}
@@ -240,51 +267,26 @@ func TestOutput(t *testing.T) {
t.Skip()
base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- trie := NewEmpty()
+ trie := newEmpty()
for i := 0; i < 50; i++ {
- trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
+ updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
}
fmt.Println("############################## FULL ################################")
fmt.Println(trie.root)
trie.Commit()
fmt.Println("############################## SMALL ################################")
- trie2 := New(trie.roothash, trie.cache.backend)
- trie2.GetString(base + "20")
+ trie2, _ := New(trie.Hash(), trie.db)
+ getString(trie2, base+"20")
fmt.Println(trie2.root)
}
-func BenchmarkGets(b *testing.B) {
- trie := NewEmpty()
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"ether", ""},
- {"dog", "puppy"},
- {"shaman", ""},
- {"somethingveryoddindeedthis is", "myothernodedata"},
- }
- for _, val := range vals {
- trie.UpdateString(val.k, val.v)
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- trie.Get([]byte("horse"))
- }
-}
-
-func BenchmarkUpdate(b *testing.B) {
- trie := NewEmpty()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value")
- }
+func TestLargeValue(t *testing.T) {
+ trie := newEmpty()
+ trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
+ trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash()
+
}
type kv struct {
@@ -293,7 +295,7 @@ type kv struct {
}
func TestLargeData(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ {
@@ -305,7 +307,7 @@ func TestLargeData(t *testing.T) {
vals[string(value2.k)] = value2
}
- it := trie.Iterator()
+ it := NewIterator(trie)
for it.Next() {
vals[string(it.Key)].t = true
}
@@ -325,30 +327,82 @@ func TestLargeData(t *testing.T) {
}
}
-func TestSecureDelete(t *testing.T) {
- trie := NewEmptySecure()
+func BenchmarkGet(b *testing.B) { benchGet(b, false) }
+func BenchmarkGetDB(b *testing.B) { benchGet(b, true) }
+func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
+func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
+func BenchmarkHashBE(b *testing.B) { benchHash(b, binary.BigEndian) }
+func BenchmarkHashLE(b *testing.B) { benchHash(b, binary.LittleEndian) }
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"ether", ""},
- {"dog", "puppy"},
- {"shaman", ""},
+const benchElemCount = 20000
+
+func benchGet(b *testing.B, commit bool) {
+ trie := new(Trie)
+ if commit {
+ dir, tmpdb := tempDB()
+ defer os.RemoveAll(dir)
+ trie, _ = New(common.Hash{}, tmpdb)
}
- for _, val := range vals {
- if val.v != "" {
- trie.UpdateString(val.k, val.v)
- } else {
- trie.DeleteString(val.k)
- }
+ k := make([]byte, 32)
+ for i := 0; i < benchElemCount; i++ {
+ binary.LittleEndian.PutUint64(k, uint64(i))
+ trie.Update(k, k)
+ }
+ binary.LittleEndian.PutUint64(k, benchElemCount/2)
+ if commit {
+ trie.Commit()
}
- hash := trie.Hash()
- exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
- if !bytes.Equal(hash, exp) {
- t.Errorf("expected %x got %x", exp, hash)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.Get(k)
}
}
+
+func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
+ trie := newEmpty()
+ k := make([]byte, 32)
+ for i := 0; i < b.N; i++ {
+ e.PutUint64(k, uint64(i))
+ trie.Update(k, k)
+ }
+ return trie
+}
+
+func benchHash(b *testing.B, e binary.ByteOrder) {
+ trie := newEmpty()
+ k := make([]byte, 32)
+ for i := 0; i < benchElemCount; i++ {
+ e.PutUint64(k, uint64(i))
+ trie.Update(k, k)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.Hash()
+ }
+}
+
+func tempDB() (string, Database) {
+ dir, err := ioutil.TempDir("", "trie-bench")
+ if err != nil {
+ panic(fmt.Sprintf("can't create temporary directory: %v", err))
+ }
+ db, err := ethdb.NewLDBDatabase(dir, 300*1024)
+ if err != nil {
+ panic(fmt.Sprintf("can't create temporary database: %v", err))
+ }
+ return dir, db
+}
+
+func getString(trie *Trie, k string) []byte {
+ return trie.Get([]byte(k))
+}
+
+func updateString(trie *Trie, k, v string) {
+ trie.Update([]byte(k), []byte(v))
+}
+
+func deleteString(trie *Trie, k string) {
+ trie.Delete([]byte(k))
+}
diff --git a/trie/valuenode.go b/trie/valuenode.go
deleted file mode 100644
index 0afa64d54..000000000
--- a/trie/valuenode.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type ValueNode struct {
- trie *Trie
- data []byte
- dirty bool
-}
-
-func NewValueNode(trie *Trie, data []byte) *ValueNode {
- return &ValueNode{trie, data, false}
-}
-
-func (self *ValueNode) Value() Node { return self } // Best not to call :-)
-func (self *ValueNode) Val() []byte { return self.data }
-func (self *ValueNode) Dirty() bool { return self.dirty }
-func (self *ValueNode) Copy(t *Trie) Node {
- return &ValueNode{t, common.CopyBytes(self.data), self.dirty}
-}
-func (self *ValueNode) RlpData() interface{} { return self.data }
-func (self *ValueNode) Hash() interface{} { return self.data }
-
-func (self *ValueNode) setDirty(dirty bool) {
- self.dirty = dirty
-}