diff options
author | Felix Lange <fjl@users.noreply.github.com> | 2017-06-27 21:57:06 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-27 21:57:06 +0800 |
commit | 9e5f03b6c487175cc5aa1224e5e12fd573f483a7 (patch) | |
tree | 475e573ff6c7e77cd069a2f6238afdb27d4bce43 | |
parent | bb366271fe33cf87b462dc5a25ac6c448ac6d2e1 (diff) | |
download | dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar.gz dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar.bz2 dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar.lz dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar.xz dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.tar.zst dexon-9e5f03b6c487175cc5aa1224e5e12fd573f483a7.zip |
core/state: access trie through Database interface, track errors (#14589)
With this commit, core/state's access to the underlying key/value database is
mediated through an interface. Database errors are tracked in StateDB and
returned by CommitTo or the new Error method.
Motivation for this change: We can remove the light client's duplicated copy of
core/state. The light client now supports node iteration, so tracing and storage
enumeration can work with the light client (not implemented in this commit).
49 files changed, 809 insertions, 1663 deletions
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 159fca136..7ac8b5820 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -90,7 +90,7 @@ func (b *SimulatedBackend) Rollback() { func (b *SimulatedBackend) rollback() { blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.database, 1, func(int, *core.BlockGen) {}) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) + b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -279,7 +279,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa block.AddTx(tx) }) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) + b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) return nil } diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 2ce0920f6..3f95a0c93 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -98,8 +98,8 @@ func runCmd(ctx *cli.Context) error { _, statedb = gen.ToBlock() chainConfig = gen.Config } else { - var db, _ = ethdb.NewMemDatabase() - statedb, _ = state.New(common.Hash{}, db) + db, _ := ethdb.NewMemDatabase() + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) } if ctx.GlobalString(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) @@ -188,7 +188,7 @@ func runCmd(ctx *cli.Context) error { execTime := time.Since(tstart) if ctx.GlobalBool(DumpFlag.Name) { - statedb.Commit(true) + statedb.IntermediateRoot(true) fmt.Println(string(statedb.Dump())) } diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index ab0e92f21..12bc1d7c6 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -312,7 +312,7 @@ func dump(ctx *cli.Context) error { fmt.Println("{}") utils.Fatalf("block not found") } else { - state, err := state.New(block.Root(), chainDb) + state, err := state.New(block.Root(), state.NewDatabase(chainDb)) if err != nil { utils.Fatalf("could not create new state: %v", err) } diff --git a/core/block_validator.go b/core/block_validator.go index 4f85df007..e9cfd0482 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -52,16 +52,10 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin // validated at this point. func (v *BlockValidator) ValidateBody(block *types.Block) error { // Check whether the block's known, and if not, that it's linkable - if v.bc.HasBlock(block.Hash()) { - if _, err := state.New(block.Root(), v.bc.chainDb); err == nil { - return ErrKnownBlock - } + if v.bc.HasBlockAndState(block.Hash()) { + return ErrKnownBlock } - parent := v.bc.GetBlock(block.ParentHash(), block.NumberU64()-1) - if parent == nil { - return consensus.ErrUnknownAncestor - } - if _, err := state.New(parent.Root(), v.bc.chainDb); err != nil { + if !v.bc.HasBlockAndState(block.ParentHash()) { return consensus.ErrUnknownAncestor } // Header validity is known at this point, check the uncles and transactions diff --git a/core/blockchain.go b/core/blockchain.go index 073b91bab..aab2e72f3 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -92,7 +92,7 @@ type BlockChain struct { currentBlock *types.Block // Current head of the block chain currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) - stateCache *state.StateDB // State database to reuse between imports (contains state cache) + stateCache state.Database // State database to reuse between imports (contains state cache) bodyCache *lru.Cache // Cache for the most recent block bodies bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format blockCache *lru.Cache // Cache for the most recent entire blocks @@ -125,6 +125,7 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co bc := &BlockChain{ config: config, chainDb: chainDb, + stateCache: state.NewDatabase(chainDb), eventMux: mux, quit: make(chan struct{}), bodyCache: bodyCache, @@ -190,7 +191,7 @@ func (bc *BlockChain) loadLastState() error { return bc.Reset() } // Make sure the state associated with the block is available - if _, err := state.New(currentBlock.Root(), bc.chainDb); err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { // Dangling block without a state associated, init from scratch log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) return bc.Reset() @@ -214,12 +215,6 @@ func (bc *BlockChain) loadLastState() error { bc.currentFastBlock = block } } - // Initialize a statedb cache to ensure singleton account bloom filter generation - statedb, err := state.New(bc.currentBlock.Root(), bc.chainDb) - if err != nil { - return err - } - bc.stateCache = statedb // Issue a status log for the user headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) @@ -261,7 +256,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()) } if bc.currentBlock != nil { - if _, err := state.New(bc.currentBlock.Root(), bc.chainDb); err != nil { + if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil { // Rewound state missing, rolled back to before pivot, reset to genesis bc.currentBlock = nil } @@ -384,7 +379,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return bc.stateCache.New(root) + return state.New(root, bc.stateCache) } // Reset purges the entire blockchain, restoring it to its genesis state. @@ -531,7 +526,7 @@ func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool { return false } // Ensure the associated state is also present - _, err := state.New(block.Root(), bc.chainDb) + _, err := bc.stateCache.OpenTrie(block.Root()) return err == nil } @@ -959,31 +954,30 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { } // Create a new statedb using the parent block and report an // error if it fails. - switch { - case i == 0: - err = bc.stateCache.Reset(bc.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) - default: - err = bc.stateCache.Reset(chain[i-1].Root()) + var parent *types.Block + if i == 0 { + parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1) + } else { + parent = chain[i-1] } + state, err := state.New(parent.Root(), bc.stateCache) if err != nil { - bc.reportBlock(block, nil, err) return i, err } // Process block using the parent state as reference point. - receipts, logs, usedGas, err := bc.processor.Process(block, bc.stateCache, bc.vmConfig) + receipts, logs, usedGas, err := bc.processor.Process(block, state, bc.vmConfig) if err != nil { bc.reportBlock(block, receipts, err) return i, err } // Validate the state using the default validator - err = bc.Validator().ValidateState(block, bc.GetBlock(block.ParentHash(), block.NumberU64()-1), bc.stateCache, receipts, usedGas) + err = bc.Validator().ValidateState(block, parent, state, receipts, usedGas) if err != nil { bc.reportBlock(block, receipts, err) return i, err } // Write state changes to database - _, err = bc.stateCache.Commit(bc.config.IsEIP158(block.Number())) - if err != nil { + if _, err = state.CommitTo(bc.chainDb, bc.config.IsEIP158(block.Number())); err != nil { return i, err } @@ -1021,7 +1015,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { return i, err } // Write hash preimages - if err := WritePreimages(bc.chainDb, block.NumberU64(), bc.stateCache.Preimages()); err != nil { + if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil { return i, err } case SideStatTy: diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 7505208e1..371522ab7 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -131,7 +131,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { } return err } - statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.chainDb) + statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache) if err != nil { return err } @@ -148,7 +148,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { blockchain.mu.Lock() WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) WriteBlock(blockchain.chainDb, block) - statedb.Commit(false) + statedb.CommitTo(blockchain.chainDb, false) blockchain.mu.Unlock() } return nil @@ -1131,7 +1131,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[0]}); err != nil { t.Fatal(err) } - if !blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); !st.Exist(theAddr) { t.Error("expected account to exist") } @@ -1139,7 +1139,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[1]}); err != nil { t.Fatal(err) } - if blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); st.Exist(theAddr) { t.Error("account should not exist") } @@ -1147,7 +1147,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { t.Fatal(err) } - if blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); st.Exist(theAddr) { t.Error("account should not exist") } } diff --git a/core/chain_makers.go b/core/chain_makers.go index cc14f8fb8..38a69d42a 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -181,7 +181,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat gen(i, b) } ethash.AccumulateRewards(statedb, h, b.uncles) - root, err := statedb.Commit(config.IsEIP158(h.Number)) + root, err := statedb.CommitTo(db, config.IsEIP158(h.Number)) if err != nil { panic(fmt.Sprintf("state write error: %v", err)) } @@ -189,7 +189,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat return types.NewBlock(h, b.txs, b.uncles, b.receipts), b.receipts } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), db) + statedb, err := state.New(parent.Root(), state.NewDatabase(db)) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index 947a53c70..5815d5901 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -176,7 +176,7 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { // ToBlock creates the block and state of a genesis specification. func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) for addr, account := range g.Alloc { statedb.AddBalance(addr, account.Balance) statedb.SetCode(addr, account.Code) diff --git a/core/state/database.go b/core/state/database.go new file mode 100644 index 000000000..946625e76 --- /dev/null +++ b/core/state/database.go @@ -0,0 +1,154 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package state + +import ( + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" + lru "github.com/hashicorp/golang-lru" +) + +// Trie cache generation limit after which to evic trie nodes from memory. +var MaxTrieCacheGen = uint16(120) + +const ( + // Number of past tries to keep. This value is chosen such that + // reasonable chain reorg depths will hit an existing trie. + maxPastTries = 12 + + // Number of codehash->size associations to keep. + codeSizeCacheSize = 100000 +) + +// Database wraps access to tries and contract code. +type Database interface { + // Accessing tries: + // OpenTrie opens the main account trie. + // OpenStorageTrie opens the storage trie of an account. + OpenTrie(root common.Hash) (Trie, error) + OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + // Accessing contract code: + ContractCode(addrHash, codeHash common.Hash) ([]byte, error) + ContractCodeSize(addrHash, codeHash common.Hash) (int, error) + // CopyTrie returns an independent copy of the given trie. + CopyTrie(Trie) Trie +} + +// Trie is a Ethereum Merkle Trie. +type Trie interface { + TryGet(key []byte) ([]byte, error) + TryUpdate(key, value []byte) error + TryDelete(key []byte) error + CommitTo(trie.DatabaseWriter) (common.Hash, error) + Hash() common.Hash + NodeIterator(startKey []byte) trie.NodeIterator + GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed +} + +// NewDatabase creates a backing store for state. The returned database is safe for +// concurrent use and retains cached trie nodes in memory. +func NewDatabase(db ethdb.Database) Database { + csc, _ := lru.New(codeSizeCacheSize) + return &cachingDB{db: db, codeSizeCache: csc} +} + +type cachingDB struct { + db ethdb.Database + mu sync.Mutex + pastTries []*trie.SecureTrie + codeSizeCache *lru.Cache +} + +func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { + db.mu.Lock() + defer db.mu.Unlock() + + for i := len(db.pastTries) - 1; i >= 0; i-- { + if db.pastTries[i].Hash() == root { + return cachedTrie{db.pastTries[i].Copy(), db}, nil + } + } + tr, err := trie.NewSecure(root, db.db, MaxTrieCacheGen) + if err != nil { + return nil, err + } + return cachedTrie{tr, db}, nil +} + +func (db *cachingDB) pushTrie(t *trie.SecureTrie) { + db.mu.Lock() + defer db.mu.Unlock() + + if len(db.pastTries) >= maxPastTries { + copy(db.pastTries, db.pastTries[1:]) + db.pastTries[len(db.pastTries)-1] = t + } else { + db.pastTries = append(db.pastTries, t) + } +} + +func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { + return trie.NewSecure(root, db.db, 0) +} + +func (db *cachingDB) CopyTrie(t Trie) Trie { + switch t := t.(type) { + case cachedTrie: + return cachedTrie{t.SecureTrie.Copy(), db} + case *trie.SecureTrie: + return t.Copy() + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + +func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { + code, err := db.db.Get(codeHash[:]) + if err == nil { + db.codeSizeCache.Add(codeHash, len(code)) + } + return code, err +} + +func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { + if cached, ok := db.codeSizeCache.Get(codeHash); ok { + return cached.(int), nil + } + code, err := db.ContractCode(addrHash, codeHash) + if err == nil { + db.codeSizeCache.Add(codeHash, len(code)) + } + return len(code), err +} + +// cachedTrie inserts its trie into a cachingDB on commit. +type cachedTrie struct { + *trie.SecureTrie + db *cachingDB +} + +func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) { + root, err := m.SecureTrie.CommitTo(dbw) + if err == nil { + m.db.pushTrie(m.SecureTrie) + } + return root, err +} diff --git a/core/state/dump.go b/core/state/dump.go index ffa1a7283..46e612850 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -41,7 +41,7 @@ type Dump struct { func (self *StateDB) RawDump() Dump { dump := Dump{ - Root: common.Bytes2Hex(self.trie.Root()), + Root: fmt.Sprintf("%x", self.trie.Hash()), Accounts: make(map[string]DumpAccount), } diff --git a/core/state/iterator.go b/core/state/iterator.go index a8a2722ae..6a5c73d3d 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -19,7 +19,6 @@ package state import ( "bytes" "fmt" - "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" @@ -105,16 +104,11 @@ func (it *NodeIterator) step() error { return nil } // Otherwise we've reached an account node, initiate data iteration - var account struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte - } + var account Account if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } - dataTrie, err := trie.New(account.Root, it.state.db) + dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root) if err != nil { return err } @@ -124,7 +118,8 @@ func (it *NodeIterator) step() error { } if !bytes.Equal(account.CodeHash, emptyCodeHash) { it.codeHash = common.BytesToHash(account.CodeHash) - it.code, err = it.state.db.Get(account.CodeHash) + addrHash := common.BytesToHash(it.stateIt.LeafKey()) + it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash)) if err != nil { return fmt.Errorf("code %x: %v", account.CodeHash, err) } diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index aa9c5b728..ff66ba7a9 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -21,13 +21,12 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethdb" ) // Tests that the node iterator indeed walks over the entire database contents. func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate - db, root, _ := makeTestState() + db, mem, root, _ := makeTestState() state, err := New(root, db) if err != nil { @@ -40,13 +39,14 @@ func TestNodeIteratorCoverage(t *testing.T) { hashes[it.Hash] = struct{}{} } } + // Cross check the hashes and the database itself for hash := range hashes { - if _, err := db.Get(hash.Bytes()); err != nil { + if _, err := mem.Get(hash.Bytes()); err != nil { t.Errorf("failed to retrieve reported node %x: %v", hash, err) } } - for _, key := range db.(*ethdb.MemDatabase).Keys() { + for _, key := range mem.Keys() { if bytes.HasPrefix(key, []byte("secure-key-")) { continue } diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index ea5737a08..1cfdd3a89 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -27,7 +27,7 @@ var addr = common.BytesToAddress([]byte("test")) func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() - statedb, _ := New(common.Hash{}, db) + statedb, _ := New(common.Hash{}, NewDatabase(db)) ms := ManageState(statedb) ms.StateDB.SetNonce(addr, 100) ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) diff --git a/core/state/state_object.go b/core/state/state_object.go index dcad9d068..b2378c69c 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -62,9 +62,10 @@ func (self Storage) Copy() Storage { // Account values can be accessed and modified through the object. // Finally, call CommitTrie to write the modified storage trie into a database. type stateObject struct { - address common.Address // Ethereum address of this account - data Account - db *StateDB + address common.Address + addrHash common.Hash // hash of ethereum address of the account + data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -74,8 +75,8 @@ type stateObject struct { dbErr error // Write caches. - trie *trie.SecureTrie // storage trie, which becomes non-nil on first access - code Code // contract bytecode, which gets set when code is loaded + trie Trie // storage trie, which becomes non-nil on first access + code Code // contract bytecode, which gets set when code is loaded cachedStorage Storage // Storage entry cache to avoid duplicate reads dirtyStorage Storage // Storage entries that need to be flushed to disk @@ -112,7 +113,15 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &stateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} + return &stateObject{ + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + cachedStorage: make(Storage), + dirtyStorage: make(Storage), + onDirty: onDirty, + } } // EncodeRLP implements rlp.Encoder. @@ -148,12 +157,12 @@ func (c *stateObject) touch() { c.touched = true } -func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie { +func (c *stateObject) getTrie(db Database) Trie { if c.trie == nil { var err error - c.trie, err = trie.NewSecure(c.data.Root, db, 0) + c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) if err != nil { - c.trie, _ = trie.NewSecure(common.Hash{}, db, 0) + c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) c.setError(fmt.Errorf("can't create storage trie: %v", err)) } } @@ -161,13 +170,18 @@ func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie { } // GetState returns a value in account storage. -func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash { +func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { value, exists := self.cachedStorage[key] if exists { return value } // Load from DB in case it is missing. - if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 { + enc, err := self.getTrie(db).TryGet(key[:]) + if err != nil { + self.setError(err) + return common.Hash{} + } + if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { self.setError(err) @@ -181,7 +195,7 @@ func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash } // SetState updates a value in account storage. -func (self *stateObject) SetState(db trie.Database, key, value common.Hash) { +func (self *stateObject) SetState(db Database, key, value common.Hash) { self.db.journal = append(self.db.journal, storageChange{ account: &self.address, key: key, @@ -201,30 +215,30 @@ func (self *stateObject) setState(key, value common.Hash) { } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie { +func (self *stateObject) updateTrie(db Database) Trie { tr := self.getTrie(db) for key, value := range self.dirtyStorage { delete(self.dirtyStorage, key) if (value == common.Hash{}) { - tr.Delete(key[:]) + self.setError(tr.TryDelete(key[:])) continue } // Encoding []byte cannot fail, ok to ignore the error. v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - tr.Update(key[:], v) + self.setError(tr.TryUpdate(key[:], v)) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db trie.Database) { +func (self *stateObject) updateRoot(db Database) { self.updateTrie(db) self.data.Root = self.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { +func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error { self.updateTrie(db) if self.dbErr != nil { return self.dbErr @@ -282,9 +296,7 @@ func (c *stateObject) ReturnGas(gas *big.Int) {} func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { stateObject := newObject(db, self.address, self.data, onDirty) if self.trie != nil { - // A shallow copy makes the two tries independent. - cpy := *self.trie - stateObject.trie = &cpy + stateObject.trie = db.db.CopyTrie(self.trie) } stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() @@ -305,14 +317,14 @@ func (c *stateObject) Address() common.Address { } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db trie.Database) []byte { +func (self *stateObject) Code(db Database) []byte { if self.code != nil { return self.code } if bytes.Equal(self.CodeHash(), emptyCodeHash) { return nil } - code, err := db.Get(self.CodeHash()) + code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) if err != nil { self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) } diff --git a/core/state/state_test.go b/core/state/state_test.go index 3bc63c148..bbae3685b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -21,14 +21,14 @@ import ( "math/big" "testing" - checker "gopkg.in/check.v1" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" + checker "gopkg.in/check.v1" ) type StateSuite struct { + db *ethdb.MemDatabase state *StateDB } @@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) { // write some of them to the trie s.state.updateStateObject(obj1) s.state.updateStateObject(obj2) - s.state.Commit(false) + s.state.CommitTo(s.db, false) // check that dump contains the state objects that are in trie got := string(s.state.Dump()) @@ -87,23 +87,20 @@ func (s *StateSuite) TestDump(c *checker.C) { } func (s *StateSuite) SetUpTest(c *checker.C) { - db, _ := ethdb.NewMemDatabase() - s.state, _ = New(common.Hash{}, db) + s.db, _ = ethdb.NewMemDatabase() + s.state, _ = New(common.Hash{}, NewDatabase(s.db)) } -func TestNull(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - +func (s *StateSuite) TestNull(c *checker.C) { address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") - state.CreateAccount(address) + s.state.CreateAccount(address) //value := common.FromHex("0x823140710bf13990e4500136726d8b55") var value common.Hash - state.SetState(address, common.Hash{}, value) - state.Commit(false) - value = state.GetState(address, common.Hash{}) + s.state.SetState(address, common.Hash{}, value) + s.state.CommitTo(s.db, false) + value = s.state.GetState(address, common.Hash{}) if !common.EmptyHash(value) { - t.Errorf("expected empty hash. got %x", value) + c.Errorf("expected empty hash. got %x", value) } } @@ -129,17 +126,15 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { c.Assert(data1, checker.DeepEquals, res) } -func TestSnapshotEmpty(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - state.RevertToSnapshot(state.Snapshot()) +func (s *StateSuite) TestSnapshotEmpty(c *checker.C) { + s.state.RevertToSnapshot(s.state.Snapshot()) } // use testing instead of checker because checker does not support // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, NewDatabase(db)) stateobjaddr0 := toAddr([]byte("so0")) stateobjaddr1 := toAddr([]byte("so1")) @@ -160,7 +155,7 @@ func TestSnapshot2(t *testing.T) { so0.deleted = false state.setStateObject(so0) - root, _ := state.Commit(false) + root, _ := state.CommitTo(db, false) state.Reset(root) // and one with deleted == true @@ -182,8 +177,8 @@ func TestSnapshot2(t *testing.T) { so0Restored := state.getStateObject(stateobjaddr0) // Update lazily-loaded values before comparing. - so0Restored.GetState(db, storageaddr) - so0Restored.Code(db) + so0Restored.GetState(state.db, storageaddr) + so0Restored.Code(state.db) // non-deleted is equal (restored) compareStateObjects(so0Restored, so0, t) diff --git a/core/state/statedb.go b/core/state/statedb.go index 05869a0c8..694374f82 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,23 +26,9 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" - lru "github.com/hashicorp/golang-lru" -) - -// Trie cache generation limit after which to evic trie nodes from memory. -var MaxTrieCacheGen = uint16(120) - -const ( - // Number of past tries to keep. This value is chosen such that - // reasonable chain reorg depths will hit an existing trie. - maxPastTries = 12 - - // Number of codehash->size associations to keep. - codeSizeCacheSize = 100000 ) type revision struct { @@ -56,16 +42,21 @@ type revision struct { // * Contracts // * Accounts type StateDB struct { - db ethdb.Database - trie *trie.SecureTrie - pastTries []*trie.SecureTrie - codeSizeCache *lru.Cache + db Database + trie Trie // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject stateObjectsDirty map[common.Address]struct{} stateObjectsDestructed map[common.Address]struct{} + // DB error. + // State objects are used by the consensus core and VM which are + // unable to deal with database-level errors. Any error that occurs + // during a database read is memoized here and will eventually be returned + // by StateDB.Commit. + dbErr error + // The refund counter, also used by state transitioning. refund *big.Int @@ -86,16 +77,14 @@ type StateDB struct { } // Create a new state from a given trie -func New(root common.Hash, db ethdb.Database) (*StateDB, error) { - tr, err := trie.NewSecure(root, db, MaxTrieCacheGen) +func New(root common.Hash, db Database) (*StateDB, error) { + tr, err := db.OpenTrie(root) if err != nil { return nil, err } - csc, _ := lru.New(codeSizeCacheSize) return &StateDB{ db: db, trie: tr, - codeSizeCache: csc, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), stateObjectsDestructed: make(map[common.Address]struct{}), @@ -105,36 +94,21 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { }, nil } -// New creates a new statedb by reusing any journalled tries to avoid costly -// disk io. -func (self *StateDB) New(root common.Hash) (*StateDB, error) { - self.lock.Lock() - defer self.lock.Unlock() - - tr, err := self.openTrie(root) - if err != nil { - return nil, err +// setError remembers the first non-nil error it is called with. +func (self *StateDB) setError(err error) { + if self.dbErr == nil { + self.dbErr = err } - return &StateDB{ - db: self.db, - trie: tr, - codeSizeCache: self.codeSizeCache, - stateObjects: make(map[common.Address]*stateObject), - stateObjectsDirty: make(map[common.Address]struct{}), - stateObjectsDestructed: make(map[common.Address]struct{}), - refund: new(big.Int), - logs: make(map[common.Hash][]*types.Log), - preimages: make(map[common.Hash][]byte), - }, nil +} + +func (self *StateDB) Error() error { + return self.dbErr } // Reset clears out all emphemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. func (self *StateDB) Reset(root common.Hash) error { - self.lock.Lock() - defer self.lock.Unlock() - - tr, err := self.openTrie(root) + tr, err := self.db.OpenTrie(root) if err != nil { return err } @@ -149,34 +123,9 @@ func (self *StateDB) Reset(root common.Hash) error { self.logSize = 0 self.preimages = make(map[common.Hash][]byte) self.clearJournalAndRefund() - return nil } -// openTrie creates a trie. It uses an existing trie if one is available -// from the journal if available. -func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) { - for i := len(self.pastTries) - 1; i >= 0; i-- { - if self.pastTries[i].Hash() == root { - tr := *self.pastTries[i] - return &tr, nil - } - } - return trie.NewSecure(root, self.db, MaxTrieCacheGen) -} - -func (self *StateDB) pushTrie(t *trie.SecureTrie) { - self.lock.Lock() - defer self.lock.Unlock() - - if len(self.pastTries) >= maxPastTries { - copy(self.pastTries, self.pastTries[1:]) - self.pastTries[len(self.pastTries)-1] = t - } else { - self.pastTries = append(self.pastTries, t) - } -} - func (self *StateDB) AddLog(log *types.Log) { self.journal = append(self.journal, addLogChange{txhash: self.thash}) @@ -254,10 +203,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.getStateObject(addr) if stateObject != nil { - code := stateObject.Code(self.db) - key := common.BytesToHash(stateObject.CodeHash()) - self.codeSizeCache.Add(key, len(code)) - return code + return stateObject.Code(self.db) } return nil } @@ -267,13 +213,12 @@ func (self *StateDB) GetCodeSize(addr common.Address) int { if stateObject == nil { return 0 } - key := common.BytesToHash(stateObject.CodeHash()) - if cached, ok := self.codeSizeCache.Get(key); ok { - return cached.(int) + if stateObject.code != nil { + return len(stateObject.code) } - size := len(stateObject.Code(self.db)) - if stateObject.dbErr == nil { - self.codeSizeCache.Add(key, size) + size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) + if err != nil { + self.setError(err) } return size } @@ -296,7 +241,7 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { // StorageTrie returns the storage trie of an account. // The return value is a copy and is nil for non-existent accounts. -func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie { +func (self *StateDB) StorageTrie(a common.Address) Trie { stateObject := self.getStateObject(a) if stateObject == nil { return nil @@ -394,14 +339,14 @@ func (self *StateDB) updateStateObject(stateObject *stateObject) { if err != nil { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } - self.trie.Update(addr[:], data) + self.setError(self.trie.TryUpdate(addr[:], data)) } // deleteStateObject removes the given object from the state trie. func (self *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - self.trie.Delete(addr[:]) + self.setError(self.trie.TryDelete(addr[:])) } // Retrieve a state object given my the address. Returns nil if not found. @@ -415,8 +360,9 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje } // Load the object from the database. - enc := self.trie.Get(addr[:]) + enc, err := self.trie.TryGet(addr[:]) if len(enc) == 0 { + self.setError(err) return nil } var data Account @@ -512,8 +458,6 @@ func (self *StateDB) Copy() *StateDB { state := &StateDB{ db: self.db, trie: self.trie, - pastTries: self.pastTries, - codeSizeCache: self.codeSizeCache, stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), stateObjectsDestructed: make(map[common.Address]struct{}, len(self.stateObjectsDestructed)), @@ -636,23 +580,6 @@ func (s *StateDB) DeleteSuicides() { } } -// Commit commits all state changes to the database. -func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { - root, batch := s.CommitBatch(deleteEmptyObjects) - return root, batch.Write() -} - -// CommitBatch commits all state changes to a write batch but does not -// execute the batch. It is used to validate state changes against -// the root hash stored in a block. -func (s *StateDB) CommitBatch(deleteEmptyObjects bool) (root common.Hash, batch ethdb.Batch) { - batch = s.db.NewBatch() - root, _ = s.CommitTo(batch, deleteEmptyObjects) - - log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) - return root, batch -} - func (s *StateDB) clearJournalAndRefund() { s.journal = nil s.validRevisions = s.validRevisions[:0] @@ -690,8 +617,6 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro } // Write trie changes. root, err = s.trie.CommitTo(dbw) - if err == nil { - s.pushTrie(s.trie) - } + log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) return root, err } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 72b638f97..b2bd18e65 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -28,6 +28,8 @@ import ( "testing" "testing/quick" + check "gopkg.in/check.v1" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" @@ -38,7 +40,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, NewDatabase(db)) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -66,8 +68,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb, _ := ethdb.NewMemDatabase() finalDb, _ := ethdb.NewMemDatabase() - transState, _ := New(common.Hash{}, transDb) - finalState, _ := New(common.Hash{}, finalDb) + transState, _ := New(common.Hash{}, NewDatabase(transDb)) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -95,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) { } // Commit and cross check the databases. - if _, err := transState.Commit(false); err != nil { + if _, err := transState.CommitTo(transDb, false); err != nil { t.Fatalf("failed to commit transition state: %v", err) } - if _, err := finalState.Commit(false); err != nil { + if _, err := finalState.CommitTo(finalDb, false); err != nil { t.Fatalf("failed to commit final state: %v", err) } for _, key := range finalDb.Keys() { @@ -282,7 +284,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db, _ = ethdb.NewMemDatabase() - state, _ = New(common.Hash{}, db) + state, _ = New(common.Hash{}, NewDatabase(db)) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -297,7 +299,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, db) + checkstate, _ := New(common.Hash{}, NewDatabase(db)) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -354,21 +356,19 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return nil } -func TestTouchDelete(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - state.GetOrNewStateObject(common.Address{}) - root, _ := state.Commit(false) - state.Reset(root) +func (s *StateSuite) TestTouchDelete(c *check.C) { + s.state.GetOrNewStateObject(common.Address{}) + root, _ := s.state.CommitTo(s.db, false) + s.state.Reset(root) - snapshot := state.Snapshot() - state.AddBalance(common.Address{}, new(big.Int)) - if len(state.stateObjectsDirty) != 1 { - t.Fatal("expected one dirty state object") + snapshot := s.state.Snapshot() + s.state.AddBalance(common.Address{}, new(big.Int)) + if len(s.state.stateObjectsDirty) != 1 { + c.Fatal("expected one dirty state object") } - state.RevertToSnapshot(snapshot) - if len(state.stateObjectsDirty) != 0 { - t.Fatal("expected no dirty state object") + s.state.RevertToSnapshot(snapshot) + if len(s.state.stateObjectsDirty) != 0 { + c.Fatal("expected no dirty state object") } } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 108ebb320..06c572ea6 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -36,9 +36,10 @@ type testAccount struct { } // makeTestState create a sample test state to test node-wise reconstruction. -func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { +func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) { // Create an empty state - db, _ := ethdb.NewMemDatabase() + mem, _ := ethdb.NewMemDatabase() + db := NewDatabase(mem) state, _ := New(common.Hash{}, db) // Fill it with some arbitrary data @@ -60,17 +61,17 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { state.updateStateObject(obj) accounts = append(accounts, acc) } - root, _ := state.Commit(false) + root, _ := state.CommitTo(mem, false) // Return the generated state - return db, root, accounts + return db, mem, root, accounts } // checkStateAccounts cross references a reconstructed state with an expected // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { // Check root availability and state contents - state, err := New(root, db) + state, err := New(root, NewDatabase(db)) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } @@ -90,13 +91,28 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou } } -// checkStateConsistency checks that all nodes in a state trie are indeed present. +// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present. +func checkTrieConsistency(db ethdb.Database, root common.Hash) error { + if v, _ := db.Get(root[:]); v == nil { + return nil // Consider a non existent state consistent. + } + trie, err := trie.New(root, db) + if err != nil { + return err + } + it := trie.NodeIterator(nil) + for it.Next(true) { + } + return it.Error() +} + +// checkStateConsistency checks that all data of a state root is present. func checkStateConsistency(db ethdb.Database, root common.Hash) error { // Create and iterate a state trie rooted in a sub-node if _, err := db.Get(root.Bytes()); err != nil { - return nil // Consider a non existent state consistent + return nil // Consider a non existent state consistent. } - state, err := New(root, db) + state, err := New(root, NewDatabase(db)) if err != nil { return err } @@ -122,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t, func testIterativeStateSync(t *testing.T, batch int) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -132,7 +148,7 @@ func testIterativeStateSync(t *testing.T, batch int) { for len(queue) > 0 { results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -154,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) { // partial results are returned, and the others sent only later. func TestIterativeDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -165,7 +181,7 @@ func TestIterativeDelayedStateSync(t *testing.T) { // Sync only half of the scheduled nodes results := make([]trie.SyncResult, len(queue)/2+1) for i, hash := range queue[:len(results)] { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -191,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS func testIterativeRandomStateSync(t *testing.T, batch int) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -205,7 +221,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // Fetch all the queued nodes in a random order results := make([]trie.SyncResult, 0, len(queue)) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -231,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // partial results are returned (Even those randomly), others sent only later. func TestIterativeRandomDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -247,7 +263,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { for hash := range queue { delete(queue, hash) - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -276,7 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { // the database. func TestIncompleteStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() + + checkTrieConsistency(srcMem, srcRoot) // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -288,7 +306,7 @@ func TestIncompleteStateSync(t *testing.T) { // Fetch a batch of state nodes results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -304,21 +322,18 @@ func TestIncompleteStateSync(t *testing.T) { for _, result := range results { added = append(added, result.Hash) } - // Check that all known sub-tries in the synced state is complete - for _, root := range added { - // Skim through the accounts and make sure the root hash is not a code node - codeHash := false + // Check that all known sub-tries added so far are complete or missing entirely. + checkSubtries: + for _, hash := range added { for _, acc := range srcAccounts { - if root == crypto.Keccak256Hash(acc.code) { - codeHash = true - break + if hash == crypto.Keccak256Hash(acc.code) { + continue checkSubtries // skip trie check of code nodes. } } - // If the root is a real trie node, check consistency - if !codeHash { - if err := checkStateConsistency(dstDb, root); err != nil { - t.Fatalf("state inconsistent: %v", err) - } + // Can't use checkStateConsistency here because subtrie keys may have odd + // length and crash in LeafKey. + if err := checkTrieConsistency(dstDb, hash); err != nil { + t.Fatalf("state inconsistent: %v", err) } } // Fetch the next batch to retrieve diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 4e28522e9..4903bc3ca 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -44,7 +44,7 @@ func pricedTransaction(nonce uint64, gaslimit, gasprice *big.Int, key *ecdsa.Pri func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) key, _ := crypto.GenerateKey() newPool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) @@ -95,7 +95,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) { key, _ = crypto.GenerateKey() address = crypto.PubkeyToAddress(key.PublicKey) mux = new(event.TypeMux) - statedb, _ = state.New(common.Hash{}, db) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) trigger = false ) @@ -114,7 +114,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) { // a state change between those fetches. stdb := statedb if trigger { - statedb, _ = state.New(common.Hash{}, db) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) // simulate that the new head block included tx0 and tx1 statedb.SetNonce(address, 2) statedb.SetBalance(address, new(big.Int).SetUint64(params.Ether)) @@ -292,7 +292,7 @@ func TestTransactionChainFork(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool.currentState = func() (*state.StateDB, error) { return statedb, nil } currentState, _ := pool.currentState() currentState.AddBalance(addr, big.NewInt(100000000000000)) @@ -318,7 +318,7 @@ func TestTransactionDoubleNonce(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool.currentState = func() (*state.StateDB, error) { return statedb, nil } currentState, _ := pool.currentState() currentState.AddBalance(addr, big.NewInt(100000000000000)) @@ -628,7 +628,7 @@ func TestTransactionQueueGlobalLimiting(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -783,7 +783,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -835,7 +835,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -868,7 +868,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -913,7 +913,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { func TestTransactionPoolRepricing(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -1006,7 +1006,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -1091,7 +1091,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) { func TestTransactionReplacement(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index aa386a995..44cde4f70 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -102,7 +102,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.State == nil { db, _ := ethdb.NewMemDatabase() - cfg.State, _ = state.New(common.Hash{}, db) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) } var ( address = common.StringToAddress("contract") @@ -133,7 +133,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { if cfg.State == nil { db, _ := ethdb.NewMemDatabase() - cfg.State, _ = state.New(common.Hash{}, db) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) } var ( vmenv = NewEnv(cfg, cfg.State) diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 7f40770d2..2c4dc5026 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -95,7 +95,7 @@ func TestExecute(t *testing.T) { func TestCall(t *testing.T) { db, _ := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, db) + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) address := common.HexToAddress("0x0a") state.SetCode(address, []byte{ byte(vm.PUSH1), 10, diff --git a/eth/api.go b/eth/api.go index 81570988c..0d90759b6 100644 --- a/eth/api.go +++ b/eth/api.go @@ -637,7 +637,7 @@ func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common return storageRangeAt(st, keyStart, maxResult), nil } -func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult { +func storageRangeAt(st state.Trie, start []byte, maxResult int) StorageRangeResult { it := trie.NewIterator(st.NodeIterator(start)) result := StorageRangeResult{Storage: storageMap{}} for i := 0; i < maxResult && it.Next(); i++ { diff --git a/eth/api_backend.go b/eth/api_backend.go index fe108d272..166b5084d 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) @@ -81,11 +80,11 @@ func (b *EthApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb return b.eth.blockchain.GetBlockByNumber(uint64(blockNr)), nil } -func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { +func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) { // Pending state is only known by the miner if blockNr == rpc.PendingBlockNumber { block, state := b.eth.miner.Pending() - return EthApiState{state}, block.Header(), nil + return state, block.Header(), nil } // Otherwise resolve the block number and return its state header, err := b.HeaderByNumber(ctx, blockNr) @@ -93,7 +92,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc. return nil, nil, err } stateDb, err := b.eth.BlockChain().StateAt(header.Root) - return EthApiState{stateDb}, header, err + return stateDb, header, err } func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { @@ -108,14 +107,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - statedb := state.(EthApiState).state - from := statedb.GetOrNewStateObject(msg.From()) - from.SetBalance(math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From(), math.MaxBig256) vmError := func() error { return nil } context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) - return vm.NewEVM(context, statedb, b.eth.chainConfig, vmCfg), vmError, nil + return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), vmError, nil } func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { @@ -200,23 +197,3 @@ func (b *EthApiBackend) EventMux() *event.TypeMux { func (b *EthApiBackend) AccountManager() *accounts.Manager { return b.eth.AccountManager() } - -type EthApiState struct { - state *state.StateDB -} - -func (s EthApiState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) { - return s.state.GetBalance(addr), nil -} - -func (s EthApiState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) { - return s.state.GetCode(addr), nil -} - -func (s EthApiState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) { - return s.state.GetState(a, b), nil -} - -func (s EthApiState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - return s.state.GetNonce(addr), nil -} diff --git a/eth/api_test.go b/eth/api_test.go index f8d2e9c76..49ce38688 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -32,7 +32,7 @@ func TestStorageRangeAt(t *testing.T) { // Create a state where account 0x010000... has a few storage entries. var ( db, _ = ethdb.NewMemDatabase() - state, _ = state.New(common.Hash{}, db) + state, _ = state.New(common.Hash{}, state.NewDatabase(db)) addr = common.Address{0x01} keys = []common.Hash{ // hashes of Keys of storage common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), diff --git a/eth/bind.go b/eth/bind.go index e5abd8617..0385db1f9 100644 --- a/eth/bind.go +++ b/eth/bind.go @@ -54,14 +54,12 @@ func NewContractBackend(apiBackend ethapi.Backend) *ContractBackend { // CodeAt retrieves any code associated with the contract from the local API. func (b *ContractBackend) CodeAt(ctx context.Context, contract common.Address, blockNum *big.Int) ([]byte, error) { - out, err := b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum)) - return common.FromHex(out), err + return b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum)) } // CodeAt retrieves any code associated with the contract from the local API. func (b *ContractBackend) PendingCodeAt(ctx context.Context, contract common.Address) ([]byte, error) { - out, err := b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber) - return common.FromHex(out), err + return b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber) } // ContractCall implements bind.ContractCaller executing an Ethereum contract diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 267a0def9..1fb5a0910 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -657,7 +657,7 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) } if index > 0 { - if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil { + if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil { t.Fatalf("state reconstruction failed: %v", err) } } diff --git a/eth/handler_test.go b/eth/handler_test.go index 413ed2bff..ca9c9e1b4 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -374,7 +374,7 @@ func testGetNodeData(t *testing.T, protocol int) { } accounts := []common.Address{testBank, acc1Addr, acc2Addr} for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), statedb) + trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb)) for j, acc := range accounts { state, _ := pm.blockchain.State() diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index da5dc5d58..c22c56dfb 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -447,8 +447,8 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add if state == nil || err != nil { return nil, err } - - return state.GetBalance(ctx, address) + b := state.GetBalance(address) + return b, state.Error() } // GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all @@ -529,31 +529,25 @@ func (s *PublicBlockChainAPI) GetUncleCountByBlockHash(ctx context.Context, bloc } // GetCode returns the code stored at the given address in the state for the given block number. -func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (string, error) { +func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if state == nil || err != nil { - return "", err - } - res, err := state.GetCode(ctx, address) - if len(res) == 0 || err != nil { // backwards compatibility - return "0x", err + return nil, err } - return common.ToHex(res), nil + code := state.GetCode(address) + return code, state.Error() } // GetStorageAt returns the storage from the state at the given address, key and // block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block // numbers are also allowed. -func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (string, error) { +func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if state == nil || err != nil { - return "0x", err - } - res, err := state.GetState(ctx, address, common.HexToHash(key)) - if err != nil { - return "0x", err + return nil, err } - return res.Hex(), nil + res := state.GetState(address, common.HexToHash(key)) + return res[:], state.Error() } // callmsg is the message type used for call transitions. @@ -978,11 +972,8 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr if state == nil || err != nil { return nil, err } - nonce, err := state.GetNonce(ctx, address) - if err != nil { - return nil, err - } - return (*hexutil.Uint64)(&nonce), nil + nonce := state.GetNonce(address) + return (*hexutil.Uint64)(&nonce), state.Error() } // getTransactionBlockData fetches the meta data for the given transaction from the chain database. This is useful to diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 68b5069d0..d122b7915 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/downloader" @@ -47,11 +48,12 @@ type Backend interface { SetHead(number uint64) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error) - StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (State, *types.Header, error) + StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + // TxPool API SendTx(ctx context.Context, signedTx *types.Transaction) error RemoveTx(txHash common.Hash) @@ -65,13 +67,6 @@ type Backend interface { CurrentBlock() *types.Block } -type State interface { - GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) - GetCode(ctx context.Context, addr common.Address) ([]byte, error) - GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) - GetNonce(ctx context.Context, addr common.Address) (uint64, error) -} - func GetAPIs(apiBackend Backend) []rpc.API { nonceLock := new(AddrLocker) return []rpc.API{ diff --git a/les/api_backend.go b/les/api_backend.go index 7d69046de..7a3c2447c 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -24,13 +24,13 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" @@ -70,12 +70,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb return b.GetBlock(ctx, header.Hash()) } -func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { +func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) { header, err := b.HeaderByNumber(ctx, blockNr) if header == nil || err != nil { return nil, nil, err } - return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil + return light.NewState(ctx, header, b.eth.odr), header, nil } func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { @@ -90,18 +90,10 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - stateDb := state.(*light.LightState).Copy() - addr := msg.From() - from, err := stateDb.GetOrNewStateObject(ctx, addr) - if err != nil { - return nil, nil, err - } - from.SetBalance(math.MaxBig256) - - vmstate := light.NewVMState(ctx, stateDb) +func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) - return vm.NewEVM(context, vmstate, b.eth.chainConfig, vmCfg), vmstate.Error, nil + return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), state.Error, nil } func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/les/odr_test.go b/les/odr_test.go index 7b34996ce..3a0fd6738 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -75,24 +75,23 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} - var res []byte + var ( + res []byte + st *state.StateDB + err error + ) for _, addr := range acc { if bc != nil { header := bc.GetHeaderByHash(bhash) - st, err := state.New(header.Root, db) - if err == nil { - bal := st.GetBalance(addr) - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } + st, err = state.New(header.Root, state.NewDatabase(db)) } else { header := lc.GetHeaderByHash(bhash) - st := light.NewLightState(light.StateTrieID(header), lc.Odr()) - bal, err := st.GetBalance(ctx, addr) - if err == nil { - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } + st = light.NewState(ctx, header, lc.Odr()) + } + if err == nil { + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) } } @@ -115,7 +114,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai data[35] = byte(i) if bc != nil { header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, db) + statedb, err := state.New(header.Root, state.NewDatabase(db)) if err == nil { from := statedb.GetOrNewStateObject(testBankAddress) @@ -133,23 +132,15 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai } } else { header := lc.GetHeaderByHash(bhash) - state := light.NewLightState(light.StateTrieID(header), lc.Odr()) - vmstate := light.NewVMState(ctx, state) - from, err := state.GetOrNewStateObject(ctx, testBankAddress) - if err == nil { - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)} - - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) - - //vmenv := light.NewEnv(ctx, state, config, lc, msg, header, vm.Config{}) - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - if vmstate.Error() == nil { - res = append(res, ret...) - } + state := light.NewState(ctx, header, lc.Odr()) + state.SetBalance(testBankAddress, math.MaxBig256) + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)} + context := core.NewEVMContext(msg, header, lc, nil) + vmenv := vm.NewEVM(context, state, config, vm.Config{}) + gp := new(core.GasPool).AddGas(math.MaxBig256) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + if state.Error() == nil { + res = append(res, ret...) } } } diff --git a/les/request_test.go b/les/request_test.go index 3add5f20d..6b594462d 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -62,7 +62,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr return nil } sti := light.StateTrieID(header) - ci := light.StorageTrieID(sti, testContractAddr, common.Hash{}) + ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{}) return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} } diff --git a/light/lightchain.go b/light/lightchain.go index 5b7e57041..87436f4a5 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -180,11 +180,6 @@ func (self *LightChain) Status() (td *big.Int, currentBlock common.Hash, genesis return self.GetTd(hash, header.Number.Uint64()), hash, self.genesisBlock.Hash() } -// State returns a new mutable state based on the current HEAD block. -func (self *LightChain) State() *LightState { - return NewLightState(StateTrieID(self.hc.CurrentHeader()), self.odr) -} - // Reset purges the entire blockchain, restoring it to its genesis state. func (bc *LightChain) Reset() { bc.ResetWithGenesisBlock(bc.genesisBlock) diff --git a/light/odr.go b/light/odr.go index ca6364f28..d19a488f6 100644 --- a/light/odr.go +++ b/light/odr.go @@ -34,7 +34,7 @@ import ( // service is not required. var NoOdr = context.Background() -// OdrBackend is an interface to a backend service that handles ODR retrievals +// OdrBackend is an interface to a backend service that handles ODR retrievals type type OdrBackend interface { Database() ethdb.Database Retrieve(ctx context.Context, req OdrRequest) error @@ -66,11 +66,11 @@ func StateTrieID(header *types.Header) *TrieID { // StorageTrieID returns a TrieID for a contract storage trie at a given account // of a given state trie. It also requires the root hash of the trie for // checking Merkle proofs. -func StorageTrieID(state *TrieID, addr common.Address, root common.Hash) *TrieID { +func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID { return &TrieID{ BlockHash: state.BlockHash, BlockNumber: state.BlockNumber, - AccKey: crypto.Keccak256(addr[:]), + AccKey: addrHash[:], Root: root, } } @@ -102,7 +102,7 @@ func storeProof(db ethdb.Database, proof []rlp.RawValue) { // CodeRequest is the ODR request type for retrieving contract code type CodeRequest struct { OdrRequest - Id *TrieID + Id *TrieID // references storage trie of the account Hash common.Hash Data []byte } diff --git a/light/odr_test.go b/light/odr_test.go index 576e3abc9..544b64eff 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -86,11 +86,11 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { return nil } -type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte +type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) -func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, odrGetBlock) } -func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var block *types.Block if bc != nil { block = bc.GetBlockByHash(bhash) @@ -98,15 +98,15 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc block, _ = lc.GetBlockByHash(ctx, bhash) } if block == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(block) - return rlp + return rlp, nil } -func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } -func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) @@ -114,43 +114,37 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) } if receipts == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(receipts) - return rlp + return rlp, nil } -func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, odrAccounts) } -func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} + var st *state.StateDB + if bc == nil { + header := lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) + } else { + header := bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + var res []byte for _, addr := range acc { - if bc != nil { - header := bc.GetHeaderByHash(bhash) - st, err := state.New(header.Root, db) - if err == nil { - bal := st.GetBalance(addr) - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } else { - header := lc.GetHeaderByHash(bhash) - st := NewLightState(StateTrieID(header), lc.Odr()) - bal, err := st.GetBalance(ctx, addr) - if err == nil { - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) } - - return res + return res, st.Error() } -func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) } type callmsg struct { types.Message @@ -158,50 +152,42 @@ type callmsg struct { func (callmsg) CheckNonce() bool { return false } -func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") - config := params.TestChainConfig var res []byte for i := 0; i < 3; i++ { data[35] = byte(i) - if bc != nil { - header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, db) - if err == nil { - from := statedb.GetOrNewStateObject(testBankAddress) - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, bc, nil) - vmenv := vm.NewEVM(context, statedb, config, vm.Config{}) - - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - res = append(res, ret...) - } + var ( + st *state.StateDB + header *types.Header + chain core.ChainContext + ) + if bc == nil { + chain = lc + header = lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) } else { - header := lc.GetHeaderByHash(bhash) - state := NewLightState(StateTrieID(header), lc.Odr()) - vmstate := NewVMState(ctx, state) - from, err := state.GetOrNewStateObject(ctx, testBankAddress) - if err == nil { - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - if vmstate.Error() == nil { - res = append(res, ret...) - } - } + chain = bc + header = bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + + // Perform read-only call. + st.SetBalance(testBankAddress, math.MaxBig256) + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} + context := core.NewEVMContext(msg, header, chain, nil) + vmenv := vm.NewEVM(context, st, config, vm.Config{}) + gp := new(core.GasPool).AddGas(math.MaxBig256) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + res = append(res, ret...) + if st.Error() != nil { + return res, st.Error() } } - return res + return res, nil } func testChainGen(i int, block *core.BlockGen) { @@ -245,7 +231,7 @@ func testChainGen(i int, block *core.BlockGen) { } } -func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { +func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { var ( evmux = new(event.TypeMux) sdb, _ = ethdb.NewMemDatabase() @@ -258,46 +244,58 @@ func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { - panic(err) + t.Fatal(err) } odr := &testOdr{sdb: sdb, ldb: ldb} - lightchain, _ := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + lightchain, err := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + if err != nil { + t.Fatal(err) + } headers := make([]*types.Header, len(gchain)) for i, block := range gchain { headers[i] = block.Header() } if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil { - panic(err) + t.Fatal(err) } - test := func(expFail uint64) { + test := func(expFail int) { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := core.GetCanonicalHash(sdb, i) - b1 := fn(NoOdr, sdb, blockchain, nil, bhash) + b1, err := fn(NoOdr, sdb, blockchain, nil, bhash) + if err != nil { + t.Fatalf("error in full-node test for block %d: %v", i, err) + } ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - b2 := fn(ctx, ldb, nil, lightchain, bhash) + + exp := i < uint64(expFail) + b2, err := fn(ctx, ldb, nil, lightchain, bhash) + if err != nil && exp { + t.Errorf("error in ODR test for block %d: %v", i, err) + } eq := bytes.Equal(b1, b2) - exp := i < expFail if exp && !eq { - t.Errorf("odr mismatch") - } - if !exp && eq { - t.Errorf("unexpected odr match") + t.Errorf("ODR test output for block %d doesn't match full node", i) } } } - odr.disable = true // expect retrievals to fail (except genesis block) without a les peer - test(expFail) - odr.disable = false - // expect all retrievals to pass - test(5) + t.Log("checking without ODR") odr.disable = true + test(1) + + // expect all retrievals to pass with ODR enabled + t.Log("checking with ODR") + odr.disable = false + test(len(gchain)) + // still expect all retrievals to pass, now data should be cached locally - test(5) + t.Log("checking without ODR, should be cached") + odr.disable = true + test(len(gchain)) } diff --git a/light/odr_util.go b/light/odr_util.go index d7f8458f1..fcdfdb82c 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -106,25 +106,6 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo return common.Hash{}, err } -// retrieveContractCode tries to retrieve the contract code of the given account -// with the given hash from the network (id points to the storage trie belonging -// to the same account) -func retrieveContractCode(ctx context.Context, odr OdrBackend, id *TrieID, hash common.Hash) ([]byte, error) { - if hash == sha3_nil { - return nil, nil - } - res, _ := odr.Database().Get(hash[:]) - if res != nil { - return res, nil - } - r := &CodeRequest{Id: id, Hash: hash} - if err := odr.Retrieve(ctx, r); err != nil { - return nil, err - } else { - return r.Data, nil - } -} - // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) { if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil { diff --git a/light/state.go b/light/state.go deleted file mode 100644 index b184dc3a5..000000000 --- a/light/state.go +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "context" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" -) - -// LightState is a memory representation of a state. -// This version is ODR capable, caching only the already accessed part of the -// state, retrieving unknown parts on-demand from the ODR backend. Changes are -// never stored in the local database, only in the memory objects. -type LightState struct { - odr OdrBackend - trie *LightTrie - id *TrieID - stateObjects map[string]*StateObject - refund *big.Int -} - -// NewLightState creates a new LightState with the specified root. -// Note that the creation of a light state is always successful, even if the -// root is non-existent. In that case, ODR retrieval will always be unsuccessful -// and every operation will return with an error or wait for the context to be -// cancelled. -func NewLightState(id *TrieID, odr OdrBackend) *LightState { - var tr *LightTrie - if id != nil { - tr = NewLightTrie(id, odr, true) - } - return &LightState{ - odr: odr, - trie: tr, - id: id, - stateObjects: make(map[string]*StateObject), - refund: new(big.Int), - } -} - -// AddRefund adds an amount to the refund value collected during a vm execution -func (self *LightState) AddRefund(gas *big.Int) { - self.refund.Add(self.refund, gas) -} - -// HasAccount returns true if an account exists at the given address -func (self *LightState) HasAccount(ctx context.Context, addr common.Address) (bool, error) { - so, err := self.GetStateObject(ctx, addr) - return so != nil, err -} - -// GetBalance retrieves the balance from the given address or 0 if the account does -// not exist -func (self *LightState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return common.Big0, err - } - if stateObject != nil { - return stateObject.balance, nil - } - - return common.Big0, nil -} - -// GetNonce returns the nonce at the given address or 0 if the account does -// not exist -func (self *LightState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return 0, err - } - if stateObject != nil { - return stateObject.nonce, nil - } - return 0, nil -} - -// GetCode returns the contract code at the given address or nil if the account -// does not exist -func (self *LightState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return nil, err - } - if stateObject != nil { - return stateObject.code, nil - } - return nil, nil -} - -// GetState returns the contract storage value at storage address b from the -// contract address a or common.Hash{} if the account does not exist -func (self *LightState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) { - stateObject, err := self.GetStateObject(ctx, a) - if err == nil && stateObject != nil { - return stateObject.GetState(ctx, b) - } - return common.Hash{}, err -} - -// HasSuicided returns true if the given account has been marked for deletion -// or false if the account does not exist -func (self *LightState) HasSuicided(ctx context.Context, addr common.Address) (bool, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err == nil && stateObject != nil { - return stateObject.remove, nil - } - return false, err -} - -/* - * SETTERS - */ - -// AddBalance adds the given amount to the balance of the specified account -func (self *LightState) AddBalance(ctx context.Context, addr common.Address, amount *big.Int) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.AddBalance(amount) - } - return err -} - -// SubBalance adds the given amount to the balance of the specified account -func (self *LightState) SubBalance(ctx context.Context, addr common.Address, amount *big.Int) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SubBalance(amount) - } - return err -} - -// SetNonce sets the nonce of the specified account -func (self *LightState) SetNonce(ctx context.Context, addr common.Address, nonce uint64) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetNonce(nonce) - } - return err -} - -// SetCode sets the contract code at the specified account -func (self *LightState) SetCode(ctx context.Context, addr common.Address, code []byte) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetCode(crypto.Keccak256Hash(code), code) - } - return err -} - -// SetState sets the storage value at storage address key of the account addr -func (self *LightState) SetState(ctx context.Context, addr common.Address, key common.Hash, value common.Hash) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetState(key, value) - } - return err -} - -// Delete marks an account to be removed and clears its balance -func (self *LightState) Suicide(ctx context.Context, addr common.Address) (bool, error) { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.MarkForDeletion() - stateObject.balance = new(big.Int) - - return true, nil - } - - return false, err -} - -// -// Get, set, new state object methods -// - -// GetStateObject returns the state object of the given account or nil if the -// account does not exist -func (self *LightState) GetStateObject(ctx context.Context, addr common.Address) (stateObject *StateObject, err error) { - stateObject = self.stateObjects[addr.Str()] - if stateObject != nil { - if stateObject.deleted { - stateObject = nil - } - return stateObject, nil - } - data, err := self.trie.Get(ctx, addr[:]) - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, nil - } - - stateObject, err = DecodeObject(ctx, self.id, addr, self.odr, []byte(data)) - if err != nil { - return nil, err - } - - self.SetStateObject(stateObject) - - return stateObject, nil -} - -// SetStateObject sets the state object of the given account -func (self *LightState) SetStateObject(object *StateObject) { - self.stateObjects[object.Address().Str()] = object -} - -// GetOrNewStateObject returns the state object of the given account or creates a -// new one if the account does not exist -func (self *LightState) GetOrNewStateObject(ctx context.Context, addr common.Address) (*StateObject, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err == nil && (stateObject == nil || stateObject.deleted) { - stateObject, err = self.CreateStateObject(ctx, addr) - } - return stateObject, err -} - -// newStateObject creates a state object whether it exists in the state or not -func (self *LightState) newStateObject(addr common.Address) *StateObject { - stateObject := NewStateObject(addr, self.odr) - self.stateObjects[addr.Str()] = stateObject - - return stateObject -} - -// CreateStateObject creates creates a new state object and takes ownership. -// This is different from "NewStateObject" -func (self *LightState) CreateStateObject(ctx context.Context, addr common.Address) (*StateObject, error) { - // Get previous (if any) - so, err := self.GetStateObject(ctx, addr) - if err != nil { - return nil, err - } - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.balance = so.balance - } - - return newSo, nil -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (self *LightState) ForEachStorage(ctx context.Context, addr common.Address, cb func(key, value common.Hash) bool) error { - so, err := self.GetStateObject(ctx, addr) - if err != nil { - return err - } - - if so == nil { - return nil - } - - for h, v := range so.storage { - cb(h, v) - } - return nil -} - -// -// Setting, copying of the state methods -// - -// Copy creates a copy of the state -func (self *LightState) Copy() *LightState { - // ignore error - we assume state-to-be-copied always exists - state := NewLightState(nil, self.odr) - state.trie = self.trie - state.id = self.id - for k, stateObject := range self.stateObjects { - if stateObject.dirty { - state.stateObjects[k] = stateObject.Copy() - } - } - - state.refund.Set(self.refund) - return state -} - -// Set copies the contents of the given state onto this state, overwriting -// its contents -func (self *LightState) Set(state *LightState) { - self.trie = state.trie - self.stateObjects = state.stateObjects - self.refund = state.refund -} - -// GetRefund returns the refund value collected during a vm execution -func (self *LightState) GetRefund() *big.Int { - return self.refund -} diff --git a/light/state_object.go b/light/state_object.go deleted file mode 100644 index a54ea1d9f..000000000 --- a/light/state_object.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "bytes" - "context" - "fmt" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" -) - -var emptyCodeHash = crypto.Keccak256(nil) - -// Code represents a contract code in binary form -type Code []byte - -// String returns a string representation of the code -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") -} - -// Storage is a memory map cache of a contract storage -type Storage map[common.Hash]common.Hash - -// String returns a string representation of the storage cache -func (self Storage) String() (str string) { - for key, value := range self { - str += fmt.Sprintf("%X : %X\n", key, value) - } - - return -} - -// Copy copies the contents of a storage cache -func (self Storage) Copy() Storage { - cpy := make(Storage) - for key, value := range self { - cpy[key] = value - } - - return cpy -} - -// StateObject is a memory representation of an account or contract and its storage. -// This version is ODR capable, caching only the already accessed part of the -// storage, retrieving unknown parts on-demand from the ODR backend. Changes are -// never stored in the local database, only in the memory objects. -type StateObject struct { - odr OdrBackend - trie *LightTrie - - // Address belonging to this account - address common.Address - // The balance of the account - balance *big.Int - // The nonce of the account - nonce uint64 - // The code hash if code is present (i.e. a contract) - codeHash []byte - // The code for this account - code Code - // Cached storage (flushed when updated) - storage Storage - - // Mark for deletion - // When an object is marked for deletion it will be delete from the trie - // during the "update" phase of the state transition - remove bool - deleted bool - dirty bool -} - -// NewStateObject creates a new StateObject of the specified account address -func NewStateObject(address common.Address, odr OdrBackend) *StateObject { - object := &StateObject{ - odr: odr, - address: address, - balance: new(big.Int), - dirty: true, - codeHash: emptyCodeHash, - storage: make(Storage), - } - object.trie = NewLightTrie(&TrieID{}, odr, true) - return object -} - -// MarkForDeletion marks an account to be removed -func (self *StateObject) MarkForDeletion() { - self.remove = true - self.dirty = true -} - -// getAddr gets the storage value at the given address from the trie -func (c *StateObject) getAddr(ctx context.Context, addr common.Hash) (common.Hash, error) { - var ret []byte - val, err := c.trie.Get(ctx, addr[:]) - if err != nil { - return common.Hash{}, err - } - rlp.DecodeBytes(val, &ret) - return common.BytesToHash(ret), nil -} - -// Storage returns the storage cache object of the account -func (self *StateObject) Storage() Storage { - return self.storage -} - -// GetState returns the storage value at the given address from either the cache -// or the trie -func (self *StateObject) GetState(ctx context.Context, key common.Hash) (common.Hash, error) { - value, exists := self.storage[key] - if !exists { - var err error - value, err = self.getAddr(ctx, key) - if err != nil { - return common.Hash{}, err - } - if (value != common.Hash{}) { - self.storage[key] = value - } - } - - return value, nil -} - -// SetState sets the storage value at the given address -func (self *StateObject) SetState(k, value common.Hash) { - self.storage[k] = value - self.dirty = true -} - -// AddBalance adds the given amount to the account balance -func (c *StateObject) AddBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Add(c.balance, amount)) -} - -// SubBalance subtracts the given amount from the account balance -func (c *StateObject) SubBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Sub(c.balance, amount)) -} - -// SetBalance sets the account balance to the given amount -func (c *StateObject) SetBalance(amount *big.Int) { - c.balance = amount - c.dirty = true -} - -// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures -func (c *StateObject) ReturnGas(gas *big.Int) {} - -// Copy creates a copy of the state object -func (self *StateObject) Copy() *StateObject { - stateObject := NewStateObject(self.Address(), self.odr) - stateObject.balance.Set(self.balance) - stateObject.codeHash = common.CopyBytes(self.codeHash) - stateObject.nonce = self.nonce - stateObject.trie = self.trie - stateObject.code = self.code - stateObject.storage = self.storage.Copy() - stateObject.remove = self.remove - stateObject.dirty = self.dirty - stateObject.deleted = self.deleted - - return stateObject -} - -// -// Attribute accessors -// - -// empty returns whether the account is considered empty. -func (self *StateObject) empty() bool { - return self.nonce == 0 && self.balance.Sign() == 0 && bytes.Equal(self.codeHash, emptyCodeHash) -} - -// Balance returns the account balance -func (self *StateObject) Balance() *big.Int { - return self.balance -} - -// Address returns the address of the contract/account -func (self *StateObject) Address() common.Address { - return self.address -} - -// Code returns the contract code -func (self *StateObject) Code() []byte { - return self.code -} - -// SetCode sets the contract code -func (self *StateObject) SetCode(hash common.Hash, code []byte) { - self.code = code - self.codeHash = hash[:] - self.dirty = true -} - -// SetNonce sets the account nonce -func (self *StateObject) SetNonce(nonce uint64) { - self.nonce = nonce - self.dirty = true -} - -// Nonce returns the account nonce -func (self *StateObject) Nonce() uint64 { - return self.nonce -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { - for h, v := range self.storage { - cb(h, v) - } -} - -// Never called, but must be present to allow StateObject to be used -// as a vm.Account interface that also satisfies the vm.ContractRef -// interface. Interfaces are awesome. -func (self *StateObject) Value() *big.Int { - panic("Value on StateObject should never be called") -} - -// Encoding - -type extStateObject struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte -} - -// DecodeObject decodes an RLP-encoded state object. -func DecodeObject(ctx context.Context, stateID *TrieID, address common.Address, odr OdrBackend, data []byte) (*StateObject, error) { - var ( - obj = &StateObject{address: address, odr: odr, storage: make(Storage)} - ext extStateObject - err error - ) - if err = rlp.DecodeBytes(data, &ext); err != nil { - return nil, err - } - trieID := StorageTrieID(stateID, address, ext.Root) - obj.trie = NewLightTrie(trieID, odr, true) - if !bytes.Equal(ext.CodeHash, emptyCodeHash) { - if obj.code, err = retrieveContractCode(ctx, obj.odr, trieID, common.BytesToHash(ext.CodeHash)); err != nil { - return nil, fmt.Errorf("can't find code for hash %x: %v", ext.CodeHash, err) - } - } - obj.nonce = ext.Nonce - obj.balance = ext.Balance - obj.codeHash = ext.CodeHash - return obj, nil -} diff --git a/light/state_test.go b/light/state_test.go deleted file mode 100644 index e776efec8..000000000 --- a/light/state_test.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "bytes" - "context" - "math/big" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/ethdb" -) - -func makeTestState() (common.Hash, ethdb.Database) { - sdb, _ := ethdb.NewMemDatabase() - st, _ := state.New(common.Hash{}, sdb) - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - for j := byte(0); j < 100; j++ { - st.SetState(addr, common.Hash{j}, common.Hash{i, j}) - } - st.SetNonce(addr, 100) - st.AddBalance(addr, big.NewInt(int64(i))) - st.SetCode(addr, []byte{i, i, i}) - } - root, _ := st.Commit(false) - return root, sdb -} - -func TestLightStateOdr(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - addr := common.Address{100} - _, err := ls.CreateStateObject(ctx, addr) - if err != nil { - t.Fatalf("Error creating state object: %v", err) - } - err = ls.SetCode(ctx, addr, []byte{100, 100, 100}) - if err != nil { - t.Fatalf("Error setting code: %v", err) - } - err = ls.AddBalance(ctx, addr, big.NewInt(1100)) - if err != nil { - t.Fatalf("Error adding balance to acc[100]: %v", err) - } - for j := byte(0); j < 101; j++ { - err = ls.SetState(ctx, addr, common.Hash{j}, common.Hash{100, j}) - if err != nil { - t.Fatalf("Error setting storage of acc[100]: %v", err) - } - } - err = ls.SetNonce(ctx, addr, 100) - if err != nil { - t.Fatalf("Error setting nonce for acc[100]: %v", err) - } - - for i := byte(0); i < 101; i++ { - addr := common.Address{i} - - bal, err := ls.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance of acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+1000 { - t.Fatalf("Incorrect balance at acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - - nonce, err := ls.GetNonce(ctx, addr) - if err != nil { - t.Fatalf("Error getting nonce of acc[%d]: %v", i, err) - } - if nonce != 100 { - t.Fatalf("Incorrect nonce at acc[%d]: expected %v, got %v", i, 100, nonce) - } - - code, err := ls.GetCode(ctx, addr) - exp := []byte{i, i, i} - if err != nil { - t.Fatalf("Error getting code of acc[%d]: %v", i, err) - } - if !bytes.Equal(code, exp) { - t.Fatalf("Incorrect code at acc[%d]: expected %v, got %v", i, exp, code) - } - - for j := byte(0); j < 101; j++ { - exp := common.Hash{i, j} - val, err := ls.GetState(ctx, addr, common.Hash{j}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[%d]: %v", i, j, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[%d]: expected %04x, got %04x", i, j, exp, val) - } - } - } -} - -func TestLightStateSetCopy(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - ls2 := ls.Copy() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls2.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls2.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 200}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - lsx := ls.Copy() - ls.Set(ls2) - ls2.Set(lsx) - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - // check balance in ls - bal, err := ls.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance to acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+2000 { - t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - // check balance in ls2 - bal, err = ls2.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance to acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+1000 { - t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - // check storage in ls - exp := common.Hash{i, 200} - val, err := ls.GetState(ctx, addr, common.Hash{100}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val) - } - // check storage in ls2 - exp = common.Hash{i, 100} - val, err = ls2.GetState(ctx, addr, common.Hash{100}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val) - } - } -} - -func TestLightStateDelete(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - addr := common.Address{42} - - b, err := ls.HasAccount(ctx, addr) - if err != nil { - t.Fatalf("HasAccount error: %v", err) - } - if !b { - t.Fatalf("HasAccount returned false, expected true") - } - - b, err = ls.HasSuicided(ctx, addr) - if err != nil { - t.Fatalf("HasSuicided error: %v", err) - } - if b { - t.Fatalf("HasSuicided returned true, expected false") - } - - ls.Suicide(ctx, addr) - - b, err = ls.HasSuicided(ctx, addr) - if err != nil { - t.Fatalf("HasSuicided error: %v", err) - } - if !b { - t.Fatalf("HasSuicided returned false, expected true") - } -} diff --git a/light/trie.go b/light/trie.go index 2988a16cf..7502b6e5d 100644 --- a/light/trie.go +++ b/light/trie.go @@ -18,99 +18,216 @@ package light import ( "context" + "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" ) -// LightTrie is an ODR-capable wrapper around trie.SecureTrie -type LightTrie struct { - trie *trie.SecureTrie +func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { + state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr)) + return state +} + +func NewStateDatabase(ctx context.Context, head *types.Header, odr OdrBackend) state.Database { + return &odrDatabase{ctx, StateTrieID(head), odr} +} + +type odrDatabase struct { + ctx context.Context + id *TrieID + backend OdrBackend +} + +func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) { + return &odrTrie{db: db, id: db.id}, nil +} + +func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, error) { + return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil +} + +func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { + switch t := t.(type) { + case *odrTrie: + cpy := &odrTrie{db: t.db, id: t.id} + if t.trie != nil { + cpytrie := *t.trie + cpy.trie = &cpytrie + } + return cpy + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + +func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { + if codeHash == sha3_nil { + return nil, nil + } + if code, err := db.backend.Database().Get(codeHash[:]); err == nil { + return code, nil + } + id := *db.id + id.AccKey = addrHash[:] + req := &CodeRequest{Id: &id, Hash: codeHash} + err := db.backend.Retrieve(db.ctx, req) + return req.Data, err +} + +func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { + code, err := db.ContractCode(addrHash, codeHash) + return len(code), err +} + +type odrTrie struct { + db *odrDatabase id *TrieID - odr OdrBackend - db ethdb.Database -} - -// NewLightTrie creates a new LightTrie instance. It doesn't instantly try to -// access the db or network and retrieve the root node, it only initializes its -// encapsulated SecureTrie at the first actual operation. -func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie { - return &LightTrie{ - // SecureTrie is initialized before first request - id: id, - odr: odr, - db: odr.Database(), + trie *trie.Trie +} + +func (t *odrTrie) TryGet(key []byte) ([]byte, error) { + key = crypto.Keccak256(key) + var res []byte + err := t.do(key, func() (err error) { + res, err = t.trie.TryGet(key) + return err + }) + return res, err +} + +func (t *odrTrie) TryUpdate(key, value []byte) error { + key = crypto.Keccak256(key) + return t.do(key, func() error { + return t.trie.TryDelete(key) + }) +} + +func (t *odrTrie) TryDelete(key []byte) error { + key = crypto.Keccak256(key) + return t.do(key, func() error { + return t.trie.TryDelete(key) + }) +} + +func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) { + if t.trie == nil { + return t.id.Root, nil + } + return t.trie.CommitTo(db) +} + +func (t *odrTrie) Hash() common.Hash { + if t.trie == nil { + return t.id.Root } + return t.trie.Hash() +} + +func (t *odrTrie) NodeIterator(startkey []byte) trie.NodeIterator { + return newNodeIterator(t, startkey) } -// retrieveKey retrieves a single key, returns true and stores nodes in local -// database if successful -func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool { - r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)} - return t.odr.Retrieve(ctx, r) == nil +func (t *odrTrie) GetKey(sha []byte) []byte { + return nil } // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError -func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error { - err := fn() - for err != nil { +func (t *odrTrie) do(key []byte, fn func() error) error { + for { + var err error + if t.trie == nil { + t.trie, err = trie.New(t.id.Root, t.db.backend.Database()) + } + if err == nil { + err = fn() + } if _, ok := err.(*trie.MissingNodeError); !ok { return err } - if !t.retrieveKey(ctx, key) { - break + r := &TrieRequest{Id: t.id, Key: key} + if err := t.db.backend.Retrieve(t.db.ctx, r); err != nil { + return fmt.Errorf("can't fetch trie key %x: %v", key, err) } - err = fn() } - return err } -// Get returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) - } - if err == nil { - res, err = t.trie.TryGet(key) - } - return +type nodeIterator struct { + trie.NodeIterator + t *odrTrie + err error +} + +func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator { + it := &nodeIterator{t: t} + // Open the actual non-ODR trie if that hasn't happened yet. + if t.trie == nil { + it.do(func() error { + t, err := trie.New(t.id.Root, t.db.backend.Database()) + if err == nil { + it.t.trie = t + } + return err + }) + } + it.do(func() error { + it.NodeIterator = it.t.trie.NodeIterator(startkey) + return it.NodeIterator.Error() }) - return + return it } -// Update associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) - } - if err == nil { - err = t.trie.TryUpdate(key, value) - } - return +func (it *nodeIterator) Next(descend bool) bool { + var ok bool + it.do(func() error { + ok = it.NodeIterator.Next(descend) + return it.NodeIterator.Error() }) - return + return ok } -// Delete removes any existing value for key from the trie. -func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) +// do runs fn and attempts to fill in missing nodes by retrieving. +func (it *nodeIterator) do(fn func() error) { + var lasthash common.Hash + for { + it.err = fn() + missing, ok := it.err.(*trie.MissingNodeError) + if !ok { + return } - if err == nil { - err = t.trie.TryDelete(key) + if missing.NodeHash == lasthash { + it.err = fmt.Errorf("retrieve loop for trie node %x", missing.NodeHash) + return } - return - }) - return + lasthash = missing.NodeHash + r := &TrieRequest{Id: it.t.id, Key: nibblesToKey(missing.Path)} + if it.err = it.t.db.backend.Retrieve(it.t.db.ctx, r); it.err != nil { + return + } + } +} + +func (it *nodeIterator) Error() error { + if it.err != nil { + return it.err + } + return it.NodeIterator.Error() +} + +func nibblesToKey(nib []byte) []byte { + if len(nib) > 0 && nib[len(nib)-1] == 0x10 { + nib = nib[:len(nib)-1] // drop terminator + } + if len(nib)&1 == 1 { + nib = append(nib, 0) // make even + } + key := make([]byte, len(nib)/2) + for bi, ni := 0, 0; ni < len(nib); bi, ni = bi+1, ni+2 { + key[bi] = nib[ni]<<4 | nib[ni+1] + } + return key } diff --git a/light/trie_test.go b/light/trie_test.go new file mode 100644 index 000000000..9b2cf7c2b --- /dev/null +++ b/light/trie_test.go @@ -0,0 +1,83 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package light + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +func TestNodeIterator(t *testing.T) { + var ( + fulldb, _ = ethdb.NewMemDatabase() + lightdb, _ = ethdb.NewMemDatabase() + gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}} + genesis = gspec.MustCommit(fulldb) + ) + gspec.MustCommit(lightdb) + blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), new(event.TypeMux), vm.Config{}) + gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, fulldb, 4, testChainGen) + if _, err := blockchain.InsertChain(gchain); err != nil { + panic(err) + } + + ctx := context.Background() + odr := &testOdr{sdb: fulldb, ldb: lightdb} + head := blockchain.CurrentHeader() + lightTrie, _ := NewStateDatabase(ctx, head, odr).OpenTrie(head.Root) + fullTrie, _ := state.NewDatabase(fulldb).OpenTrie(head.Root) + if err := diffTries(fullTrie, lightTrie); err != nil { + t.Fatal(err) + } +} + +func diffTries(t1, t2 state.Trie) error { + i1 := trie.NewIterator(t1.NodeIterator(nil)) + i2 := trie.NewIterator(t2.NodeIterator(nil)) + for i1.Next() && i2.Next() { + if !bytes.Equal(i1.Key, i2.Key) { + spew.Dump(i2) + return fmt.Errorf("tries have different keys %x, %x", i1.Key, i2.Key) + } + if !bytes.Equal(i2.Value, i2.Value) { + return fmt.Errorf("tries differ at key %x", i1.Key) + } + } + switch { + case i1.Err != nil: + return fmt.Errorf("full trie iterator error: %v", i1.Err) + case i2.Err != nil: + return fmt.Errorf("light trie iterator error: %v", i1.Err) + case i1.Next(): + return fmt.Errorf("full trie iterator has more k/v pairs") + case i2.Next(): + return fmt.Errorf("light trie iterator has more k/v pairs") + } + return nil +} diff --git a/light/txpool.go b/light/txpool.go index 7276874b8..0430b280f 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -100,17 +101,18 @@ func NewTxPool(config *params.ChainConfig, eventMux *event.TypeMux, chain *Light } // currentState returns the light state of the current head header -func (pool *TxPool) currentState() *LightState { - return NewLightState(StateTrieID(pool.chain.CurrentHeader()), pool.odr) +func (pool *TxPool) currentState(ctx context.Context) *state.StateDB { + return NewState(ctx, pool.chain.CurrentHeader(), pool.odr) } // GetNonce returns the "pending" nonce of a given address. It always queries // the nonce belonging to the latest header too in order to detect if another // client using the same key sent a transaction. func (pool *TxPool) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - nonce, err := pool.currentState().GetNonce(ctx, addr) - if err != nil { - return 0, err + state := pool.currentState(ctx) + nonce := state.GetNonce(addr) + if state.Error() != nil { + return 0, state.Error() } sn, ok := pool.nonce[addr] if ok && sn > nonce { @@ -357,13 +359,9 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error return core.ErrInvalidSender } // Last but not least check for nonce errors - currentState := pool.currentState() - if n, err := currentState.GetNonce(ctx, from); err == nil { - if n > tx.Nonce() { - return core.ErrNonceTooLow - } - } else { - return err + currentState := pool.currentState(ctx) + if n := currentState.GetNonce(from); n > tx.Nonce() { + return core.ErrNonceTooLow } // Check the transaction doesn't exceed the current @@ -382,12 +380,8 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error // Transactor should have enough funds to cover the costs // cost == V + GP * GL - if b, err := currentState.GetBalance(ctx, from); err == nil { - if b.Cmp(tx.Cost()) < 0 { - return core.ErrInsufficientFunds - } - } else { - return err + if b := currentState.GetBalance(from); b.Cmp(tx.Cost()) < 0 { + return core.ErrInsufficientFunds } // Should supply enough intrinsic gas @@ -395,7 +389,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error return core.ErrIntrinsicGas } - return nil + return currentState.Error() } // add validates a new transaction and sets its state pending if processable. diff --git a/light/vm_env.go b/light/vm_env.go deleted file mode 100644 index 54aa12875..000000000 --- a/light/vm_env.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "context" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" -) - -// VMState is a wrapper for the light state that holds the actual context and -// passes it to any state operation that requires it. -type VMState struct { - ctx context.Context - state *LightState - snapshots []*LightState - err error -} - -func NewVMState(ctx context.Context, state *LightState) *VMState { - return &VMState{ctx: ctx, state: state} -} - -func (s *VMState) Error() error { - return s.err -} - -func (s *VMState) AddLog(log *types.Log) {} - -func (s *VMState) AddPreimage(hash common.Hash, preimage []byte) {} - -// errHandler handles and stores any state error that happens during execution. -func (s *VMState) errHandler(err error) { - if err != nil && s.err == nil { - s.err = err - } -} - -func (self *VMState) Snapshot() int { - self.snapshots = append(self.snapshots, self.state.Copy()) - return len(self.snapshots) - 1 -} - -func (self *VMState) RevertToSnapshot(idx int) { - self.state.Set(self.snapshots[idx]) - self.snapshots = self.snapshots[:idx] -} - -// CreateAccount creates creates a new account object and takes ownership. -func (s *VMState) CreateAccount(addr common.Address) { - _, err := s.state.CreateStateObject(s.ctx, addr) - s.errHandler(err) -} - -// AddBalance adds the given amount to the balance of the specified account -func (s *VMState) AddBalance(addr common.Address, amount *big.Int) { - err := s.state.AddBalance(s.ctx, addr, amount) - s.errHandler(err) -} - -// SubBalance adds the given amount to the balance of the specified account -func (s *VMState) SubBalance(addr common.Address, amount *big.Int) { - err := s.state.SubBalance(s.ctx, addr, amount) - s.errHandler(err) -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (s *VMState) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) { - err := s.state.ForEachStorage(s.ctx, addr, cb) - s.errHandler(err) -} - -// GetBalance retrieves the balance from the given address or 0 if the account does -// not exist -func (s *VMState) GetBalance(addr common.Address) *big.Int { - res, err := s.state.GetBalance(s.ctx, addr) - s.errHandler(err) - return res -} - -// GetNonce returns the nonce at the given address or 0 if the account does -// not exist -func (s *VMState) GetNonce(addr common.Address) uint64 { - res, err := s.state.GetNonce(s.ctx, addr) - s.errHandler(err) - return res -} - -// SetNonce sets the nonce of the specified account -func (s *VMState) SetNonce(addr common.Address, nonce uint64) { - err := s.state.SetNonce(s.ctx, addr, nonce) - s.errHandler(err) -} - -// GetCode returns the contract code at the given address or nil if the account -// does not exist -func (s *VMState) GetCode(addr common.Address) []byte { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return res -} - -// GetCodeHash returns the contract code hash at the given address -func (s *VMState) GetCodeHash(addr common.Address) common.Hash { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return crypto.Keccak256Hash(res) -} - -// GetCodeSize returns the contract code size at the given address -func (s *VMState) GetCodeSize(addr common.Address) int { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return len(res) -} - -// SetCode sets the contract code at the specified account -func (s *VMState) SetCode(addr common.Address, code []byte) { - err := s.state.SetCode(s.ctx, addr, code) - s.errHandler(err) -} - -// AddRefund adds an amount to the refund value collected during a vm execution -func (s *VMState) AddRefund(gas *big.Int) { - s.state.AddRefund(gas) -} - -// GetRefund returns the refund value collected during a vm execution -func (s *VMState) GetRefund() *big.Int { - return s.state.GetRefund() -} - -// GetState returns the contract storage value at storage address b from the -// contract address a or common.Hash{} if the account does not exist -func (s *VMState) GetState(a common.Address, b common.Hash) common.Hash { - res, err := s.state.GetState(s.ctx, a, b) - s.errHandler(err) - return res -} - -// SetState sets the storage value at storage address key of the account addr -func (s *VMState) SetState(addr common.Address, key common.Hash, value common.Hash) { - err := s.state.SetState(s.ctx, addr, key, value) - s.errHandler(err) -} - -// Suicide marks an account to be removed and clears its balance -func (s *VMState) Suicide(addr common.Address) bool { - res, err := s.state.Suicide(s.ctx, addr) - s.errHandler(err) - return res -} - -// Exist returns true if an account exists at the given address -func (s *VMState) Exist(addr common.Address) bool { - res, err := s.state.HasAccount(s.ctx, addr) - s.errHandler(err) - return res -} - -// Empty returns true if the account at the given address is considered empty -func (s *VMState) Empty(addr common.Address) bool { - so, err := s.state.GetStateObject(s.ctx, addr) - s.errHandler(err) - return so == nil || so.empty() -} - -// HasSuicided returns true if the given account has been marked for deletion -// or false if the account does not exist -func (s *VMState) HasSuicided(addr common.Address) bool { - res, err := s.state.HasSuicided(s.ctx, addr) - s.errHandler(err) - return res -} diff --git a/miner/worker.go b/miner/worker.go index 803015390..e44514755 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -274,7 +274,7 @@ func (self *worker) wait() { } go self.mux.Post(core.NewMinedBlockEvent{Block: block}) } else { - work.state.Commit(self.config.IsEIP158(block.Number())) + work.state.CommitTo(self.chainDb, self.config.IsEIP158(block.Number())) stat, err := self.chain.WriteBlock(block) if err != nil { log.Error("Failed writing block to chain", "err", err) diff --git a/tests/block_test_util.go b/tests/block_test_util.go index b9678a77b..24d4672b6 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -204,7 +204,7 @@ func runBlockTest(homesteadBlock, daoForkBlock, gasPriceFork *big.Int, test *Blo // InsertPreState populates the given database with the genesis // accounts defined by the test. func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) { - statedb, err := state.New(common.Hash{}, db) + statedb, err := state.New(common.Hash{}, state.NewDatabase(db)) if err != nil { return nil, err } @@ -232,7 +232,7 @@ func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) { } } - root, err := statedb.Commit(false) + root, err := statedb.CommitTo(db, false) if err != nil { return nil, fmt.Errorf("error writing state: %v", err) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index c1892cdcc..58acdd488 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -20,7 +20,6 @@ import ( "bytes" "fmt" "io" - "math/big" "strconv" "strings" "testing" @@ -99,7 +98,7 @@ func benchStateTest(chainConfig *params.ChainConfig, test VmTest, env map[string statedb := makePreState(db, test.Pre) b.StartTimer() - RunState(chainConfig, statedb, env, test.Exec) + RunState(chainConfig, statedb, db, env, test.Exec) } func runStateTests(chainConfig *params.ChainConfig, tests map[string]VmTest, skipTests []string) error { @@ -143,16 +142,9 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error { env["currentTimestamp"] = test.Env.CurrentTimestamp.(string) } - var ( - ret []byte - // gas *big.Int - // err error - logs []*types.Log - ) + ret, logs, root, _ := RunState(chainConfig, statedb, db, env, test.Transaction) - ret, logs, _, _ = RunState(chainConfig, statedb, env, test.Transaction) - - // Compare expected and actual return + // Return value: var rexp []byte if strings.HasPrefix(test.Out, "#") { n, _ := strconv.Atoi(test.Out[1:]) @@ -163,61 +155,43 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error { if !bytes.Equal(rexp, ret) { return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret) } - - // check post state + // Post state content: for addr, account := range test.Post { address := common.HexToAddress(addr) if !statedb.Exist(address) { return fmt.Errorf("did not find expected post-state account: %s", addr) } - if balance := statedb.GetBalance(address); balance.Cmp(math.MustParseBig256(account.Balance)) != 0 { return fmt.Errorf("(%x) balance failed. Expected: %v have: %v\n", address[:4], math.MustParseBig256(account.Balance), balance) } - if nonce := statedb.GetNonce(address); nonce != math.MustParseUint64(account.Nonce) { return fmt.Errorf("(%x) nonce failed. Expected: %v have: %v\n", address[:4], account.Nonce, nonce) } - for addr, value := range account.Storage { v := statedb.GetState(address, common.HexToHash(addr)) vexp := common.HexToHash(value) - if v != vexp { return fmt.Errorf("storage failed:\n%x: %s:\nexpected: %x\nhave: %x\n(%v %v)\n", address[:4], addr, vexp, v, vexp.Big(), v.Big()) } } } - - root, _ := statedb.Commit(false) + // Root: if common.HexToHash(test.PostStateRoot) != root { return fmt.Errorf("Post state root error. Expected: %s have: %x", test.PostStateRoot, root) } - - // check logs - if len(test.Logs) > 0 { - if err := checkLogs(test.Logs, logs); err != nil { - return err - } - } - - return nil + // Logs: + return checkLogs(test.Logs, logs) } -func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, env, tx map[string]string) ([]byte, []*types.Log, *big.Int, error) { +func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, db ethdb.Database, env, tx map[string]string) ([]byte, []*types.Log, common.Hash, error) { environment, msg := NewEVMEnvironment(false, chainConfig, statedb, env, tx) gaspool := new(core.GasPool).AddGas(math.MustParseBig256(env["currentGasLimit"])) - root, _ := statedb.Commit(false) - statedb.Reset(root) - snapshot := statedb.Snapshot() - - ret, gasUsed, err := core.ApplyMessage(environment, msg, gaspool) + ret, _, err := core.ApplyMessage(environment, msg, gaspool) if err != nil { statedb.RevertToSnapshot(snapshot) } - statedb.Commit(chainConfig.IsEIP158(environment.Context.BlockNumber)) - - return ret, statedb.Logs(), gasUsed, err + root, _ := statedb.CommitTo(db, chainConfig.IsEIP158(environment.Context.BlockNumber)) + return ret, statedb.Logs(), root, err } diff --git a/tests/util.go b/tests/util.go index a3a9a1f64..ff02679ec 100644 --- a/tests/util.go +++ b/tests/util.go @@ -48,7 +48,6 @@ func init() { } func checkLogs(tlog []Log, logs []*types.Log) error { - if len(tlog) != len(logs) { return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs)) } else { @@ -106,10 +105,14 @@ func (self Log) Topics() [][]byte { } func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { - statedb, _ := state.New(common.Hash{}, db) + sdb := state.NewDatabase(db) + statedb, _ := state.New(common.Hash{}, sdb) for addr, account := range accounts { insertAccount(statedb, addr, account) } + // Commit and re-open to start with a clean state. + root, _ := statedb.CommitTo(db, false) + statedb, _ = state.New(root, sdb) return statedb } diff --git a/trie/proof.go b/trie/proof.go index 1f8f76b1b..298f648c4 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -125,7 +125,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value } func get(tn node, key []byte) ([]byte, node) { - for len(key) > 0 { + for { switch n := tn.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { @@ -140,9 +140,10 @@ func get(tn node, key []byte) ([]byte, node) { return key, n case nil: return key, nil + case valueNode: + return nil, n default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - return nil, tn.(valueNode) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 37d1d4b09..20c303f31 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,6 +156,11 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } +func (t *SecureTrie) Copy() *SecureTrie { + cpy := *t + return &cpy +} + // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // starts at the key after the given start key. func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { |