aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPéter Szilágyi <peterke@gmail.com>2018-02-06 00:40:32 +0800
committerFelix Lange <fjl@users.noreply.github.com>2018-02-06 00:40:32 +0800
commit55599ee95d4151a2502465e0afc7c47bd1acba77 (patch)
tree4165e73ae852db4f025a5ed57f0bc499e87cb8b9
parent59336283c0dbeb1d0a74ff7a8b717b2b3bb0cf40 (diff)
downloaddexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar.gz
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar.bz2
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar.lz
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar.xz
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.tar.zst
dexon-55599ee95d4151a2502465e0afc7c47bd1acba77.zip
core, trie: intermediate mempool between trie and database (#15857)
This commit reduces database I/O by not writing every state trie to disk.
-rw-r--r--accounts/abi/bind/backends/simulated.go14
-rw-r--r--cmd/evm/runner.go4
-rw-r--r--cmd/geth/chaincmd.go4
-rw-r--r--cmd/geth/main.go3
-rw-r--r--cmd/geth/usage.go6
-rw-r--r--cmd/utils/cmd.go26
-rw-r--r--cmd/utils/flags.go47
-rw-r--r--common/size.go27
-rw-r--r--consensus/errors.go4
-rw-r--r--core/bench_test.go4
-rw-r--r--core/block_validator.go9
-rw-r--r--core/block_validator_test.go8
-rw-r--r--core/blockchain.go334
-rw-r--r--core/blockchain_test.go145
-rw-r--r--core/chain_indexer.go3
-rw-r--r--core/chain_makers.go9
-rw-r--r--core/chain_makers_test.go2
-rw-r--r--core/dao_test.go24
-rw-r--r--core/genesis.go24
-rw-r--r--core/genesis_test.go6
-rw-r--r--core/state/database.go51
-rw-r--r--core/state/iterator_test.go17
-rw-r--r--core/state/state_object.go5
-rw-r--r--core/state/state_test.go6
-rw-r--r--core/state/statedb.go38
-rw-r--r--core/state/statedb_test.go14
-rw-r--r--core/state/sync_test.go44
-rw-r--r--core/tx_pool_test.go4
-rw-r--r--core/types/block.go9
-rw-r--r--core/types/receipt.go13
-rw-r--r--core/types/transaction.go2
-rw-r--r--eth/api.go4
-rw-r--r--eth/api_tracer.go135
-rw-r--r--eth/backend.go8
-rw-r--r--eth/config.go8
-rw-r--r--eth/downloader/downloader.go317
-rw-r--r--eth/downloader/downloader_test.go192
-rw-r--r--eth/downloader/queue.go169
-rw-r--r--eth/downloader/statesync.go31
-rw-r--r--eth/handler.go6
-rw-r--r--eth/handler_test.go16
-rw-r--r--eth/helper_test.go14
-rw-r--r--eth/protocol_test.go6
-rw-r--r--eth/sync_test.go4
-rw-r--r--internal/ethapi/api.go2
-rw-r--r--les/handler.go193
-rw-r--r--les/handler_test.go2
-rw-r--r--les/helper_test.go2
-rw-r--r--les/odr_test.go1
-rw-r--r--light/lightchain.go7
-rw-r--r--light/nodeset.go8
-rw-r--r--light/odr_test.go4
-rw-r--r--light/postprocess.go64
-rw-r--r--light/trie.go18
-rw-r--r--light/trie_test.go2
-rw-r--r--light/txpool_test.go2
-rw-r--r--miner/worker.go2
-rw-r--r--tests/block_test_util.go2
-rw-r--r--tests/state_test_util.go6
-rw-r--r--trie/database.go355
-rw-r--r--trie/hasher.go61
-rw-r--r--trie/iterator_test.go125
-rw-r--r--trie/proof.go47
-rw-r--r--trie/secure_trie.go62
-rw-r--r--trie/secure_trie_test.go20
-rw-r--r--trie/sync.go14
-rw-r--r--trie/sync_test.go103
-rw-r--r--trie/trie.go90
-rw-r--r--trie/trie_test.go104
69 files changed, 1953 insertions, 1159 deletions
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 1803d3f23..bd342a8cb 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -68,7 +68,7 @@ func NewSimulatedBackend(alloc core.GenesisAlloc) *SimulatedBackend {
database, _ := ethdb.NewMemDatabase()
genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc}
genesis.MustCommit(database)
- blockchain, _ := core.NewBlockChain(database, genesis.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := core.NewBlockChain(database, nil, genesis.Config, ethash.NewFaker(), vm.Config{})
backend := &SimulatedBackend{
database: database,
@@ -102,8 +102,10 @@ func (b *SimulatedBackend) Rollback() {
func (b *SimulatedBackend) rollback() {
blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {})
+ statedb, _ := b.blockchain.State()
+
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
}
// CodeAt returns the code associated with a certain account in the blockchain.
@@ -309,8 +311,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
}
block.AddTx(tx)
})
+ statedb, _ := b.blockchain.State()
+
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
return nil
}
@@ -386,8 +390,10 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error {
}
block.OffsetTime(int64(adjustment.Seconds()))
})
+ statedb, _ := b.blockchain.State()
+
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
return nil
}
diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go
index 96de0c76a..a9a2e5420 100644
--- a/cmd/evm/runner.go
+++ b/cmd/evm/runner.go
@@ -96,7 +96,9 @@ func runCmd(ctx *cli.Context) error {
}
if ctx.GlobalString(GenesisFlag.Name) != "" {
gen := readGenesis(ctx.GlobalString(GenesisFlag.Name))
- _, statedb = gen.ToBlock()
+ db, _ := ethdb.NewMemDatabase()
+ genesis := gen.ToBlock(db)
+ statedb, _ = state.New(genesis.Root(), state.NewDatabase(db))
chainConfig = gen.Config
} else {
db, _ := ethdb.NewMemDatabase()
diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go
index 4a9a7b11b..35bf576e1 100644
--- a/cmd/geth/chaincmd.go
+++ b/cmd/geth/chaincmd.go
@@ -202,7 +202,7 @@ func importChain(ctx *cli.Context) error {
if len(ctx.Args()) == 1 {
if err := utils.ImportChain(chain, ctx.Args().First()); err != nil {
- utils.Fatalf("Import error: %v", err)
+ log.Error("Import error", "err", err)
}
} else {
for _, arg := range ctx.Args() {
@@ -211,7 +211,7 @@ func importChain(ctx *cli.Context) error {
}
}
}
-
+ chain.Stop()
fmt.Printf("Import done in %v.\n\n", time.Since(start))
// Output pre-compaction stats mostly to see the import trashing
diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index b955bd243..cb8d63bf7 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -85,10 +85,13 @@ var (
utils.FastSyncFlag,
utils.LightModeFlag,
utils.SyncModeFlag,
+ utils.GCModeFlag,
utils.LightServFlag,
utils.LightPeersFlag,
utils.LightKDFFlag,
utils.CacheFlag,
+ utils.CacheDatabaseFlag,
+ utils.CacheGCFlag,
utils.TrieCacheGenFlag,
utils.ListenPortFlag,
utils.MaxPeersFlag,
diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go
index a834d5b7a..a2bcaff02 100644
--- a/cmd/geth/usage.go
+++ b/cmd/geth/usage.go
@@ -22,10 +22,11 @@ import (
"io"
"sort"
+ "strings"
+
"github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/internal/debug"
"gopkg.in/urfave/cli.v1"
- "strings"
)
// AppHelpTemplate is the test template for the default, global app help topic.
@@ -74,6 +75,7 @@ var AppHelpFlagGroups = []flagGroup{
utils.TestnetFlag,
utils.RinkebyFlag,
utils.SyncModeFlag,
+ utils.GCModeFlag,
utils.EthStatsURLFlag,
utils.IdentityFlag,
utils.LightServFlag,
@@ -127,6 +129,8 @@ var AppHelpFlagGroups = []flagGroup{
Name: "PERFORMANCE TUNING",
Flags: []cli.Flag{
utils.CacheFlag,
+ utils.CacheDatabaseFlag,
+ utils.CacheGCFlag,
utils.TrieCacheGenFlag,
},
},
diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go
index 23b10c2d7..53cdf7861 100644
--- a/cmd/utils/cmd.go
+++ b/cmd/utils/cmd.go
@@ -116,7 +116,6 @@ func ImportChain(chain *core.BlockChain, fn string) error {
return err
}
}
-
stream := rlp.NewStream(reader, 0)
// Run actual the import.
@@ -150,25 +149,34 @@ func ImportChain(chain *core.BlockChain, fn string) error {
if checkInterrupt() {
return fmt.Errorf("interrupted")
}
- if hasAllBlocks(chain, blocks[:i]) {
+ missing := missingBlocks(chain, blocks[:i])
+ if len(missing) == 0 {
log.Info("Skipping batch as all blocks present", "batch", batch, "first", blocks[0].Hash(), "last", blocks[i-1].Hash())
continue
}
-
- if _, err := chain.InsertChain(blocks[:i]); err != nil {
+ if _, err := chain.InsertChain(missing); err != nil {
return fmt.Errorf("invalid block %d: %v", n, err)
}
}
return nil
}
-func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool {
- for _, b := range bs {
- if !chain.HasBlock(b.Hash(), b.NumberU64()) {
- return false
+func missingBlocks(chain *core.BlockChain, blocks []*types.Block) []*types.Block {
+ head := chain.CurrentBlock()
+ for i, block := range blocks {
+ // If we're behind the chain head, only check block, state is available at head
+ if head.NumberU64() > block.NumberU64() {
+ if !chain.HasBlock(block.Hash(), block.NumberU64()) {
+ return blocks[i:]
+ }
+ continue
+ }
+ // If we're above the chain head, state availability is a must
+ if !chain.HasBlockAndState(block.Hash(), block.NumberU64()) {
+ return blocks[i:]
}
}
- return true
+ return nil
}
func ExportChain(blockchain *core.BlockChain, fn string) error {
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 833cd95de..2a2909ff2 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -170,7 +170,11 @@ var (
Usage: `Blockchain sync mode ("fast", "full", or "light")`,
Value: &defaultSyncMode,
}
-
+ GCModeFlag = cli.StringFlag{
+ Name: "gcmode",
+ Usage: `Blockchain garbage collection mode ("full", "archive")`,
+ Value: "full",
+ }
LightServFlag = cli.IntFlag{
Name: "lightserv",
Usage: "Maximum percentage of time allowed for serving LES requests (0-90)",
@@ -293,8 +297,18 @@ var (
// Performance tuning settings
CacheFlag = cli.IntFlag{
Name: "cache",
- Usage: "Megabytes of memory allocated to internal caching (min 16MB / database forced)",
- Value: 128,
+ Usage: "Megabytes of memory allocated to internal caching",
+ Value: 1024,
+ }
+ CacheDatabaseFlag = cli.IntFlag{
+ Name: "cache.database",
+ Usage: "Percentage of cache memory allowance to use for database io",
+ Value: 75,
+ }
+ CacheGCFlag = cli.IntFlag{
+ Name: "cache.gc",
+ Usage: "Percentage of cache memory allowance to use for trie pruning",
+ Value: 25,
}
TrieCacheGenFlag = cli.IntFlag{
Name: "trie-cache-gens",
@@ -1021,11 +1035,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.NetworkId = ctx.GlobalUint64(NetworkIdFlag.Name)
}
- if ctx.GlobalIsSet(CacheFlag.Name) {
- cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name)
+ if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheDatabaseFlag.Name) {
+ cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
}
cfg.DatabaseHandles = makeDatabaseHandles()
+ if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
+ Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
+ }
+ cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive"
+
+ if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
+ cfg.TrieCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
+ }
if ctx.GlobalIsSet(MinerThreadsFlag.Name) {
cfg.MinerThreads = ctx.GlobalInt(MinerThreadsFlag.Name)
}
@@ -1157,7 +1179,7 @@ func SetupNetwork(ctx *cli.Context) {
// MakeChainDatabase open an LevelDB using the flags passed to the client and will hard crash if it fails.
func MakeChainDatabase(ctx *cli.Context, stack *node.Node) ethdb.Database {
var (
- cache = ctx.GlobalInt(CacheFlag.Name)
+ cache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
handles = makeDatabaseHandles()
)
name := "chaindata"
@@ -1209,8 +1231,19 @@ func MakeChain(ctx *cli.Context, stack *node.Node) (chain *core.BlockChain, chai
})
}
}
+ if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
+ Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
+ }
+ cache := &core.CacheConfig{
+ Disabled: ctx.GlobalString(GCModeFlag.Name) == "archive",
+ TrieNodeLimit: eth.DefaultConfig.TrieCache,
+ TrieTimeLimit: eth.DefaultConfig.TrieTimeout,
+ }
+ if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
+ cache.TrieNodeLimit = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
+ }
vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)}
- chain, err = core.NewBlockChain(chainDb, config, engine, vmcfg)
+ chain, err = core.NewBlockChain(chainDb, cache, config, engine, vmcfg)
if err != nil {
Fatalf("Can't create BlockChain: %v", err)
}
diff --git a/common/size.go b/common/size.go
index c5a0cb0f2..bd0fc85c7 100644
--- a/common/size.go
+++ b/common/size.go
@@ -20,18 +20,29 @@ import (
"fmt"
)
+// StorageSize is a wrapper around a float value that supports user friendly
+// formatting.
type StorageSize float64
-func (self StorageSize) String() string {
- if self > 1000000 {
- return fmt.Sprintf("%.2f mB", self/1000000)
- } else if self > 1000 {
- return fmt.Sprintf("%.2f kB", self/1000)
+// String implements the stringer interface.
+func (s StorageSize) String() string {
+ if s > 1000000 {
+ return fmt.Sprintf("%.2f mB", s/1000000)
+ } else if s > 1000 {
+ return fmt.Sprintf("%.2f kB", s/1000)
} else {
- return fmt.Sprintf("%.2f B", self)
+ return fmt.Sprintf("%.2f B", s)
}
}
-func (self StorageSize) Int64() int64 {
- return int64(self)
+// TerminalString implements log.TerminalStringer, formatting a string for console
+// output during logging.
+func (s StorageSize) TerminalString() string {
+ if s > 1000000 {
+ return fmt.Sprintf("%.2fmB", s/1000000)
+ } else if s > 1000 {
+ return fmt.Sprintf("%.2fkB", s/1000)
+ } else {
+ return fmt.Sprintf("%.2fB", s)
+ }
}
diff --git a/consensus/errors.go b/consensus/errors.go
index 3b136dbdd..a005c5f63 100644
--- a/consensus/errors.go
+++ b/consensus/errors.go
@@ -23,6 +23,10 @@ var (
// that is unknown.
ErrUnknownAncestor = errors.New("unknown ancestor")
+ // ErrPrunedAncestor is returned when validating a block requires an ancestor
+ // that is known, but the state of which is not available.
+ ErrPrunedAncestor = errors.New("pruned ancestor")
+
// ErrFutureBlock is returned when a block's timestamp is in the future according
// to the current node.
ErrFutureBlock = errors.New("block in the future")
diff --git a/core/bench_test.go b/core/bench_test.go
index f976331d1..e23f0d19d 100644
--- a/core/bench_test.go
+++ b/core/bench_test.go
@@ -173,7 +173,7 @@ func benchInsertChain(b *testing.B, disk bool, gen func(int, *BlockGen)) {
// Time the insertion of the new chain.
// State and blocks are stored in the same DB.
- chainman, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ chainman, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer chainman.Stop()
b.ReportAllocs()
b.ResetTimer()
@@ -283,7 +283,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
if err != nil {
b.Fatalf("error opening database at %v: %v", dir, err)
}
- chain, err := NewBlockChain(db, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+ chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
if err != nil {
b.Fatalf("error creating chain: %v", err)
}
diff --git a/core/block_validator.go b/core/block_validator.go
index 143728bb8..98958809b 100644
--- a/core/block_validator.go
+++ b/core/block_validator.go
@@ -50,11 +50,14 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin
// validated at this point.
func (v *BlockValidator) ValidateBody(block *types.Block) error {
// Check whether the block's known, and if not, that it's linkable
- if v.bc.HasBlockAndState(block.Hash()) {
+ if v.bc.HasBlockAndState(block.Hash(), block.NumberU64()) {
return ErrKnownBlock
}
- if !v.bc.HasBlockAndState(block.ParentHash()) {
- return consensus.ErrUnknownAncestor
+ if !v.bc.HasBlockAndState(block.ParentHash(), block.NumberU64()-1) {
+ if !v.bc.HasBlock(block.ParentHash(), block.NumberU64()-1) {
+ return consensus.ErrUnknownAncestor
+ }
+ return consensus.ErrPrunedAncestor
}
// Header validity is known at this point, check the uncles and transactions
header := block.Header()
diff --git a/core/block_validator_test.go b/core/block_validator_test.go
index e668601f3..e334b3c3c 100644
--- a/core/block_validator_test.go
+++ b/core/block_validator_test.go
@@ -42,7 +42,7 @@ func TestHeaderVerification(t *testing.T) {
headers[i] = block.Header()
}
// Run the header checker for blocks one-by-one, checking for both valid and invalid nonces
- chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+ chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
defer chain.Stop()
for i := 0; i < len(blocks); i++ {
@@ -106,11 +106,11 @@ func testHeaderConcurrentVerification(t *testing.T, threads int) {
var results <-chan error
if valid {
- chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+ chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
_, results = chain.engine.VerifyHeaders(chain, headers, seals)
chain.Stop()
} else {
- chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{})
+ chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{})
_, results = chain.engine.VerifyHeaders(chain, headers, seals)
chain.Stop()
}
@@ -173,7 +173,7 @@ func testHeaderConcurrentAbortion(t *testing.T, threads int) {
defer runtime.GOMAXPROCS(old)
// Start the verifications and immediately abort
- chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{})
+ chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{})
defer chain.Stop()
abort, results := chain.engine.VerifyHeaders(chain, headers, seals)
diff --git a/core/blockchain.go b/core/blockchain.go
index d5e139e31..8d141fddb 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -42,6 +42,7 @@ import (
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/hashicorp/golang-lru"
+ "gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
var (
@@ -56,11 +57,20 @@ const (
maxFutureBlocks = 256
maxTimeFutureBlocks = 30
badBlockLimit = 10
+ triesInMemory = 128
// BlockChainVersion ensures that an incompatible database forces a resync from scratch.
BlockChainVersion = 3
)
+// CacheConfig contains the configuration values for the trie caching/pruning
+// that's resident in a blockchain.
+type CacheConfig struct {
+ Disabled bool // Whether to disable trie write caching (archive node)
+ TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk
+ TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk
+}
+
// BlockChain represents the canonical chain given a database with a genesis
// block. The Blockchain manages chain imports, reverts, chain reorganisations.
//
@@ -76,10 +86,14 @@ const (
// included in the canonical one where as GetBlockByNumber always represents the
// canonical chain.
type BlockChain struct {
- config *params.ChainConfig // chain & network configuration
+ chainConfig *params.ChainConfig // Chain & network configuration
+ cacheConfig *CacheConfig // Cache configuration for pruning
+
+ db ethdb.Database // Low level persistent database to store final content in
+ triegc *prque.Prque // Priority queue mapping block numbers to tries to gc
+ gcproc time.Duration // Accumulates canonical block processing for trie dumping
hc *HeaderChain
- chainDb ethdb.Database
rmLogsFeed event.Feed
chainFeed event.Feed
chainSideFeed event.Feed
@@ -119,7 +133,13 @@ type BlockChain struct {
// NewBlockChain returns a fully initialised block chain using information
// available in the database. It initialises the default Ethereum Validator and
// Processor.
-func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) {
+func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) {
+ if cacheConfig == nil {
+ cacheConfig = &CacheConfig{
+ TrieNodeLimit: 256 * 1024 * 1024,
+ TrieTimeLimit: 5 * time.Minute,
+ }
+ }
bodyCache, _ := lru.New(bodyCacheLimit)
bodyRLPCache, _ := lru.New(bodyCacheLimit)
blockCache, _ := lru.New(blockCacheLimit)
@@ -127,9 +147,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
badBlocks, _ := lru.New(badBlockLimit)
bc := &BlockChain{
- config: config,
- chainDb: chainDb,
- stateCache: state.NewDatabase(chainDb),
+ chainConfig: chainConfig,
+ cacheConfig: cacheConfig,
+ db: db,
+ triegc: prque.New(),
+ stateCache: state.NewDatabase(db),
quit: make(chan struct{}),
bodyCache: bodyCache,
bodyRLPCache: bodyRLPCache,
@@ -139,11 +161,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
vmConfig: vmConfig,
badBlocks: badBlocks,
}
- bc.SetValidator(NewBlockValidator(config, bc, engine))
- bc.SetProcessor(NewStateProcessor(config, bc, engine))
+ bc.SetValidator(NewBlockValidator(chainConfig, bc, engine))
+ bc.SetProcessor(NewStateProcessor(chainConfig, bc, engine))
var err error
- bc.hc, err = NewHeaderChain(chainDb, config, engine, bc.getProcInterrupt)
+ bc.hc, err = NewHeaderChain(db, chainConfig, engine, bc.getProcInterrupt)
if err != nil {
return nil, err
}
@@ -180,7 +202,7 @@ func (bc *BlockChain) getProcInterrupt() bool {
// assumes that the chain manager mutex is held.
func (bc *BlockChain) loadLastState() error {
// Restore the last known head block
- head := GetHeadBlockHash(bc.chainDb)
+ head := GetHeadBlockHash(bc.db)
if head == (common.Hash{}) {
// Corrupt or empty database, init from scratch
log.Warn("Empty database, resetting chain")
@@ -196,15 +218,17 @@ func (bc *BlockChain) loadLastState() error {
// Make sure the state associated with the block is available
if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
// Dangling block without a state associated, init from scratch
- log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
- return bc.Reset()
+ log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
+ if err := bc.repair(&currentBlock); err != nil {
+ return err
+ }
}
// Everything seems to be fine, set as the head block
bc.currentBlock = currentBlock
// Restore the last known head header
currentHeader := bc.currentBlock.Header()
- if head := GetHeadHeaderHash(bc.chainDb); head != (common.Hash{}) {
+ if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
if header := bc.GetHeaderByHash(head); header != nil {
currentHeader = header
}
@@ -213,7 +237,7 @@ func (bc *BlockChain) loadLastState() error {
// Restore the last known head fast block
bc.currentFastBlock = bc.currentBlock
- if head := GetHeadFastBlockHash(bc.chainDb); head != (common.Hash{}) {
+ if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
if block := bc.GetBlockByHash(head); block != nil {
bc.currentFastBlock = block
}
@@ -243,7 +267,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
// Rewind the header chain, deleting all block bodies until then
delFn := func(hash common.Hash, num uint64) {
- DeleteBody(bc.chainDb, hash, num)
+ DeleteBody(bc.db, hash, num)
}
bc.hc.SetHead(head, delFn)
currentHeader := bc.hc.CurrentHeader()
@@ -275,10 +299,10 @@ func (bc *BlockChain) SetHead(head uint64) error {
if bc.currentFastBlock == nil {
bc.currentFastBlock = bc.genesisBlock
}
- if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil {
+ if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil {
log.Crit("Failed to reset head full block", "err", err)
}
- if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil {
+ if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil {
log.Crit("Failed to reset head fast block", "err", err)
}
return bc.loadLastState()
@@ -292,7 +316,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
if block == nil {
return fmt.Errorf("non existent block [%x…]", hash[:4])
}
- if _, err := trie.NewSecure(block.Root(), bc.chainDb, 0); err != nil {
+ if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), 0); err != nil {
return err
}
// If all checks out, manually set the head block
@@ -387,7 +411,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
log.Crit("Failed to write genesis block TD", "err", err)
}
- if err := WriteBlock(bc.chainDb, genesis); err != nil {
+ if err := WriteBlock(bc.db, genesis); err != nil {
log.Crit("Failed to write genesis block", "err", err)
}
bc.genesisBlock = genesis
@@ -400,6 +424,24 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
return nil
}
+// repair tries to repair the current blockchain by rolling back the current block
+// until one with associated state is found. This is needed to fix incomplete db
+// writes caused either by crashes/power outages, or simply non-committed tries.
+//
+// This method only rolls back the current block. The current header and current
+// fast block are left intact.
+func (bc *BlockChain) repair(head **types.Block) error {
+ for {
+ // Abort if we've rewound to a head block that does have associated state
+ if _, err := state.New((*head).Root(), bc.stateCache); err == nil {
+ log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash())
+ return nil
+ }
+ // Otherwise rewind one block and recheck state availability there
+ (*head) = bc.GetBlock((*head).ParentHash(), (*head).NumberU64()-1)
+ }
+}
+
// Export writes the active chain to the given writer.
func (bc *BlockChain) Export(w io.Writer) error {
return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64())
@@ -437,13 +479,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
// Note, this function assumes that the `mu` mutex is held!
func (bc *BlockChain) insert(block *types.Block) {
// If the block is on a side chain or an unknown one, force other heads onto it too
- updateHeads := GetCanonicalHash(bc.chainDb, block.NumberU64()) != block.Hash()
+ updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash()
// Add the block to the canonical chain number scheme and mark as the head
- if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil {
+ if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil {
log.Crit("Failed to insert block number", "err", err)
}
- if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil {
+ if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head block hash", "err", err)
}
bc.currentBlock = block
@@ -452,7 +494,7 @@ func (bc *BlockChain) insert(block *types.Block) {
if updateHeads {
bc.hc.SetCurrentHeader(block.Header())
- if err := WriteHeadFastBlockHash(bc.chainDb, block.Hash()); err != nil {
+ if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head fast block hash", "err", err)
}
bc.currentFastBlock = block
@@ -472,7 +514,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body {
body := cached.(*types.Body)
return body
}
- body := GetBody(bc.chainDb, hash, bc.hc.GetBlockNumber(hash))
+ body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash))
if body == nil {
return nil
}
@@ -488,7 +530,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue {
if cached, ok := bc.bodyRLPCache.Get(hash); ok {
return cached.(rlp.RawValue)
}
- body := GetBodyRLP(bc.chainDb, hash, bc.hc.GetBlockNumber(hash))
+ body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash))
if len(body) == 0 {
return nil
}
@@ -502,21 +544,25 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool {
if bc.blockCache.Contains(hash) {
return true
}
- ok, _ := bc.chainDb.Has(blockBodyKey(hash, number))
+ ok, _ := bc.db.Has(blockBodyKey(hash, number))
return ok
}
+// HasState checks if state trie is fully present in the database or not.
+func (bc *BlockChain) HasState(hash common.Hash) bool {
+ _, err := bc.stateCache.OpenTrie(hash)
+ return err == nil
+}
+
// HasBlockAndState checks if a block and associated state trie is fully present
// in the database or not, caching it if present.
-func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
+func (bc *BlockChain) HasBlockAndState(hash common.Hash, number uint64) bool {
// Check first that the block itself is known
- block := bc.GetBlockByHash(hash)
+ block := bc.GetBlock(hash, number)
if block == nil {
return false
}
- // Ensure the associated state is also present
- _, err := bc.stateCache.OpenTrie(block.Root())
- return err == nil
+ return bc.HasState(block.Root())
}
// GetBlock retrieves a block from the database by hash and number,
@@ -526,7 +572,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
if block, ok := bc.blockCache.Get(hash); ok {
return block.(*types.Block)
}
- block := GetBlock(bc.chainDb, hash, number)
+ block := GetBlock(bc.db, hash, number)
if block == nil {
return nil
}
@@ -543,13 +589,18 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block {
// GetBlockByNumber retrieves a block from the database by number, caching it
// (associated with its hash) if found.
func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block {
- hash := GetCanonicalHash(bc.chainDb, number)
+ hash := GetCanonicalHash(bc.db, number)
if hash == (common.Hash{}) {
return nil
}
return bc.GetBlock(hash, number)
}
+// GetReceiptsByHash retrieves the receipts for all transactions in a given block.
+func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts {
+ return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash))
+}
+
// GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors.
// [deprecated by eth/62]
func (bc *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) {
@@ -577,6 +628,12 @@ func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.
return uncles
}
+// TrieNode retrieves a blob of data associated with a trie node (or code hash)
+// either from ephemeral in-memory cache, or from persistent storage.
+func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) {
+ return bc.stateCache.TrieDB().Node(hash)
+}
+
// Stop stops the blockchain service. If any imports are currently in progress
// it will abort them using the procInterrupt.
func (bc *BlockChain) Stop() {
@@ -589,6 +646,33 @@ func (bc *BlockChain) Stop() {
atomic.StoreInt32(&bc.procInterrupt, 1)
bc.wg.Wait()
+
+ // Ensure the state of a recent block is also stored to disk before exiting.
+ // It is fine if this state does not exist (fast start/stop cycle), but it is
+ // advisable to leave an N block gap from the head so 1) a restart loads up
+ // the last N blocks as sync assistance to remote nodes; 2) a restart during
+ // a (small) reorg doesn't require deep reprocesses; 3) chain "repair" from
+ // missing states are constantly tested.
+ //
+ // This may be tuned a bit on mainnet if its too annoying to reprocess the last
+ // N blocks.
+ if !bc.cacheConfig.Disabled {
+ triedb := bc.stateCache.TrieDB()
+ if number := bc.CurrentBlock().NumberU64(); number >= triesInMemory {
+ recent := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - triesInMemory + 1)
+
+ log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root())
+ if err := triedb.Commit(recent.Root(), true); err != nil {
+ log.Error("Failed to commit recent state trie", "err", err)
+ }
+ }
+ for !bc.triegc.Empty() {
+ triedb.Dereference(bc.triegc.PopItem().(common.Hash), common.Hash{})
+ }
+ if size := triedb.Size(); size != 0 {
+ log.Error("Dangling trie nodes after full cleanup")
+ }
+ }
log.Info("Blockchain manager stopped")
}
@@ -633,11 +717,11 @@ func (bc *BlockChain) Rollback(chain []common.Hash) {
}
if bc.currentFastBlock.Hash() == hash {
bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1)
- WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash())
+ WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash())
}
if bc.currentBlock.Hash() == hash {
bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1)
- WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash())
+ WriteHeadBlockHash(bc.db, bc.currentBlock.Hash())
}
}
}
@@ -696,7 +780,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
stats = struct{ processed, ignored int32 }{}
start = time.Now()
bytes = 0
- batch = bc.chainDb.NewBatch()
+ batch = bc.db.NewBatch()
)
for i, block := range blockChain {
receipts := receiptChain[i]
@@ -714,7 +798,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
continue
}
// Compute all the non-consensus fields of the receipts
- SetReceiptsData(bc.config, block, receipts)
+ SetReceiptsData(bc.chainConfig, block, receipts)
// Write all the data out into the database
if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil {
return i, fmt.Errorf("failed to write block body: %v", err)
@@ -747,7 +831,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
head := blockChain[len(blockChain)-1]
if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 {
- if err := WriteHeadFastBlockHash(bc.chainDb, head.Hash()); err != nil {
+ if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
log.Crit("Failed to update head fast block hash", "err", err)
}
bc.currentFastBlock = head
@@ -758,15 +842,33 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
log.Info("Imported new block receipts",
"count", stats.processed,
"elapsed", common.PrettyDuration(time.Since(start)),
- "bytes", bytes,
"number", head.Number(),
"hash", head.Hash(),
+ "size", common.StorageSize(bytes),
"ignored", stats.ignored)
return 0, nil
}
-// WriteBlock writes the block to the chain.
-func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
+var lastWrite uint64
+
+// WriteBlockWithoutState writes only the block and its metadata to the database,
+// but does not write any state. This is used to construct competing side forks
+// up to the point where they exceed the canonical total difficulty.
+func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (err error) {
+ bc.wg.Add(1)
+ defer bc.wg.Done()
+
+ if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil {
+ return err
+ }
+ if err := WriteBlock(bc.db, block); err != nil {
+ return err
+ }
+ return nil
+}
+
+// WriteBlockWithState writes the block and all associated state to the database.
+func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
bc.wg.Add(1)
defer bc.wg.Done()
@@ -787,17 +889,73 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
return NonStatTy, err
}
// Write other block data using a batch.
- batch := bc.chainDb.NewBatch()
+ batch := bc.db.NewBatch()
if err := WriteBlock(batch, block); err != nil {
return NonStatTy, err
}
- if _, err := state.CommitTo(batch, bc.config.IsEIP158(block.Number())); err != nil {
+ root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
+ if err != nil {
return NonStatTy, err
}
+ triedb := bc.stateCache.TrieDB()
+
+ // If we're running an archive node, always flush
+ if bc.cacheConfig.Disabled {
+ if err := triedb.Commit(root, false); err != nil {
+ return NonStatTy, err
+ }
+ } else {
+ // Full but not archive node, do proper garbage collection
+ triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
+ bc.triegc.Push(root, -float32(block.NumberU64()))
+
+ if current := block.NumberU64(); current > triesInMemory {
+ // Find the next state trie we need to commit
+ header := bc.GetHeaderByNumber(current - triesInMemory)
+ chosen := header.Number.Uint64()
+
+ // Only write to disk if we exceeded our memory allowance *and* also have at
+ // least a given number of tries gapped.
+ var (
+ size = triedb.Size()
+ limit = common.StorageSize(bc.cacheConfig.TrieNodeLimit) * 1024 * 1024
+ )
+ if size > limit || bc.gcproc > bc.cacheConfig.TrieTimeLimit {
+ // If we're exceeding limits but haven't reached a large enough memory gap,
+ // warn the user that the system is becoming unstable.
+ if chosen < lastWrite+triesInMemory {
+ switch {
+ case size >= 2*limit:
+ log.Error("Trie memory critical, forcing to disk", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+ case bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit:
+ log.Error("Trie timing critical, forcing to disk", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+ case size > limit:
+ log.Warn("Trie memory at dangerous levels", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+ case bc.gcproc > bc.cacheConfig.TrieTimeLimit:
+ log.Warn("Trie timing at dangerous levels", "time", bc.gcproc, "limit", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+ }
+ }
+ // If optimum or critical limits reached, write to disk
+ if chosen >= lastWrite+triesInMemory || size >= 2*limit || bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit {
+ triedb.Commit(header.Root, true)
+ lastWrite = chosen
+ bc.gcproc = 0
+ }
+ }
+ // Garbage collect anything below our required write retention
+ for !bc.triegc.Empty() {
+ root, number := bc.triegc.Pop()
+ if uint64(-number) > chosen {
+ bc.triegc.Push(root, number)
+ break
+ }
+ triedb.Dereference(root.(common.Hash), common.Hash{})
+ }
+ }
+ }
if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
return NonStatTy, err
}
-
// If the total difficulty is higher than our known, add it to the canonical chain
// Second clause in the if statement reduces the vulnerability to selfish mining.
// Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf
@@ -818,7 +976,7 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
return NonStatTy, err
}
// Write hash preimages
- if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil {
+ if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil {
return NonStatTy, err
}
status = CanonStatTy
@@ -910,31 +1068,60 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
if err == nil {
err = bc.Validator().ValidateBody(block)
}
- if err != nil {
- if err == ErrKnownBlock {
- stats.ignored++
- continue
+ switch {
+ case err == ErrKnownBlock:
+ stats.ignored++
+ continue
+
+ case err == consensus.ErrFutureBlock:
+ // Allow up to MaxFuture second in the future blocks. If this limit is exceeded
+ // the chain is discarded and processed at a later time if given.
+ max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks)
+ if block.Time().Cmp(max) > 0 {
+ return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max)
}
+ bc.futureBlocks.Add(block.Hash(), block)
+ stats.queued++
+ continue
+
+ case err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()):
+ bc.futureBlocks.Add(block.Hash(), block)
+ stats.queued++
+ continue
- if err == consensus.ErrFutureBlock {
- // Allow up to MaxFuture second in the future blocks. If this limit
- // is exceeded the chain is discarded and processed at a later time
- // if given.
- max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks)
- if block.Time().Cmp(max) > 0 {
- return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max)
+ case err == consensus.ErrPrunedAncestor:
+ // Block competing with the canonical chain, store in the db, but don't process
+ // until the competitor TD goes above the canonical TD
+ localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
+ externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty())
+ if localTd.Cmp(externTd) > 0 {
+ if err = bc.WriteBlockWithoutState(block, externTd); err != nil {
+ return i, events, coalescedLogs, err
}
- bc.futureBlocks.Add(block.Hash(), block)
- stats.queued++
continue
}
+ // Competitor chain beat canonical, gather all blocks from the common ancestor
+ var winner []*types.Block
- if err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()) {
- bc.futureBlocks.Add(block.Hash(), block)
- stats.queued++
- continue
+ parent := bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
+ for !bc.HasState(parent.Root()) {
+ winner = append(winner, parent)
+ parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1)
+ }
+ for j := 0; j < len(winner)/2; j++ {
+ winner[j], winner[len(winner)-1-j] = winner[len(winner)-1-j], winner[j]
+ }
+ // Import all the pruned blocks to make the state available
+ bc.chainmu.Unlock()
+ _, evs, logs, err := bc.insertChain(winner)
+ bc.chainmu.Lock()
+ events, coalescedLogs = evs, logs
+
+ if err != nil {
+ return i, events, coalescedLogs, err
}
+ case err != nil:
bc.reportBlock(block, nil, err)
return i, events, coalescedLogs, err
}
@@ -962,8 +1149,10 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
bc.reportBlock(block, receipts, err)
return i, events, coalescedLogs, err
}
+ proctime := time.Since(bstart)
+
// Write the block to the chain and get the status.
- status, err := bc.WriteBlockAndState(block, receipts, state)
+ status, err := bc.WriteBlockWithState(block, receipts, state)
if err != nil {
return i, events, coalescedLogs, err
}
@@ -977,6 +1166,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
events = append(events, ChainEvent{block, block.Hash(), logs})
lastCanon = block
+ // Only count canonical blocks for GC processing time
+ bc.gcproc += proctime
+
case SideStatTy:
log.Debug("Inserted forked block", "number", block.Number(), "hash", block.Hash(), "diff", block.Difficulty(), "elapsed",
common.PrettyDuration(time.Since(bstart)), "txs", len(block.Transactions()), "gas", block.GasUsed(), "uncles", len(block.Uncles()))
@@ -986,7 +1178,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
}
stats.processed++
stats.usedGas += usedGas
- stats.report(chain, i)
+ stats.report(chain, i, bc.stateCache.TrieDB().Size())
}
// Append a single chain head event if we've progressed the chain
if lastCanon != nil && bc.CurrentBlock().Hash() == lastCanon.Hash() {
@@ -1009,7 +1201,7 @@ const statsReportLimit = 8 * time.Second
// report prints statistics if some number of blocks have been processed
// or more than a few seconds have passed since the last message.
-func (st *insertStats) report(chain []*types.Block, index int) {
+func (st *insertStats) report(chain []*types.Block, index int, cache common.StorageSize) {
// Fetch the timings for the batch
var (
now = mclock.Now()
@@ -1024,7 +1216,7 @@ func (st *insertStats) report(chain []*types.Block, index int) {
context := []interface{}{
"blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000,
"elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed),
- "number", end.Number(), "hash", end.Hash(),
+ "number", end.Number(), "hash", end.Hash(), "cache", cache,
}
if st.queued > 0 {
context = append(context, []interface{}{"queued", st.queued}...)
@@ -1060,7 +1252,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// These logs are later announced as deleted.
collectLogs = func(h common.Hash) {
// Coalesce logs and set 'Removed'.
- receipts := GetBlockReceipts(bc.chainDb, h, bc.hc.GetBlockNumber(h))
+ receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h))
for _, receipt := range receipts {
for _, log := range receipt.Logs {
del := *log
@@ -1129,7 +1321,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// insert the block in the canonical way, re-writing history
bc.insert(newChain[i])
// write lookup entries for hash based transaction/receipt searches
- if err := WriteTxLookupEntries(bc.chainDb, newChain[i]); err != nil {
+ if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil {
return err
}
addedTxs = append(addedTxs, newChain[i].Transactions()...)
@@ -1139,7 +1331,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// When transactions get deleted from the database that means the
// receipts that were created in the fork must also be deleted
for _, tx := range diff {
- DeleteTxLookupEntry(bc.chainDb, tx.Hash())
+ DeleteTxLookupEntry(bc.db, tx.Hash())
}
if len(deletedLogs) > 0 {
go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs})
@@ -1231,7 +1423,7 @@ Hash: 0x%x
Error: %v
##############################
-`, bc.config, block.Number(), block.Hash(), receiptString, err))
+`, bc.chainConfig, block.Number(), block.Hash(), receiptString, err))
}
// InsertHeaderChain attempts to insert the given header chain in to the local
@@ -1338,7 +1530,7 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
}
// Config retrieves the blockchain's chain configuration.
-func (bc *BlockChain) Config() *params.ChainConfig { return bc.config }
+func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
// Engine retrieves the blockchain's consensus engine.
func (bc *BlockChain) Engine() consensus.Engine { return bc.engine }
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index cbde3bcd2..635379161 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -46,7 +46,7 @@ func newTestBlockChain(fake bool) *BlockChain {
if !fake {
engine = ethash.NewTester()
}
- blockchain, err := NewBlockChain(db, gspec.Config, engine, vm.Config{})
+ blockchain, err := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
if err != nil {
panic(err)
}
@@ -148,9 +148,9 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
return err
}
blockchain.mu.Lock()
- WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
- WriteBlock(blockchain.chainDb, block)
- statedb.CommitTo(blockchain.chainDb, false)
+ WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
+ WriteBlock(blockchain.db, block)
+ statedb.Commit(false)
blockchain.mu.Unlock()
}
return nil
@@ -166,8 +166,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
}
// Manually insert the header into the database, but don't reorganise (allows subsequent testing)
blockchain.mu.Lock()
- WriteTd(blockchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
- WriteHeader(blockchain.chainDb, header)
+ WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
+ WriteHeader(blockchain.db, header)
blockchain.mu.Unlock()
}
return nil
@@ -186,9 +186,9 @@ func TestLastBlock(t *testing.T) {
bchain := newTestBlockChain(false)
defer bchain.Stop()
- block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.chainDb, 0)[0]
+ block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.db, 0)[0]
bchain.insert(block)
- if block.Hash() != GetHeadBlockHash(bchain.chainDb) {
+ if block.Hash() != GetHeadBlockHash(bchain.db) {
t.Errorf("Write/Get HeadBlockHash failed")
}
}
@@ -496,7 +496,7 @@ func testReorgBadHashes(t *testing.T, full bool) {
}
// Create a new BlockChain and check that it rolled back the state.
- ncm, err := NewBlockChain(bc.chainDb, bc.config, ethash.NewFaker(), vm.Config{})
+ ncm, err := NewBlockChain(bc.db, nil, bc.chainConfig, ethash.NewFaker(), vm.Config{})
if err != nil {
t.Fatalf("failed to create new chain manager: %v", err)
}
@@ -609,7 +609,7 @@ func TestFastVsFullChains(t *testing.T) {
// Import the chain as an archive node for the comparison baseline
archiveDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(archiveDb)
- archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+ archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer archive.Stop()
if n, err := archive.InsertChain(blocks); err != nil {
@@ -618,7 +618,7 @@ func TestFastVsFullChains(t *testing.T) {
// Fast import the chain as a non-archive node to test
fastDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(fastDb)
- fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+ fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer fast.Stop()
headers := make([]*types.Header, len(blocks))
@@ -696,7 +696,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
archiveDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(archiveDb)
- archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+ archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if n, err := archive.InsertChain(blocks); err != nil {
t.Fatalf("failed to process block %d: %v", n, err)
}
@@ -709,7 +709,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
// Import the chain as a non-archive node and ensure all pointers are updated
fastDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(fastDb)
- fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+ fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer fast.Stop()
headers := make([]*types.Header, len(blocks))
@@ -730,7 +730,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
lightDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(lightDb)
- light, _ := NewBlockChain(lightDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+ light, _ := NewBlockChain(lightDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if n, err := light.InsertHeaderChain(headers, 1); err != nil {
t.Fatalf("failed to insert header %d: %v", n, err)
}
@@ -799,7 +799,7 @@ func TestChainTxReorgs(t *testing.T) {
}
})
// Import the chain. This runs all block validation rules.
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if i, err := blockchain.InsertChain(chain); err != nil {
t.Fatalf("failed to insert original chain[%d]: %v", i, err)
}
@@ -870,7 +870,7 @@ func TestLogReorgs(t *testing.T) {
signer = types.NewEIP155Signer(gspec.Config.ChainId)
)
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop()
rmLogsCh := make(chan RemovedLogsEvent)
@@ -917,7 +917,7 @@ func TestReorgSideEvent(t *testing.T) {
signer = types.NewEIP155Signer(gspec.Config.ChainId)
)
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop()
chain, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, gen *BlockGen) {})
@@ -992,7 +992,7 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
bc := newTestBlockChain(true)
defer bc.Stop()
- chain, _ := GenerateChain(bc.config, bc.genesisBlock, ethash.NewFaker(), bc.chainDb, 10, func(i int, gen *BlockGen) {})
+ chain, _ := GenerateChain(bc.chainConfig, bc.genesisBlock, ethash.NewFaker(), bc.db, 10, func(i int, gen *BlockGen) {})
var pend sync.WaitGroup
pend.Add(len(chain))
@@ -1003,14 +1003,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
// try to retrieve a block by its canonical hash and see if the block data can be retrieved.
for {
- ch := GetCanonicalHash(bc.chainDb, block.NumberU64())
+ ch := GetCanonicalHash(bc.db, block.NumberU64())
if ch == (common.Hash{}) {
continue // busy wait for canonical hash to be written
}
if ch != block.Hash() {
t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex())
}
- fb := GetBlock(bc.chainDb, ch, block.NumberU64())
+ fb := GetBlock(bc.db, ch, block.NumberU64())
if fb == nil {
t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex())
}
@@ -1043,7 +1043,7 @@ func TestEIP155Transition(t *testing.T) {
genesis = gspec.MustCommit(db)
)
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop()
blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 4, func(i int, block *BlockGen) {
@@ -1151,7 +1151,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
}
genesis = gspec.MustCommit(db)
)
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop()
blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, block *BlockGen) {
@@ -1226,7 +1226,7 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
diskdb, _ := ethdb.NewMemDatabase()
new(Genesis).MustCommit(diskdb)
- chain, err := NewBlockChain(diskdb, params.TestChainConfig, engine, vm.Config{})
+ chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
if err != nil {
t.Fatalf("failed to create tester chain: %v", err)
}
@@ -1245,3 +1245,102 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
}
}
}
+
+// Tests that importing small side forks doesn't leave junk in the trie database
+// cache (which would eventually cause memory issues).
+func TestTrieForkGC(t *testing.T) {
+ // Generate a canonical chain to act as the main dataset
+ engine := ethash.NewFaker()
+
+ db, _ := ethdb.NewMemDatabase()
+ genesis := new(Genesis).MustCommit(db)
+ blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
+
+ // Generate a bunch of fork blocks, each side forking from the canonical chain
+ forks := make([]*types.Block, len(blocks))
+ for i := 0; i < len(forks); i++ {
+ parent := genesis
+ if i > 0 {
+ parent = blocks[i-1]
+ }
+ fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
+ forks[i] = fork[0]
+ }
+ // Import the canonical and fork chain side by side, forcing the trie cache to cache both
+ diskdb, _ := ethdb.NewMemDatabase()
+ new(Genesis).MustCommit(diskdb)
+
+ chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
+ if err != nil {
+ t.Fatalf("failed to create tester chain: %v", err)
+ }
+ for i := 0; i < len(blocks); i++ {
+ if _, err := chain.InsertChain(blocks[i : i+1]); err != nil {
+ t.Fatalf("block %d: failed to insert into chain: %v", i, err)
+ }
+ if _, err := chain.InsertChain(forks[i : i+1]); err != nil {
+ t.Fatalf("fork %d: failed to insert into chain: %v", i, err)
+ }
+ }
+ // Dereference all the recent tries and ensure no past trie is left in
+ for i := 0; i < triesInMemory; i++ {
+ chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), common.Hash{})
+ chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), common.Hash{})
+ }
+ if len(chain.stateCache.TrieDB().Nodes()) > 0 {
+ t.Fatalf("stale tries still alive after garbase collection")
+ }
+}
+
+// Tests that doing large reorgs works even if the state associated with the
+// forking point is not available any more.
+func TestLargeReorgTrieGC(t *testing.T) {
+ // Generate the original common chain segment and the two competing forks
+ engine := ethash.NewFaker()
+
+ db, _ := ethdb.NewMemDatabase()
+ genesis := new(Genesis).MustCommit(db)
+
+ shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
+ original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
+ competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) })
+
+ // Import the shared chain and the original canonical one
+ diskdb, _ := ethdb.NewMemDatabase()
+ new(Genesis).MustCommit(diskdb)
+
+ chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
+ if err != nil {
+ t.Fatalf("failed to create tester chain: %v", err)
+ }
+ if _, err := chain.InsertChain(shared); err != nil {
+ t.Fatalf("failed to insert shared chain: %v", err)
+ }
+ if _, err := chain.InsertChain(original); err != nil {
+ t.Fatalf("failed to insert shared chain: %v", err)
+ }
+ // Ensure that the state associated with the forking point is pruned away
+ if node, _ := chain.stateCache.TrieDB().Node(shared[len(shared)-1].Root()); node != nil {
+ t.Fatalf("common-but-old ancestor still cache")
+ }
+ // Import the competitor chain without exceeding the canonical's TD and ensure
+ // we have not processed any of the blocks (protection against malicious blocks)
+ if _, err := chain.InsertChain(competitor[:len(competitor)-2]); err != nil {
+ t.Fatalf("failed to insert competitor chain: %v", err)
+ }
+ for i, block := range competitor[:len(competitor)-2] {
+ if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
+ t.Fatalf("competitor %d: low TD chain became processed", i)
+ }
+ }
+ // Import the head of the competitor chain, triggering the reorg and ensure we
+ // successfully reprocess all the stashed away blocks.
+ if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil {
+ t.Fatalf("failed to finalize competitor chain: %v", err)
+ }
+ for i, block := range competitor[:len(competitor)-triesInMemory] {
+ if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
+ t.Fatalf("competitor %d: competing chain state missing", i)
+ }
+ }
+}
diff --git a/core/chain_indexer.go b/core/chain_indexer.go
index 7fb184aaa..158ed8324 100644
--- a/core/chain_indexer.go
+++ b/core/chain_indexer.go
@@ -203,6 +203,9 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE
if header.ParentHash != prevHash {
// Reorg to the common ancestor (might not exist in light sync mode, skip reorg then)
// TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly?
+
+ // TODO(karalabe): This operation is expensive and might block, causing the event system to
+ // potentially also lock up. We need to do with on a different thread somehow.
if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
c.newHead(h.Number.Uint64(), true)
}
diff --git a/core/chain_makers.go b/core/chain_makers.go
index 5e264a994..6744428ff 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -166,7 +166,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
genblock := func(i int, parent *types.Block, statedb *state.StateDB) (*types.Block, types.Receipts) {
// TODO(karalabe): This is needed for clique, which depends on multiple blocks.
// It's nonetheless ugly to spin up a blockchain here. Get rid of this somehow.
- blockchain, _ := NewBlockChain(db, config, engine, vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, config, engine, vm.Config{})
defer blockchain.Stop()
b := &BlockGen{i: i, parent: parent, chain: blocks, chainReader: blockchain, statedb: statedb, config: config, engine: engine}
@@ -192,10 +192,13 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
if b.engine != nil {
block, _ := b.engine.Finalize(b.chainReader, b.header, statedb, b.txs, b.uncles, b.receipts)
// Write state changes to db
- _, err := statedb.CommitTo(db, config.IsEIP158(b.header.Number))
+ root, err := statedb.Commit(config.IsEIP158(b.header.Number))
if err != nil {
panic(fmt.Sprintf("state write error: %v", err))
}
+ if err := statedb.Database().TrieDB().Commit(root, false); err != nil {
+ panic(fmt.Sprintf("trie write error: %v", err))
+ }
return block, b.receipts
}
return nil, nil
@@ -246,7 +249,7 @@ func newCanonical(engine consensus.Engine, n int, full bool) (ethdb.Database, *B
db, _ := ethdb.NewMemDatabase()
genesis := gspec.MustCommit(db)
- blockchain, _ := NewBlockChain(db, params.AllEthashProtocolChanges, engine, vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{})
// Create and inject the requested chain
if n == 0 {
return db, blockchain, nil
diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go
index a3b80da29..93be43ddc 100644
--- a/core/chain_makers_test.go
+++ b/core/chain_makers_test.go
@@ -79,7 +79,7 @@ func ExampleGenerateChain() {
})
// Import the chain. This runs all block validation rules.
- blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop()
if i, err := blockchain.InsertChain(chain); err != nil {
diff --git a/core/dao_test.go b/core/dao_test.go
index 43e2982a5..e0a3e3ff3 100644
--- a/core/dao_test.go
+++ b/core/dao_test.go
@@ -45,7 +45,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
proConf.DAOForkBlock = forkBlock
proConf.DAOForkSupport = true
- proBc, _ := NewBlockChain(proDb, &proConf, ethash.NewFaker(), vm.Config{})
+ proBc, _ := NewBlockChain(proDb, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer proBc.Stop()
conDb, _ := ethdb.NewMemDatabase()
@@ -55,7 +55,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
conConf.DAOForkBlock = forkBlock
conConf.DAOForkSupport = false
- conBc, _ := NewBlockChain(conDb, &conConf, ethash.NewFaker(), vm.Config{})
+ conBc, _ := NewBlockChain(conDb, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer conBc.Stop()
if _, err := proBc.InsertChain(prefix); err != nil {
@@ -69,7 +69,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Create a pro-fork block, and try to feed into the no-fork chain
db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db)
- bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{})
+ bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop()
blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@@ -79,6 +79,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
}
+ if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+ t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
+ }
blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := conBc.InsertChain(blocks); err == nil {
t.Fatalf("contra-fork chain accepted pro-fork block: %v", blocks[0])
@@ -91,7 +94,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Create a no-fork block, and try to feed into the pro-fork chain
db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db)
- bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{})
+ bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop()
blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@@ -101,6 +104,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
}
+ if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+ t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
+ }
blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := proBc.InsertChain(blocks); err == nil {
t.Fatalf("pro-fork chain accepted contra-fork block: %v", blocks[0])
@@ -114,7 +120,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Verify that contra-forkers accept pro-fork extra-datas after forking finishes
db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db)
- bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{})
+ bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop()
blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@@ -124,6 +130,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
}
+ if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+ t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
+ }
blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := conBc.InsertChain(blocks); err != nil {
t.Fatalf("contra-fork chain didn't accept pro-fork block post-fork: %v", err)
@@ -131,7 +140,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Verify that pro-forkers accept contra-fork extra-datas after forking finishes
db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db)
- bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{})
+ bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop()
blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@@ -141,6 +150,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
}
+ if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+ t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
+ }
blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := proBc.InsertChain(blocks); err != nil {
t.Fatalf("pro-fork chain didn't accept contra-fork block post-fork: %v", err)
diff --git a/core/genesis.go b/core/genesis.go
index e22985b80..b6ead2250 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -169,10 +169,9 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
// Check whether the genesis block is already written.
if genesis != nil {
- block, _ := genesis.ToBlock()
- hash := block.Hash()
+ hash := genesis.ToBlock(nil).Hash()
if hash != stored {
- return genesis.Config, block.Hash(), &GenesisMismatchError{stored, hash}
+ return genesis.Config, hash, &GenesisMismatchError{stored, hash}
}
}
@@ -220,9 +219,12 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
}
}
-// ToBlock creates the block and state of a genesis specification.
-func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
- db, _ := ethdb.NewMemDatabase()
+// ToBlock creates the genesis block and writes state of a genesis specification
+// to the given database (or discards it if nil).
+func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
+ if db == nil {
+ db, _ = ethdb.NewMemDatabase()
+ }
statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
for addr, account := range g.Alloc {
statedb.AddBalance(addr, account.Balance)
@@ -252,19 +254,19 @@ func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
if g.Difficulty == nil {
head.Difficulty = params.GenesisDifficulty
}
- return types.NewBlock(head, nil, nil, nil), statedb
+ statedb.Commit(false)
+ statedb.Database().TrieDB().Commit(root, true)
+
+ return types.NewBlock(head, nil, nil, nil)
}
// Commit writes the block and state of a genesis specification to the database.
// The block is committed as the canonical head block.
func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
- block, statedb := g.ToBlock()
+ block := g.ToBlock(db)
if block.Number().Sign() != 0 {
return nil, fmt.Errorf("can't commit genesis block with number > 0")
}
- if _, err := statedb.CommitTo(db, false); err != nil {
- return nil, fmt.Errorf("cannot write state: %v", err)
- }
if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil {
return nil, err
}
diff --git a/core/genesis_test.go b/core/genesis_test.go
index 2fe931b24..cd548d4b1 100644
--- a/core/genesis_test.go
+++ b/core/genesis_test.go
@@ -30,11 +30,11 @@ import (
)
func TestDefaultGenesisBlock(t *testing.T) {
- block, _ := DefaultGenesisBlock().ToBlock()
+ block := DefaultGenesisBlock().ToBlock(nil)
if block.Hash() != params.MainnetGenesisHash {
t.Errorf("wrong mainnet genesis hash, got %v, want %v", block.Hash(), params.MainnetGenesisHash)
}
- block, _ = DefaultTestnetGenesisBlock().ToBlock()
+ block = DefaultTestnetGenesisBlock().ToBlock(nil)
if block.Hash() != params.TestnetGenesisHash {
t.Errorf("wrong testnet genesis hash, got %v, want %v", block.Hash(), params.TestnetGenesisHash)
}
@@ -118,7 +118,7 @@ func TestSetupGenesis(t *testing.T) {
// Commit the 'old' genesis block with Homestead transition at #2.
// Advance to block #4, past the homestead transition block of customg.
genesis := oldcustomg.MustCommit(db)
- bc, _ := NewBlockChain(db, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{})
+ bc, _ := NewBlockChain(db, nil, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{})
defer bc.Stop()
bc.SetValidator(bproc{})
bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0))
diff --git a/core/state/database.go b/core/state/database.go
index 946625e76..36926ec69 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -40,16 +40,23 @@ const (
// Database wraps access to tries and contract code.
type Database interface {
- // Accessing tries:
// OpenTrie opens the main account trie.
- // OpenStorageTrie opens the storage trie of an account.
OpenTrie(root common.Hash) (Trie, error)
+
+ // OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
- // Accessing contract code:
- ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
- ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
+
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
+
+ // ContractCode retrieves a particular contract's code.
+ ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
+
+ // ContractCodeSize retrieves a particular contracts code's size.
+ ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
+
+ // TrieDB retrieves the low level trie database used for data storage.
+ TrieDB() *trie.Database
}
// Trie is a Ethereum Merkle Trie.
@@ -57,26 +64,33 @@ type Trie interface {
TryGet(key []byte) ([]byte, error)
TryUpdate(key, value []byte) error
TryDelete(key []byte) error
- CommitTo(trie.DatabaseWriter) (common.Hash, error)
+ Commit(onleaf trie.LeafCallback) (common.Hash, error)
Hash() common.Hash
NodeIterator(startKey []byte) trie.NodeIterator
GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed
+ Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error
}
// NewDatabase creates a backing store for state. The returned database is safe for
-// concurrent use and retains cached trie nodes in memory.
+// concurrent use and retains cached trie nodes in memory. The pool is an optional
+// intermediate trie-node memory pool between the low level storage layer and the
+// high level trie abstraction.
func NewDatabase(db ethdb.Database) Database {
csc, _ := lru.New(codeSizeCacheSize)
- return &cachingDB{db: db, codeSizeCache: csc}
+ return &cachingDB{
+ db: trie.NewDatabase(db),
+ codeSizeCache: csc,
+ }
}
type cachingDB struct {
- db ethdb.Database
+ db *trie.Database
mu sync.Mutex
pastTries []*trie.SecureTrie
codeSizeCache *lru.Cache
}
+// OpenTrie opens the main account trie.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
db.mu.Lock()
defer db.mu.Unlock()
@@ -105,10 +119,12 @@ func (db *cachingDB) pushTrie(t *trie.SecureTrie) {
}
}
+// OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
return trie.NewSecure(root, db.db, 0)
}
+// CopyTrie returns an independent copy of the given trie.
func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) {
case cachedTrie:
@@ -120,14 +136,16 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
}
}
+// ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
- code, err := db.db.Get(codeHash[:])
+ code, err := db.db.Node(codeHash)
if err == nil {
db.codeSizeCache.Add(codeHash, len(code))
}
return code, err
}
+// ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached.(int), nil
@@ -139,16 +157,25 @@ func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, erro
return len(code), err
}
+// TrieDB retrieves any intermediate trie-node caching layer.
+func (db *cachingDB) TrieDB() *trie.Database {
+ return db.db
+}
+
// cachedTrie inserts its trie into a cachingDB on commit.
type cachedTrie struct {
*trie.SecureTrie
db *cachingDB
}
-func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) {
- root, err := m.SecureTrie.CommitTo(dbw)
+func (m cachedTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
+ root, err := m.SecureTrie.Commit(onleaf)
if err == nil {
m.db.pushTrie(m.SecureTrie)
}
return root, err
}
+
+func (m cachedTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+ return m.SecureTrie.Prove(key, fromLevel, proofDb)
+}
diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go
index ff66ba7a9..9e46c851c 100644
--- a/core/state/iterator_test.go
+++ b/core/state/iterator_test.go
@@ -21,12 +21,13 @@ import (
"testing"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
)
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate
- db, mem, root, _ := makeTestState()
+ db, root, _ := makeTestState()
state, err := New(root, db)
if err != nil {
@@ -39,14 +40,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
hashes[it.Hash] = struct{}{}
}
}
-
- // Cross check the hashes and the database itself
+ // Cross check the iterated hashes and the database/nodepool content
for hash := range hashes {
- if _, err := mem.Get(hash.Bytes()); err != nil {
- t.Errorf("failed to retrieve reported node %x: %v", hash, err)
+ if _, err := db.TrieDB().Node(hash); err != nil {
+ t.Errorf("failed to retrieve reported node %x", hash)
+ }
+ }
+ for _, hash := range db.TrieDB().Nodes() {
+ if _, ok := hashes[hash]; !ok {
+ t.Errorf("state entry not reported %x", hash)
}
}
- for _, key := range mem.Keys() {
+ for _, key := range db.TrieDB().DiskDB().(*ethdb.MemDatabase).Keys() {
if bytes.HasPrefix(key, []byte("secure-key-")) {
continue
}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index b2378c69c..b2112bfae 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
)
var emptyCodeHash = crypto.Keccak256(nil)
@@ -238,12 +237,12 @@ func (self *stateObject) updateRoot(db Database) {
// CommitTrie the storage trie of the object to dwb.
// This updates the trie root.
-func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error {
+func (self *stateObject) CommitTrie(db Database) error {
self.updateTrie(db)
if self.dbErr != nil {
return self.dbErr
}
- root, err := self.trie.CommitTo(dbw)
+ root, err := self.trie.Commit(nil)
if err == nil {
self.data.Root = root
}
diff --git a/core/state/state_test.go b/core/state/state_test.go
index bbae3685b..6d42d63d8 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) {
// write some of them to the trie
s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2)
- s.state.CommitTo(s.db, false)
+ s.state.Commit(false)
// check that dump contains the state objects that are in trie
got := string(s.state.Dump())
@@ -97,7 +97,7 @@ func (s *StateSuite) TestNull(c *checker.C) {
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
s.state.SetState(address, common.Hash{}, value)
- s.state.CommitTo(s.db, false)
+ s.state.Commit(false)
value = s.state.GetState(address, common.Hash{})
if !common.EmptyHash(value) {
c.Errorf("expected empty hash. got %x", value)
@@ -155,7 +155,7 @@ func TestSnapshot2(t *testing.T) {
so0.deleted = false
state.setStateObject(so0)
- root, _ := state.CommitTo(db, false)
+ root, _ := state.Commit(false)
state.Reset(root)
// and one with deleted == true
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 8e29104d5..776693e24 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -36,6 +36,14 @@ type revision struct {
journalIndex int
}
+var (
+ // emptyState is the known hash of an empty state trie entry.
+ emptyState = crypto.Keccak256Hash(nil)
+
+ // emptyCode is the known hash of the empty EVM bytecode.
+ emptyCode = crypto.Keccak256Hash(nil)
+)
+
// StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
@@ -235,6 +243,11 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
return common.Hash{}
}
+// Database retrieves the low level database supporting the lower level trie ops.
+func (self *StateDB) Database() Database {
+ return self.db
+}
+
// StorageTrie returns the storage trie of an account.
// The return value is a copy and is nil for non-existent accounts.
func (self *StateDB) StorageTrie(a common.Address) Trie {
@@ -568,8 +581,8 @@ func (s *StateDB) clearJournalAndRefund() {
s.refund = 0
}
-// CommitTo writes the state to the given database.
-func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (root common.Hash, err error) {
+// Commit writes the state to the underlying in-memory trie database.
+func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
defer s.clearJournalAndRefund()
// Commit objects to the trie.
@@ -583,13 +596,11 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
case isDirty:
// Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode {
- if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil {
- return common.Hash{}, err
- }
+ s.db.TrieDB().Insert(common.BytesToHash(stateObject.CodeHash()), stateObject.code)
stateObject.dirtyCode = false
}
// Write any storage changes in the state object to its storage trie.
- if err := stateObject.CommitTrie(s.db, dbw); err != nil {
+ if err := stateObject.CommitTrie(s.db); err != nil {
return common.Hash{}, err
}
// Update the object in the main account trie.
@@ -598,7 +609,20 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
delete(s.stateObjectsDirty, addr)
}
// Write trie changes.
- root, err = s.trie.CommitTo(dbw)
+ root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error {
+ var account Account
+ if err := rlp.DecodeBytes(leaf, &account); err != nil {
+ return nil
+ }
+ if account.Root != emptyState {
+ s.db.TrieDB().Reference(account.Root, parent)
+ }
+ code := common.BytesToHash(account.CodeHash)
+ if code != emptyCode {
+ s.db.TrieDB().Reference(code, parent)
+ }
+ return nil
+ })
log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
return root, err
}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 5c80e3aa5..d9e3d9b79 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -97,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) {
}
// Commit and cross check the databases.
- if _, err := transState.CommitTo(transDb, false); err != nil {
+ if _, err := transState.Commit(false); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
- if _, err := finalState.CommitTo(finalDb, false); err != nil {
+ if _, err := finalState.Commit(false); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
for _, key := range finalDb.Keys() {
@@ -122,8 +122,8 @@ func TestIntermediateLeaks(t *testing.T) {
// https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently"
- mem, _ := ethdb.NewMemDatabase()
- orig, _ := New(common.Hash{}, NewDatabase(mem))
+ db, _ := ethdb.NewMemDatabase()
+ orig, _ := New(common.Hash{}, NewDatabase(db))
for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
@@ -346,11 +346,10 @@ func (test *snapshotTest) run() bool {
}
action.fn(action, state)
}
-
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
- checkstate, _ := New(common.Hash{}, NewDatabase(db))
+ checkstate, _ := New(common.Hash{}, state.Database())
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
@@ -409,7 +408,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
func (s *StateSuite) TestTouchDelete(c *check.C) {
s.state.GetOrNewStateObject(common.Address{})
- root, _ := s.state.CommitTo(s.db, false)
+ root, _ := s.state.Commit(false)
s.state.Reset(root)
snapshot := s.state.Snapshot()
@@ -417,7 +416,6 @@ func (s *StateSuite) TestTouchDelete(c *check.C) {
if len(s.state.stateObjectsDirty) != 1 {
c.Fatal("expected one dirty state object")
}
-
s.state.RevertToSnapshot(snapshot)
if len(s.state.stateObjectsDirty) != 0 {
c.Fatal("expected no dirty state object")
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 06c572ea6..8f14a44e7 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -36,10 +36,10 @@ type testAccount struct {
}
// makeTestState create a sample test state to test node-wise reconstruction.
-func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) {
+func makeTestState() (Database, common.Hash, []*testAccount) {
// Create an empty state
- mem, _ := ethdb.NewMemDatabase()
- db := NewDatabase(mem)
+ diskdb, _ := ethdb.NewMemDatabase()
+ db := NewDatabase(diskdb)
state, _ := New(common.Hash{}, db)
// Fill it with some arbitrary data
@@ -61,10 +61,10 @@ func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount)
state.updateStateObject(obj)
accounts = append(accounts, acc)
}
- root, _ := state.CommitTo(mem, false)
+ root, _ := state.Commit(false)
// Return the generated state
- return db, mem, root, accounts
+ return db, root, accounts
}
// checkStateAccounts cross references a reconstructed state with an expected
@@ -96,7 +96,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent.
}
- trie, err := trie.New(root, db)
+ trie, err := trie.New(root, trie.NewDatabase(db))
if err != nil {
return err
}
@@ -138,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t,
func testIterativeStateSync(t *testing.T, batch int) {
// Create a random state to copy
- _, srcMem, srcRoot, srcAccounts := makeTestState()
+ srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -148,9 +148,9 @@ func testIterativeStateSync(t *testing.T, batch int) {
for len(queue) > 0 {
results := make([]trie.SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcMem.Get(hash.Bytes())
+ data, err := srcDb.TrieDB().Node(hash)
if err != nil {
- t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+ t.Fatalf("failed to retrieve node data for %x", hash)
}
results[i] = trie.SyncResult{Hash: hash, Data: data}
}
@@ -170,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
// partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy
- _, srcMem, srcRoot, srcAccounts := makeTestState()
+ srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -181,9 +181,9 @@ func TestIterativeDelayedStateSync(t *testing.T) {
// Sync only half of the scheduled nodes
results := make([]trie.SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] {
- data, err := srcMem.Get(hash.Bytes())
+ data, err := srcDb.TrieDB().Node(hash)
if err != nil {
- t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+ t.Fatalf("failed to retrieve node data for %x", hash)
}
results[i] = trie.SyncResult{Hash: hash, Data: data}
}
@@ -207,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, batch int) {
// Create a random state to copy
- _, srcMem, srcRoot, srcAccounts := makeTestState()
+ srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -221,9 +221,9 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order
results := make([]trie.SyncResult, 0, len(queue))
for hash := range queue {
- data, err := srcMem.Get(hash.Bytes())
+ data, err := srcDb.TrieDB().Node(hash)
if err != nil {
- t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+ t.Fatalf("failed to retrieve node data for %x", hash)
}
results = append(results, trie.SyncResult{Hash: hash, Data: data})
}
@@ -247,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy
- _, srcMem, srcRoot, srcAccounts := makeTestState()
+ srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -263,9 +263,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range queue {
delete(queue, hash)
- data, err := srcMem.Get(hash.Bytes())
+ data, err := srcDb.TrieDB().Node(hash)
if err != nil {
- t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+ t.Fatalf("failed to retrieve node data for %x", hash)
}
results = append(results, trie.SyncResult{Hash: hash, Data: data})
@@ -292,9 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
// the database.
func TestIncompleteStateSync(t *testing.T) {
// Create a random state to copy
- _, srcMem, srcRoot, srcAccounts := makeTestState()
+ srcDb, srcRoot, srcAccounts := makeTestState()
- checkTrieConsistency(srcMem, srcRoot)
+ checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot)
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -306,9 +306,9 @@ func TestIncompleteStateSync(t *testing.T) {
// Fetch a batch of state nodes
results := make([]trie.SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcMem.Get(hash.Bytes())
+ data, err := srcDb.TrieDB().Node(hash)
if err != nil {
- t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+ t.Fatalf("failed to retrieve node data for %x", hash)
}
results[i] = trie.SyncResult{Hash: hash, Data: data}
}
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index cd11f2ba2..158b9776b 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -78,8 +78,8 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec
}
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
- db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ diskdb, _ := ethdb.NewMemDatabase()
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb))
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
key, _ := crypto.GenerateKey()
diff --git a/core/types/block.go b/core/types/block.go
index ffe317342..92b868d9d 100644
--- a/core/types/block.go
+++ b/core/types/block.go
@@ -25,6 +25,7 @@ import (
"sort"
"sync/atomic"
"time"
+ "unsafe"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
@@ -121,6 +122,12 @@ func (h *Header) HashNoNonce() common.Hash {
})
}
+// Size returns the approximate memory used by all internal contents. It is used
+// to approximate and limit the memory consumption of various caches.
+func (h *Header) Size() common.StorageSize {
+ return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8)
+}
+
func rlpHash(x interface{}) (h common.Hash) {
hw := sha3.NewKeccak256()
rlp.Encode(hw, x)
@@ -322,6 +329,8 @@ func (b *Block) HashNoNonce() common.Hash {
return b.header.HashNoNonce()
}
+// Size returns the true RLP encoded storage size of the block, either by encoding
+// and returning it, or returning a previsouly cached value.
func (b *Block) Size() common.StorageSize {
if size := b.size.Load(); size != nil {
return size.(common.StorageSize)
diff --git a/core/types/receipt.go b/core/types/receipt.go
index 208d54aaa..f945f6f6a 100644
--- a/core/types/receipt.go
+++ b/core/types/receipt.go
@@ -20,6 +20,7 @@ import (
"bytes"
"fmt"
"io"
+ "unsafe"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
@@ -136,6 +137,18 @@ func (r *Receipt) statusEncoding() []byte {
return r.PostState
}
+// Size returns the approximate memory used by all internal contents. It is used
+// to approximate and limit the memory consumption of various caches.
+func (r *Receipt) Size() common.StorageSize {
+ size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState))
+
+ size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{}))
+ for _, log := range r.Logs {
+ size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data))
+ }
+ return size
+}
+
// String implements the Stringer interface.
func (r *Receipt) String() string {
if len(r.PostState) == 0 {
diff --git a/core/types/transaction.go b/core/types/transaction.go
index a7ed211e4..5660582ba 100644
--- a/core/types/transaction.go
+++ b/core/types/transaction.go
@@ -206,6 +206,8 @@ func (tx *Transaction) Hash() common.Hash {
return v
}
+// Size returns the true RLP encoded storage size of the transaction, either by
+// encoding and returning it, or returning a previsouly cached value.
func (tx *Transaction) Size() common.StorageSize {
if size := tx.size.Load(); size != nil {
return size.(common.StorageSize)
diff --git a/eth/api.go b/eth/api.go
index 0db3eb554..a345b57e4 100644
--- a/eth/api.go
+++ b/eth/api.go
@@ -462,11 +462,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc
return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64())
}
- oldTrie, err := trie.NewSecure(startBlock.Root(), api.eth.chainDb, 0)
+ oldTrie, err := trie.NewSecure(startBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
if err != nil {
return nil, err
}
- newTrie, err := trie.NewSecure(endBlock.Root(), api.eth.chainDb, 0)
+ newTrie, err := trie.NewSecure(endBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
if err != nil {
return nil, err
}
diff --git a/eth/api_tracer.go b/eth/api_tracer.go
index d49f077ae..07c4457bc 100644
--- a/eth/api_tracer.go
+++ b/eth/api_tracer.go
@@ -24,7 +24,6 @@ import (
"io/ioutil"
"runtime"
"sync"
- "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
@@ -34,7 +33,6 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers"
- "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
@@ -72,6 +70,7 @@ type txTraceResult struct {
type blockTraceTask struct {
statedb *state.StateDB // Intermediate state prepped for tracing
block *types.Block // Block to trace the transactions from
+ rootref common.Hash // Trie root reference held for this task
results []*txTraceResult // Trace results procudes by the task
}
@@ -90,59 +89,6 @@ type txTraceTask struct {
index int // Transaction offset in the block
}
-// ephemeralDatabase is a memory wrapper around a proper database, which acts as
-// an ephemeral write layer. This construct is used by the chain tracer to write
-// state tries for intermediate blocks without serializing to disk, but at the
-// same time to allow disk fallback for reads that do no hit the memory layer.
-type ephemeralDatabase struct {
- diskdb ethdb.Database // Persistent disk database to fall back to with reads
- memdb *ethdb.MemDatabase // Ephemeral memory database for primary reads and writes
-}
-
-func (db *ephemeralDatabase) Put(key []byte, value []byte) error { return db.memdb.Put(key, value) }
-func (db *ephemeralDatabase) Delete(key []byte) error { return errors.New("delete not supported") }
-func (db *ephemeralDatabase) Close() { db.memdb.Close() }
-func (db *ephemeralDatabase) NewBatch() ethdb.Batch {
- return db.memdb.NewBatch()
-}
-func (db *ephemeralDatabase) Has(key []byte) (bool, error) {
- if has, _ := db.memdb.Has(key); has {
- return has, nil
- }
- return db.diskdb.Has(key)
-}
-func (db *ephemeralDatabase) Get(key []byte) ([]byte, error) {
- if blob, _ := db.memdb.Get(key); blob != nil {
- return blob, nil
- }
- return db.diskdb.Get(key)
-}
-
-// Prune does a state sync into a new memory write layer and replaces the old one.
-// This allows us to discard entries that are no longer referenced from the current
-// state.
-func (db *ephemeralDatabase) Prune(root common.Hash) {
- // Pull the still relevant state data into memory
- sync := state.NewStateSync(root, db.diskdb)
- for sync.Pending() > 0 {
- hash := sync.Missing(1)[0]
-
- // Move the next trie node from the memory layer into a sync struct
- node, err := db.memdb.Get(hash[:])
- if err != nil {
- panic(err) // memdb must have the data
- }
- if _, _, err := sync.Process([]trie.SyncResult{{Hash: hash, Data: node}}); err != nil {
- panic(err) // it's not possible to fail processing a node
- }
- }
- // Discard the old memory layer and write a new one
- db.memdb, _ = ethdb.NewMemDatabaseWithCap(db.memdb.Len())
- if _, err := sync.Commit(db); err != nil {
- panic(err) // writing into a memdb cannot fail
- }
-}
-
// TraceChain returns the structured logs created during the execution of EVM
// between two blocks (excluding start) and returns them as a JSON object.
func (api *PrivateDebugAPI) TraceChain(ctx context.Context, start, end rpc.BlockNumber, config *TraceConfig) (*rpc.Subscription, error) {
@@ -188,19 +134,15 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
// Ensure we have a valid starting state before doing any work
origin := start.NumberU64()
+ database := state.NewDatabase(api.eth.ChainDb())
- memdb, _ := ethdb.NewMemDatabase()
- db := &ephemeralDatabase{
- diskdb: api.eth.ChainDb(),
- memdb: memdb,
- }
if number := start.NumberU64(); number > 0 {
start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1)
if start == nil {
return nil, fmt.Errorf("parent block #%d not found", number-1)
}
}
- statedb, err := state.New(start.Root(), state.NewDatabase(db))
+ statedb, err := state.New(start.Root(), database)
if err != nil {
// If the starting state is missing, allow some number of blocks to be reexecuted
reexec := defaultTraceReexec
@@ -213,7 +155,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
if start == nil {
break
}
- if statedb, err = state.New(start.Root(), state.NewDatabase(db)); err == nil {
+ if statedb, err = state.New(start.Root(), database); err == nil {
break
}
}
@@ -256,7 +198,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config)
if err != nil {
task.results[i] = &txTraceResult{Error: err.Error()}
- log.Warn("Tracing failed", "err", err)
+ log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err)
break
}
task.statedb.DeleteSuicides()
@@ -273,7 +215,6 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
}
// Start a goroutine to feed all the blocks into the tracers
begin := time.Now()
- complete := start.NumberU64()
go func() {
var (
@@ -281,6 +222,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
number uint64
traced uint64
failed error
+ proot common.Hash
)
// Ensure everything is properly cleaned up on any exit path
defer func() {
@@ -308,7 +250,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
// Print progress logs if long enough time elapsed
if time.Since(logged) > 8*time.Second {
if number > origin {
- log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin))
+ log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin), "memory", database.TrieDB().Size())
} else {
log.Info("Preparing state for chain trace", "block", number, "start", origin, "elapsed", time.Since(begin))
}
@@ -325,13 +267,11 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
txs := block.Transactions()
select {
- case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, results: make([]*txTraceResult, len(txs))}:
+ case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, rootref: proot, results: make([]*txTraceResult, len(txs))}:
case <-notifier.Closed():
return
}
traced += uint64(len(txs))
- } else {
- atomic.StoreUint64(&complete, number)
}
// Generate the next state snapshot fast without tracing
_, _, _, err := api.eth.blockchain.Processor().Process(block, statedb, vm.Config{})
@@ -340,7 +280,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
break
}
// Finalize the state so any modifications are written to the trie
- root, err := statedb.CommitTo(db, true)
+ root, err := statedb.Commit(true)
if err != nil {
failed = err
break
@@ -349,26 +289,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
failed = err
break
}
- // After every N blocks, prune the database to only retain relevant data
- if (number-start.NumberU64())%4096 == 0 {
- // Wait until currently pending trace jobs finish
- for atomic.LoadUint64(&complete) != number {
- select {
- case <-time.After(100 * time.Millisecond):
- case <-notifier.Closed():
- return
- }
- }
- // No more concurrent access at this point, prune the database
- var (
- nodes = db.memdb.Len()
- start = time.Now()
- )
- db.Prune(root)
- log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(start))
-
- statedb, _ = state.New(root, state.NewDatabase(db))
+ // Reference the trie twice, once for us, once for the trancer
+ database.TrieDB().Reference(root, common.Hash{})
+ if number >= origin {
+ database.TrieDB().Reference(root, common.Hash{})
}
+ // Dereference all past tries we ourselves are done working with
+ database.TrieDB().Dereference(proot, common.Hash{})
+ proot = root
}
}()
@@ -387,12 +315,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
}
done[uint64(result.Block)] = result
+ // Dereference any paret tries held in memory by this task
+ database.TrieDB().Dereference(res.rootref, common.Hash{})
+
// Stream completed traces to the user, aborting on the first error
for result, ok := done[next]; ok; result, ok = done[next] {
if len(result.Traces) > 0 || next == end.NumberU64() {
notifier.Notify(sub.ID, result)
}
- atomic.StoreUint64(&complete, next)
delete(done, next)
next++
}
@@ -544,18 +474,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
}
// Otherwise try to reexec blocks until we find a state or reach our limit
origin := block.NumberU64()
+ database := state.NewDatabase(api.eth.ChainDb())
- memdb, _ := ethdb.NewMemDatabase()
- db := &ephemeralDatabase{
- diskdb: api.eth.ChainDb(),
- memdb: memdb,
- }
for i := uint64(0); i < reexec; i++ {
block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1)
if block == nil {
break
}
- if statedb, err = state.New(block.Root(), state.NewDatabase(db)); err == nil {
+ if statedb, err = state.New(block.Root(), database); err == nil {
break
}
}
@@ -571,6 +497,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
var (
start = time.Now()
logged time.Time
+ proot common.Hash
)
for block.NumberU64() < origin {
// Print progress logs if long enough time elapsed
@@ -587,26 +514,18 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
return nil, err
}
// Finalize the state so any modifications are written to the trie
- root, err := statedb.CommitTo(db, true)
+ root, err := statedb.Commit(true)
if err != nil {
return nil, err
}
if err := statedb.Reset(root); err != nil {
return nil, err
}
- // After every N blocks, prune the database to only retain relevant data
- if block.NumberU64()%4096 == 0 || block.NumberU64() == origin {
- var (
- nodes = db.memdb.Len()
- begin = time.Now()
- )
- db.Prune(root)
- log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(begin))
-
- statedb, _ = state.New(root, state.NewDatabase(db))
- }
+ database.TrieDB().Reference(root, common.Hash{})
+ database.TrieDB().Dereference(proot, common.Hash{})
+ proot = root
}
- log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start))
+ log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", database.TrieDB().Size())
return statedb, nil
}
diff --git a/eth/backend.go b/eth/backend.go
index bcd724c0c..94aad2310 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -144,9 +144,11 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
}
core.WriteBlockChainVersion(chainDb, core.BlockChainVersion)
}
-
- vmConfig := vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
- eth.blockchain, err = core.NewBlockChain(chainDb, eth.chainConfig, eth.engine, vmConfig)
+ var (
+ vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
+ cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout}
+ )
+ eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, eth.chainConfig, eth.engine, vmConfig)
if err != nil {
return nil, err
}
diff --git a/eth/config.go b/eth/config.go
index 2158c71ba..dd7f42c7d 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -22,6 +22,7 @@ import (
"os/user"
"path/filepath"
"runtime"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
@@ -44,7 +45,9 @@ var DefaultConfig = Config{
},
NetworkId: 1,
LightPeers: 100,
- DatabaseCache: 128,
+ DatabaseCache: 768,
+ TrieCache: 256,
+ TrieTimeout: 5 * time.Minute,
GasPrice: big.NewInt(18 * params.Shannon),
TxPool: core.DefaultTxPoolConfig,
@@ -78,6 +81,7 @@ type Config struct {
// Protocol options
NetworkId uint64 // Network ID to use for selecting peers to connect to
SyncMode downloader.SyncMode
+ NoPruning bool
// Light client options
LightServ int `toml:",omitempty"` // Maximum percentage of time allowed for serving LES requests
@@ -87,6 +91,8 @@ type Config struct {
SkipBcVersionCheck bool `toml:"-"`
DatabaseHandles int `toml:"-"`
DatabaseCache int
+ TrieCache int
+ TrieTimeout time.Duration
// Mining-related options
Etherbase common.Address `toml:",omitempty"`
diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index 746c6a402..7f490d9e9 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -18,10 +18,8 @@
package downloader
import (
- "crypto/rand"
"errors"
"fmt"
- "math"
"math/big"
"sync"
"sync/atomic"
@@ -61,12 +59,11 @@ var (
maxHeadersProcess = 2048 // Number of header download results to import at once into the chain
maxResultsProcess = 2048 // Number of content download results to import at once into the chain
- fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync
- fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected
- fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it
- fsPivotInterval = 256 // Number of headers out of which to randomize the pivot point
- fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync
- fsCriticalTrials = uint32(32) // Number of times to retry in the cricical section before bailing
+ fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync
+ fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected
+ fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it
+ fsHeaderContCheck = 3 * time.Second // Time interval to check for header continuations during state download
+ fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync
)
var (
@@ -102,9 +99,6 @@ type Downloader struct {
peers *peerSet // Set of active peers from which download can proceed
stateDB ethdb.Database
- fsPivotLock *types.Header // Pivot header on critical section entry (cannot change between retries)
- fsPivotFails uint32 // Number of subsequent fast sync failures in the critical section
-
rttEstimate uint64 // Round trip time to target for download requests
rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops)
@@ -124,6 +118,7 @@ type Downloader struct {
synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
synchronising int32
notified int32
+ committed int32
// Channels
headerCh chan dataPack // [eth/62] Channel receiving inbound block headers
@@ -156,7 +151,7 @@ type Downloader struct {
// LightChain encapsulates functions required to synchronise a light chain.
type LightChain interface {
// HasHeader verifies a header's presence in the local chain.
- HasHeader(h common.Hash, number uint64) bool
+ HasHeader(common.Hash, uint64) bool
// GetHeaderByHash retrieves a header from the local chain.
GetHeaderByHash(common.Hash) *types.Header
@@ -179,7 +174,7 @@ type BlockChain interface {
LightChain
// HasBlockAndState verifies block and associated states' presence in the local chain.
- HasBlockAndState(common.Hash) bool
+ HasBlockAndState(common.Hash, uint64) bool
// GetBlockByHash retrieves a block from the local chain.
GetBlockByHash(common.Hash) *types.Block
@@ -391,9 +386,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
// Set the requested sync mode, unless it's forbidden
d.mode = mode
- if d.mode == FastSync && atomic.LoadUint32(&d.fsPivotFails) >= fsCriticalTrials {
- d.mode = FullSync
- }
+
// Retrieve the origin peer and initiate the downloading process
p := d.peers.Peer(id)
if p == nil {
@@ -441,57 +434,40 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
d.syncStatsChainHeight = height
d.syncStatsLock.Unlock()
- // Initiate the sync using a concurrent header and content retrieval algorithm
+ // Ensure our origin point is below any fast sync pivot point
pivot := uint64(0)
- switch d.mode {
- case LightSync:
- pivot = height
- case FastSync:
- // Calculate the new fast/slow sync pivot point
- if d.fsPivotLock == nil {
- pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval)))
- if err != nil {
- panic(fmt.Sprintf("Failed to access crypto random source: %v", err))
- }
- if height > uint64(fsMinFullBlocks)+pivotOffset.Uint64() {
- pivot = height - uint64(fsMinFullBlocks) - pivotOffset.Uint64()
- }
+ if d.mode == FastSync {
+ if height <= uint64(fsMinFullBlocks) {
+ origin = 0
} else {
- // Pivot point locked in, use this and do not pick a new one!
- pivot = d.fsPivotLock.Number.Uint64()
- }
- // If the point is below the origin, move origin back to ensure state download
- if pivot < origin {
- if pivot > 0 {
+ pivot = height - uint64(fsMinFullBlocks)
+ if pivot <= origin {
origin = pivot - 1
- } else {
- origin = 0
}
}
- log.Debug("Fast syncing until pivot block", "pivot", pivot)
}
- d.queue.Prepare(origin+1, d.mode, pivot, latest)
+ d.committed = 1
+ if d.mode == FastSync && pivot != 0 {
+ d.committed = 0
+ }
+ // Initiate the sync using a concurrent header and content retrieval algorithm
+ d.queue.Prepare(origin+1, d.mode)
if d.syncInitHook != nil {
d.syncInitHook(origin, height)
}
fetchers := []func() error{
- func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved
- func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync
- func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
- func() error { return d.processHeaders(origin+1, td) },
+ func() error { return d.fetchHeaders(p, origin+1, pivot) }, // Headers are always retrieved
+ func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync
+ func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
+ func() error { return d.processHeaders(origin+1, pivot, td) },
}
if d.mode == FastSync {
fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) })
} else if d.mode == FullSync {
fetchers = append(fetchers, d.processFullSyncContent)
}
- err = d.spawnSync(fetchers)
- if err != nil && d.mode == FastSync && d.fsPivotLock != nil {
- // If sync failed in the critical section, bump the fail counter.
- atomic.AddUint32(&d.fsPivotFails, 1)
- }
- return err
+ return d.spawnSync(fetchers)
}
// spawnSync runs d.process and all given fetcher functions to completion in
@@ -671,7 +647,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
continue
}
// Otherwise check if we already know the header or not
- if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) {
+ if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash(), headers[i].Number.Uint64())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) {
number, hash = headers[i].Number.Uint64(), headers[i].Hash()
// If every header is known, even future ones, the peer straight out lied about its head
@@ -736,7 +712,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
arrived = true
// Modify the search interval based on the response
- if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) {
+ if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash(), headers[0].Number.Uint64())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) {
end = check
break
}
@@ -774,7 +750,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
// other peers are only accepted if they map cleanly to the skeleton. If no one
// can fill in the skeleton - not even the origin peer - it's assumed invalid and
// the origin is dropped.
-func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
+func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64) error {
p.log.Debug("Directing header downloads", "origin", from)
defer p.log.Debug("Header download terminated")
@@ -825,6 +801,18 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
}
// If no more headers are inbound, notify the content fetchers and return
if packet.Items() == 0 {
+ // Don't abort header fetches while the pivot is downloading
+ if atomic.LoadInt32(&d.committed) == 0 && pivot <= from {
+ p.log.Debug("No headers, waiting for pivot commit")
+ select {
+ case <-time.After(fsHeaderContCheck):
+ getHeaders(from)
+ continue
+ case <-d.cancelCh:
+ return errCancelHeaderFetch
+ }
+ }
+ // Pivot done (or not in fast sync) and no more headers, terminate the process
p.log.Debug("No more headers available")
select {
case d.headerProcCh <- nil:
@@ -1129,10 +1117,8 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
}
if request.From > 0 {
peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From)
- } else if len(request.Headers) > 0 {
- peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
} else {
- peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Hashes))
+ peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
}
// Fetch the chunk and make sure any errors return the hashes to the queue
if fetchHook != nil {
@@ -1160,10 +1146,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
// processHeaders takes batches of retrieved headers from an input channel and
// keeps processing and scheduling them into the header chain and downloader's
// queue until the stream ends or a failure occurs.
-func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
- // Calculate the pivoting point for switching from fast to slow sync
- pivot := d.queue.FastSyncPivot()
-
+func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) error {
// Keep a count of uncertain headers to roll back
rollback := []*types.Header{}
defer func() {
@@ -1188,19 +1171,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
"header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
"fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
"block", fmt.Sprintf("%d->%d", lastBlock, curBlock))
-
- // If we're already past the pivot point, this could be an attack, thread carefully
- if rollback[len(rollback)-1].Number.Uint64() > pivot {
- // If we didn't ever fail, lock in the pivot header (must! not! change!)
- if atomic.LoadUint32(&d.fsPivotFails) == 0 {
- for _, header := range rollback {
- if header.Number.Uint64() == pivot {
- log.Warn("Fast-sync pivot locked in", "number", pivot, "hash", header.Hash())
- d.fsPivotLock = header
- }
- }
- }
- }
}
}()
@@ -1302,13 +1272,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...)
}
}
- // If we're fast syncing and just pulled in the pivot, make sure it's the one locked in
- if d.mode == FastSync && d.fsPivotLock != nil && chunk[0].Number.Uint64() <= pivot && chunk[len(chunk)-1].Number.Uint64() >= pivot {
- if pivot := chunk[int(pivot-chunk[0].Number.Uint64())]; pivot.Hash() != d.fsPivotLock.Hash() {
- log.Warn("Pivot doesn't match locked in one", "remoteNumber", pivot.Number, "remoteHash", pivot.Hash(), "localNumber", d.fsPivotLock.Number, "localHash", d.fsPivotLock.Hash())
- return errInvalidChain
- }
- }
// Unless we're doing light chains, schedule the headers for associated content retrieval
if d.mode == FullSync || d.mode == FastSync {
// If we've reached the allowed number of pending headers, stall a bit
@@ -1343,7 +1306,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
// processFullSyncContent takes fetch results from the queue and imports them into the chain.
func (d *Downloader) processFullSyncContent() error {
for {
- results := d.queue.WaitResults()
+ results := d.queue.Results(true)
if len(results) == 0 {
return nil
}
@@ -1357,30 +1320,28 @@ func (d *Downloader) processFullSyncContent() error {
}
func (d *Downloader) importBlockResults(results []*fetchResult) error {
- for len(results) != 0 {
- // Check for any termination requests. This makes clean shutdown faster.
- select {
- case <-d.quitCh:
- return errCancelContentProcessing
- default:
- }
- // Retrieve the a batch of results to import
- items := int(math.Min(float64(len(results)), float64(maxResultsProcess)))
- first, last := results[0].Header, results[items-1].Header
- log.Debug("Inserting downloaded chain", "items", len(results),
- "firstnum", first.Number, "firsthash", first.Hash(),
- "lastnum", last.Number, "lasthash", last.Hash(),
- )
- blocks := make([]*types.Block, items)
- for i, result := range results[:items] {
- blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
- }
- if index, err := d.blockchain.InsertChain(blocks); err != nil {
- log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
- return errInvalidChain
- }
- // Shift the results to the next batch
- results = results[items:]
+ // Check for any early termination requests
+ if len(results) == 0 {
+ return nil
+ }
+ select {
+ case <-d.quitCh:
+ return errCancelContentProcessing
+ default:
+ }
+ // Retrieve the a batch of results to import
+ first, last := results[0].Header, results[len(results)-1].Header
+ log.Debug("Inserting downloaded chain", "items", len(results),
+ "firstnum", first.Number, "firsthash", first.Hash(),
+ "lastnum", last.Number, "lasthash", last.Hash(),
+ )
+ blocks := make([]*types.Block, len(results))
+ for i, result := range results {
+ blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ }
+ if index, err := d.blockchain.InsertChain(blocks); err != nil {
+ log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+ return errInvalidChain
}
return nil
}
@@ -1388,35 +1349,92 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error {
// processFastSyncContent takes fetch results from the queue and writes them to the
// database. It also controls the synchronisation of state nodes of the pivot block.
func (d *Downloader) processFastSyncContent(latest *types.Header) error {
- // Start syncing state of the reported head block.
- // This should get us most of the state of the pivot block.
+ // Start syncing state of the reported head block. This should get us most of
+ // the state of the pivot block.
stateSync := d.syncState(latest.Root)
defer stateSync.Cancel()
go func() {
- if err := stateSync.Wait(); err != nil {
+ if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
d.queue.Close() // wake up WaitResults
}
}()
-
- pivot := d.queue.FastSyncPivot()
+ // Figure out the ideal pivot block. Note, that this goalpost may move if the
+ // sync takes long enough for the chain head to move significantly.
+ pivot := uint64(0)
+ if height := latest.Number.Uint64(); height > uint64(fsMinFullBlocks) {
+ pivot = height - uint64(fsMinFullBlocks)
+ }
+ // To cater for moving pivot points, track the pivot block and subsequently
+ // accumulated download results separatey.
+ var (
+ oldPivot *fetchResult // Locked in pivot block, might change eventually
+ oldTail []*fetchResult // Downloaded content after the pivot
+ )
for {
- results := d.queue.WaitResults()
+ // Wait for the next batch of downloaded data to be available, and if the pivot
+ // block became stale, move the goalpost
+ results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness
if len(results) == 0 {
- return stateSync.Cancel()
+ // If pivot sync is done, stop
+ if oldPivot == nil {
+ return stateSync.Cancel()
+ }
+ // If sync failed, stop
+ select {
+ case <-d.cancelCh:
+ return stateSync.Cancel()
+ default:
+ }
}
if d.chainInsertHook != nil {
d.chainInsertHook(results)
}
+ if oldPivot != nil {
+ results = append(append([]*fetchResult{oldPivot}, oldTail...), results...)
+ }
+ // Split around the pivot block and process the two sides via fast/full sync
+ if atomic.LoadInt32(&d.committed) == 0 {
+ latest = results[len(results)-1].Header
+ if height := latest.Number.Uint64(); height > pivot+2*uint64(fsMinFullBlocks) {
+ log.Warn("Pivot became stale, moving", "old", pivot, "new", height-uint64(fsMinFullBlocks))
+ pivot = height - uint64(fsMinFullBlocks)
+ }
+ }
P, beforeP, afterP := splitAroundPivot(pivot, results)
if err := d.commitFastSyncData(beforeP, stateSync); err != nil {
return err
}
if P != nil {
- stateSync.Cancel()
- if err := d.commitPivotBlock(P); err != nil {
- return err
+ // If new pivot block found, cancel old state retrieval and restart
+ if oldPivot != P {
+ stateSync.Cancel()
+
+ stateSync = d.syncState(P.Header.Root)
+ defer stateSync.Cancel()
+ go func() {
+ if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
+ d.queue.Close() // wake up WaitResults
+ }
+ }()
+ oldPivot = P
+ }
+ // Wait for completion, occasionally checking for pivot staleness
+ select {
+ case <-stateSync.done:
+ if stateSync.err != nil {
+ return stateSync.err
+ }
+ if err := d.commitPivotBlock(P); err != nil {
+ return err
+ }
+ oldPivot = nil
+
+ case <-time.After(time.Second):
+ oldTail = afterP
+ continue
}
}
+ // Fast sync done, pivot commit done, full import
if err := d.importBlockResults(afterP); err != nil {
return err
}
@@ -1439,52 +1457,49 @@ func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, bef
}
func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error {
- for len(results) != 0 {
- // Check for any termination requests.
- select {
- case <-d.quitCh:
- return errCancelContentProcessing
- case <-stateSync.done:
- if err := stateSync.Wait(); err != nil {
- return err
- }
- default:
- }
- // Retrieve the a batch of results to import
- items := int(math.Min(float64(len(results)), float64(maxResultsProcess)))
- first, last := results[0].Header, results[items-1].Header
- log.Debug("Inserting fast-sync blocks", "items", len(results),
- "firstnum", first.Number, "firsthash", first.Hash(),
- "lastnumn", last.Number, "lasthash", last.Hash(),
- )
- blocks := make([]*types.Block, items)
- receipts := make([]types.Receipts, items)
- for i, result := range results[:items] {
- blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
- receipts[i] = result.Receipts
- }
- if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil {
- log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
- return errInvalidChain
+ // Check for any early termination requests
+ if len(results) == 0 {
+ return nil
+ }
+ select {
+ case <-d.quitCh:
+ return errCancelContentProcessing
+ case <-stateSync.done:
+ if err := stateSync.Wait(); err != nil {
+ return err
}
- // Shift the results to the next batch
- results = results[items:]
+ default:
+ }
+ // Retrieve the a batch of results to import
+ first, last := results[0].Header, results[len(results)-1].Header
+ log.Debug("Inserting fast-sync blocks", "items", len(results),
+ "firstnum", first.Number, "firsthash", first.Hash(),
+ "lastnumn", last.Number, "lasthash", last.Hash(),
+ )
+ blocks := make([]*types.Block, len(results))
+ receipts := make([]types.Receipts, len(results))
+ for i, result := range results {
+ blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ receipts[i] = result.Receipts
+ }
+ if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil {
+ log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+ return errInvalidChain
}
return nil
}
func (d *Downloader) commitPivotBlock(result *fetchResult) error {
- b := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
- // Sync the pivot block state. This should complete reasonably quickly because
- // we've already synced up to the reported head block state earlier.
- if err := d.syncState(b.Root()).Wait(); err != nil {
+ block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash())
+ if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}); err != nil {
return err
}
- log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash())
- if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil {
+ if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil {
return err
}
- return d.blockchain.FastSyncCommitHead(b.Hash())
+ atomic.StoreInt32(&d.committed, 1)
+ return nil
}
// DeliverHeaders injects a new batch of block headers received from a remote
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index e9c7b6170..d94d55f11 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -28,7 +28,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
@@ -45,8 +44,8 @@ var (
// Reduce some of the parameters to make the tester faster.
func init() {
MaxForkAncestry = uint64(10000)
- blockCacheLimit = 1024
- fsCriticalTrials = 10
+ blockCacheItems = 1024
+ fsHeaderContCheck = 500 * time.Millisecond
}
// downloadTester is a test simulator for mocking out local block chain.
@@ -223,7 +222,7 @@ func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool {
}
// HasBlockAndState checks if a block and associated state is present in the testers canonical chain.
-func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool {
+func (dl *downloadTester) HasBlockAndState(hash common.Hash, number uint64) bool {
block := dl.GetBlockByHash(hash)
if block == nil {
return false
@@ -293,7 +292,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block {
func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
// For now only check that the state trie is correct
if block := dl.GetBlockByHash(hash); block != nil {
- _, err := trie.NewSecure(block.Root(), dl.stateDb, 0)
+ _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), 0)
return err
}
return fmt.Errorf("non existent block: %x", hash[:4])
@@ -619,28 +618,22 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
// number of items of the various chain components.
func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
// Initialize the counters for the first fork
- headers, blocks := lengths[0], lengths[0]
+ headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks
- minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks
- if minReceipts < 0 {
- minReceipts = 1
- }
- if maxReceipts < 0 {
- maxReceipts = 1
+ if receipts < 0 {
+ receipts = 1
}
// Update the counters for each subsequent fork
for _, length := range lengths[1:] {
headers += length - common
blocks += length - common
-
- minReceipts += length - common - fsMinFullBlocks - fsPivotInterval
- maxReceipts += length - common - fsMinFullBlocks
+ receipts += length - common - fsMinFullBlocks
}
switch tester.downloader.mode {
case FullSync:
- minReceipts, maxReceipts = 1, 1
+ receipts = 1
case LightSync:
- blocks, minReceipts, maxReceipts = 1, 1, 1
+ blocks, receipts = 1, 1
}
if hs := len(tester.ownHeaders); hs != headers {
t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
@@ -648,11 +641,12 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
if bs := len(tester.ownBlocks); bs != blocks {
t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
}
- if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts {
- t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts)
+ if rs := len(tester.ownReceipts); rs != receipts {
+ t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts)
}
// Verify the state trie too for fast syncs
- if tester.downloader.mode == FastSync {
+ /*if tester.downloader.mode == FastSync {
+ pivot := uint64(0)
var index int
if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common {
index = pivot
@@ -660,11 +654,11 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
}
if index > 0 {
- if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil {
+ if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(trie.NewDatabase(tester.stateDb))); statedb == nil || err != nil {
t.Fatalf("state reconstruction failed: %v", err)
}
}
- }
+ }*/
}
// Tests that simple synchronization against a canonical chain works correctly.
@@ -684,7 +678,7 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -710,7 +704,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a long block chain to download and the tester
- targetBlocks := 8 * blockCacheLimit
+ targetBlocks := 8 * blockCacheItems
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -745,9 +739,9 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
cached = len(tester.downloader.queue.blockDonePool)
if mode == FastSync {
if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached {
- if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
- cached = receipts
- }
+ //if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
+ cached = receipts
+ //}
}
}
frozen = int(atomic.LoadUint32(&blocked))
@@ -755,7 +749,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
tester.downloader.queue.lock.Unlock()
tester.lock.Unlock()
- if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 {
+ if cached == blockCacheItems || retrieved+cached+frozen == targetBlocks+1 {
break
}
}
@@ -765,8 +759,8 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
tester.lock.RLock()
retrieved = len(tester.ownBlocks)
tester.lock.RUnlock()
- if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 {
- t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1)
+ if cached != blockCacheItems && retrieved+cached+frozen != targetBlocks+1 {
+ t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheItems, retrieved, frozen, targetBlocks+1)
}
// Permit the blocked blocks to import
if atomic.LoadUint32(&blocked) > 0 {
@@ -974,7 +968,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download and the tester
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
if targetBlocks >= MaxHashFetch {
targetBlocks = MaxHashFetch - 15
}
@@ -1016,12 +1010,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Create various peers with various parts of the chain
targetPeers := 8
- targetBlocks := targetPeers*blockCacheLimit - 15
+ targetBlocks := targetPeers*blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
for i := 0; i < targetPeers; i++ {
id := fmt.Sprintf("peer #%d", i)
- tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts)
+ tester.newPeer(id, protocol, hashes[i*blockCacheItems:], headers, blocks, receipts)
}
if err := tester.sync("peer #0", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err)
@@ -1045,7 +1039,7 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Create peers of every type
@@ -1084,7 +1078,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a block chain to download
- targetBlocks := 2*blockCacheLimit - 15
+ targetBlocks := 2*blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -1110,8 +1104,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
bodiesNeeded++
}
}
- for hash, receipt := range receipts {
- if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot {
+ for _, receipt := range receipts {
+ if mode == FastSync && len(receipt) > 0 {
receiptsNeeded++
}
}
@@ -1139,7 +1133,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt a full sync with an attacker feeding gapped headers
@@ -1174,7 +1168,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt a full sync with an attacker feeding shifted headers
@@ -1208,7 +1202,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := 3*fsHeaderSafetyNet + fsPivotInterval + fsMinFullBlocks
+ targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt to sync with an attacker that feeds junk during the fast sync phase.
@@ -1248,7 +1242,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts)
missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
- tester.downloader.fsPivotFails = 0
tester.downloader.syncInitHook = func(uint64, uint64) {
for i := missing; i <= len(hashes); i++ {
delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i])
@@ -1267,8 +1260,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
t.Errorf("fast sync pivot block #%d not rolled back", head)
}
}
- tester.downloader.fsPivotFails = fsCriticalTrials
-
// Synchronise with the valid peer and make sure sync succeeds. Since the last
// rollback should also disable fast syncing for this process, verify that we
// did a fresh full sync. Note, we can't assert anything about the receipts
@@ -1383,7 +1374,7 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes
@@ -1532,7 +1523,7 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes
@@ -1609,7 +1600,7 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate()
// Create a small block chain
- targetBlocks := blockCacheLimit - 15
+ targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks+3, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes
@@ -1697,6 +1688,7 @@ func TestDeliverHeadersHang(t *testing.T) {
type floodingTestPeer struct {
peer Peer
tester *downloadTester
+ pend sync.WaitGroup
}
func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() }
@@ -1717,9 +1709,12 @@ func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int
deliveriesDone := make(chan struct{}, 500)
for i := 0; i < cap(deliveriesDone); i++ {
peer := fmt.Sprintf("fake-peer%d", i)
+ ftp.pend.Add(1)
+
go func() {
ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}})
deliveriesDone <- struct{}{}
+ ftp.pend.Done()
}()
}
// Deliver the actual requested headers.
@@ -1751,110 +1746,15 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
// Whenever the downloader requests headers, flood it with
// a lot of unrequested header deliveries.
tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{
- tester.downloader.peers.peers["peer"].peer,
- tester,
+ peer: tester.downloader.peers.peers["peer"].peer,
+ tester: tester,
}
if err := tester.sync("peer", nil, mode); err != nil {
- t.Errorf("sync failed: %v", err)
+ t.Errorf("test %d: sync failed: %v", i, err)
}
tester.terminate()
- }
-}
-
-// Tests that if fast sync aborts in the critical section, it can restart a few
-// times before giving up.
-// We use data driven subtests to manage this so that it will be parallel on its own
-// and not with the other tests, avoiding intermittent failures.
-func TestFastCriticalRestarts(t *testing.T) {
- testCases := []struct {
- protocol int
- progress bool
- }{
- {63, false},
- {64, false},
- {63, true},
- {64, true},
- }
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("protocol %d progress %v", tc.protocol, tc.progress), func(t *testing.T) {
- testFastCriticalRestarts(t, tc.protocol, tc.progress)
- })
- }
-}
-
-func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) {
- t.Parallel()
-
- tester := newTester()
- defer tester.terminate()
-
- // Create a large enough blockchin to actually fast sync on
- targetBlocks := fsMinFullBlocks + 2*fsPivotInterval - 15
- hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
-
- // Create a tester peer with a critical section header missing (force failures)
- tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
- delete(tester.peerHeaders["peer"], hashes[fsMinFullBlocks-1])
- tester.downloader.dropPeer = func(id string) {} // We reuse the same "faulty" peer throughout the test
-
- // Remove all possible pivot state roots and slow down replies (test failure resets later)
- for i := 0; i < fsPivotInterval; i++ {
- tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true
- }
- (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(500 * time.Millisecond) // Enough to reach the critical section
-
- // Synchronise with the peer a few times and make sure they fail until the retry limit
- for i := 0; i < int(fsCriticalTrials)-1; i++ {
- // Attempt a sync and ensure it fails properly
- if err := tester.sync("peer", nil, FastSync); err == nil {
- t.Fatalf("failing fast sync succeeded: %v", err)
- }
- time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
-
- // If it's the first failure, pivot should be locked => reenable all others to detect pivot changes
- if i == 0 {
- time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
- if tester.downloader.fsPivotLock == nil {
- time.Sleep(400 * time.Millisecond) // Make sure the first huge timeout expires too
- t.Fatalf("pivot block not locked in after critical section failure")
- }
- tester.lock.Lock()
- tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]]
- tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true}
- (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(0)
- tester.lock.Unlock()
- }
- }
- // Return all nodes if we're testing fast sync progression
- if progress {
- tester.lock.Lock()
- tester.peerMissingStates["peer"] = map[common.Hash]bool{}
- tester.lock.Unlock()
-
- if err := tester.sync("peer", nil, FastSync); err != nil {
- t.Fatalf("failed to synchronise blocks in progressed fast sync: %v", err)
- }
- time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
- if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != 1 {
- t.Fatalf("progressed pivot trial count mismatch: have %v, want %v", fails, 1)
- }
- assertOwnChain(t, tester, targetBlocks+1)
- } else {
- if err := tester.sync("peer", nil, FastSync); err == nil {
- t.Fatalf("succeeded to synchronise blocks in failed fast sync")
- }
- time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
-
- if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != fsCriticalTrials {
- t.Fatalf("failed pivot trial count mismatch: have %v, want %v", fails, fsCriticalTrials)
- }
- }
- // Retry limit exhausted, downloader will switch to full sync, should succeed
- if err := tester.sync("peer", nil, FastSync); err != nil {
- t.Fatalf("failed to synchronise blocks in slow sync: %v", err)
+ // Flush all goroutines to prevent messing with subsequent tests
+ tester.downloader.peers.peers["peer"].peer.(*floodingTestPeer).pend.Wait()
}
- // Note, we can't assert the chain here because the test asserter assumes sync
- // completed using a single mode of operation, whereas fast-then-slow can result
- // in arbitrary intermediate state that's not cleanly verifiable.
}
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index 6926f1d8c..a1a70e46e 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -32,7 +32,11 @@ import (
"gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
-var blockCacheLimit = 8192 // Maximum number of blocks to cache before throttling the download
+var (
+ blockCacheItems = 8192 // Maximum number of blocks to cache before throttling the download
+ blockCacheMemory = 64 * 1024 * 1024 // Maximum amount of memory to use for block caching
+ blockCacheSizeWeight = 0.1 // Multiplier to approximate the average block size based on past ones
+)
var (
errNoFetchesPending = errors.New("no fetches pending")
@@ -41,17 +45,17 @@ var (
// fetchRequest is a currently running data retrieval operation.
type fetchRequest struct {
- Peer *peerConnection // Peer to which the request was sent
- From uint64 // [eth/62] Requested chain element index (used for skeleton fills only)
- Hashes map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority)
- Headers []*types.Header // [eth/62] Requested headers, sorted by request order
- Time time.Time // Time when the request was made
+ Peer *peerConnection // Peer to which the request was sent
+ From uint64 // [eth/62] Requested chain element index (used for skeleton fills only)
+ Headers []*types.Header // [eth/62] Requested headers, sorted by request order
+ Time time.Time // Time when the request was made
}
// fetchResult is a struct collecting partial results from data fetchers until
// all outstanding pieces complete and the result as a whole can be processed.
type fetchResult struct {
- Pending int // Number of data fetches still pending
+ Pending int // Number of data fetches still pending
+ Hash common.Hash // Hash of the header to prevent recalculating
Header *types.Header
Uncles []*types.Header
@@ -61,12 +65,10 @@ type fetchResult struct {
// queue represents hashes that are either need fetching or are being fetched
type queue struct {
- mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
- fastSyncPivot uint64 // Block number where the fast sync pivots into archive synchronisation mode
-
- headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
+ mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
// Headers are "special", they download in batches, supported by a skeleton chain
+ headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers
headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for
headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable
@@ -87,8 +89,9 @@ type queue struct {
receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations
receiptDonePool map[common.Hash]struct{} // [eth/63] Set of the completed receipt fetches
- resultCache []*fetchResult // Downloaded but not yet delivered fetch results
- resultOffset uint64 // Offset of the first cached fetch result in the block chain
+ resultCache []*fetchResult // Downloaded but not yet delivered fetch results
+ resultOffset uint64 // Offset of the first cached fetch result in the block chain
+ resultSize common.StorageSize // Approximate size of a block (exponential moving average)
lock *sync.Mutex
active *sync.Cond
@@ -109,7 +112,7 @@ func newQueue() *queue {
receiptTaskQueue: prque.New(),
receiptPendPool: make(map[string]*fetchRequest),
receiptDonePool: make(map[common.Hash]struct{}),
- resultCache: make([]*fetchResult, blockCacheLimit),
+ resultCache: make([]*fetchResult, blockCacheItems),
active: sync.NewCond(lock),
lock: lock,
}
@@ -122,10 +125,8 @@ func (q *queue) Reset() {
q.closed = false
q.mode = FullSync
- q.fastSyncPivot = 0
q.headerHead = common.Hash{}
-
q.headerPendPool = make(map[string]*fetchRequest)
q.blockTaskPool = make(map[common.Hash]*types.Header)
@@ -138,7 +139,7 @@ func (q *queue) Reset() {
q.receiptPendPool = make(map[string]*fetchRequest)
q.receiptDonePool = make(map[common.Hash]struct{})
- q.resultCache = make([]*fetchResult, blockCacheLimit)
+ q.resultCache = make([]*fetchResult, blockCacheItems)
q.resultOffset = 0
}
@@ -214,27 +215,13 @@ func (q *queue) Idle() bool {
return (queued + pending + cached) == 0
}
-// FastSyncPivot retrieves the currently used fast sync pivot point.
-func (q *queue) FastSyncPivot() uint64 {
- q.lock.Lock()
- defer q.lock.Unlock()
-
- return q.fastSyncPivot
-}
-
// ShouldThrottleBlocks checks if the download should be throttled (active block (body)
// fetches exceed block cache).
func (q *queue) ShouldThrottleBlocks() bool {
q.lock.Lock()
defer q.lock.Unlock()
- // Calculate the currently in-flight block (body) requests
- pending := 0
- for _, request := range q.blockPendPool {
- pending += len(request.Hashes) + len(request.Headers)
- }
- // Throttle if more blocks (bodies) are in-flight than free space in the cache
- return pending >= len(q.resultCache)-len(q.blockDonePool)
+ return q.resultSlots(q.blockPendPool, q.blockDonePool) <= 0
}
// ShouldThrottleReceipts checks if the download should be throttled (active receipt
@@ -243,13 +230,39 @@ func (q *queue) ShouldThrottleReceipts() bool {
q.lock.Lock()
defer q.lock.Unlock()
- // Calculate the currently in-flight receipt requests
+ return q.resultSlots(q.receiptPendPool, q.receiptDonePool) <= 0
+}
+
+// resultSlots calculates the number of results slots available for requests
+// whilst adhering to both the item and the memory limit too of the results
+// cache.
+func (q *queue) resultSlots(pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}) int {
+ // Calculate the maximum length capped by the memory limit
+ limit := len(q.resultCache)
+ if common.StorageSize(len(q.resultCache))*q.resultSize > common.StorageSize(blockCacheMemory) {
+ limit = int((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize)
+ }
+ // Calculate the number of slots already finished
+ finished := 0
+ for _, result := range q.resultCache[:limit] {
+ if result == nil {
+ break
+ }
+ if _, ok := donePool[result.Hash]; ok {
+ finished++
+ }
+ }
+ // Calculate the number of slots currently downloading
pending := 0
- for _, request := range q.receiptPendPool {
- pending += len(request.Headers)
+ for _, request := range pendPool {
+ for _, header := range request.Headers {
+ if header.Number.Uint64() < q.resultOffset+uint64(limit) {
+ pending++
+ }
+ }
}
- // Throttle if more receipts are in-flight than free space in the cache
- return pending >= len(q.resultCache)-len(q.receiptDonePool)
+ // Return the free slots to distribute
+ return limit - finished - pending
}
// ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill
@@ -323,8 +336,7 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
q.blockTaskPool[hash] = header
q.blockTaskQueue.Push(header, -float32(header.Number.Uint64()))
- if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot {
- // Fast phase of the fast sync, retrieve receipts too
+ if q.mode == FastSync {
q.receiptTaskPool[hash] = header
q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64()))
}
@@ -335,18 +347,25 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
return inserts
}
-// WaitResults retrieves and permanently removes a batch of fetch
-// results from the cache. the result slice will be empty if the queue
-// has been closed.
-func (q *queue) WaitResults() []*fetchResult {
+// Results retrieves and permanently removes a batch of fetch results from
+// the cache. the result slice will be empty if the queue has been closed.
+func (q *queue) Results(block bool) []*fetchResult {
q.lock.Lock()
defer q.lock.Unlock()
+ // Count the number of items available for processing
nproc := q.countProcessableItems()
for nproc == 0 && !q.closed {
+ if !block {
+ return nil
+ }
q.active.Wait()
nproc = q.countProcessableItems()
}
+ // Since we have a batch limit, don't pull more into "dangling" memory
+ if nproc > maxResultsProcess {
+ nproc = maxResultsProcess
+ }
results := make([]*fetchResult, nproc)
copy(results, q.resultCache[:nproc])
if len(results) > 0 {
@@ -363,6 +382,21 @@ func (q *queue) WaitResults() []*fetchResult {
}
// Advance the expected block number of the first cache entry.
q.resultOffset += uint64(nproc)
+
+ // Recalculate the result item weights to prevent memory exhaustion
+ for _, result := range results {
+ size := result.Header.Size()
+ for _, uncle := range result.Uncles {
+ size += uncle.Size()
+ }
+ for _, receipt := range result.Receipts {
+ size += receipt.Size()
+ }
+ for _, tx := range result.Transactions {
+ size += tx.Size()
+ }
+ q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize
+ }
}
return results
}
@@ -370,21 +404,9 @@ func (q *queue) WaitResults() []*fetchResult {
// countProcessableItems counts the processable items.
func (q *queue) countProcessableItems() int {
for i, result := range q.resultCache {
- // Don't process incomplete or unavailable items.
if result == nil || result.Pending > 0 {
return i
}
- // Stop before processing the pivot block to ensure that
- // resultCache has space for fsHeaderForceVerify items. Not
- // doing this could leave us unable to download the required
- // amount of headers.
- if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot {
- for j := 0; j < fsHeaderForceVerify; j++ {
- if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil {
- return i
- }
- }
- }
}
return len(q.resultCache)
}
@@ -473,10 +495,8 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
return nil, false, nil
}
// Calculate an upper limit on the items we might fetch (i.e. throttling)
- space := len(q.resultCache) - len(donePool)
- for _, request := range pendPool {
- space -= len(request.Headers)
- }
+ space := q.resultSlots(pendPool, donePool)
+
// Retrieve a batch of tasks, skipping previously failed ones
send := make([]*types.Header, 0, count)
skip := make([]*types.Header, 0)
@@ -484,6 +504,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
progress := false
for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ {
header := taskQueue.PopItem().(*types.Header)
+ hash := header.Hash()
// If we're the first to request this task, initialise the result container
index := int(header.Number.Int64() - int64(q.resultOffset))
@@ -493,18 +514,19 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
}
if q.resultCache[index] == nil {
components := 1
- if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot {
+ if q.mode == FastSync {
components = 2
}
q.resultCache[index] = &fetchResult{
Pending: components,
+ Hash: hash,
Header: header,
}
}
// If this fetch task is a noop, skip this fetch operation
if isNoop(header) {
- donePool[header.Hash()] = struct{}{}
- delete(taskPool, header.Hash())
+ donePool[hash] = struct{}{}
+ delete(taskPool, hash)
space, proc = space-1, proc-1
q.resultCache[index].Pending--
@@ -512,7 +534,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
continue
}
// Otherwise unless the peer is known not to have the data, add to the retrieve list
- if p.Lacks(header.Hash()) {
+ if p.Lacks(hash) {
skip = append(skip, header)
} else {
send = append(send, header)
@@ -565,9 +587,6 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m
if request.From > 0 {
taskQueue.Push(request.From, -float32(request.From))
}
- for hash, index := range request.Hashes {
- taskQueue.Push(hash, float32(index))
- }
for _, header := range request.Headers {
taskQueue.Push(header, -float32(header.Number.Uint64()))
}
@@ -640,18 +659,11 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest,
if request.From > 0 {
taskQueue.Push(request.From, -float32(request.From))
}
- for hash, index := range request.Hashes {
- taskQueue.Push(hash, float32(index))
- }
for _, header := range request.Headers {
taskQueue.Push(header, -float32(header.Number.Uint64()))
}
// Add the peer to the expiry report along the the number of failed requests
- expirations := len(request.Hashes)
- if expirations < len(request.Headers) {
- expirations = len(request.Headers)
- }
- expiries[id] = expirations
+ expiries[id] = len(request.Headers)
}
}
// Remove the expired requests from the pending pool
@@ -828,14 +840,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
failure = err
break
}
- donePool[header.Hash()] = struct{}{}
+ hash := header.Hash()
+
+ donePool[hash] = struct{}{}
q.resultCache[index].Pending--
useful = true
accepted++
// Clean up a successful fetch
request.Headers[i] = nil
- delete(taskPool, header.Hash())
+ delete(taskPool, hash)
}
// Return all failed or missing fetches to the queue
for _, header := range request.Headers {
@@ -860,7 +874,7 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
// Prepare configures the result cache to allow accepting and caching inbound
// fetch results.
-func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.Header) {
+func (q *queue) Prepare(offset uint64, mode SyncMode) {
q.lock.Lock()
defer q.lock.Unlock()
@@ -868,6 +882,5 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.
if q.resultOffset < offset {
q.resultOffset = offset
}
- q.fastSyncPivot = pivot
q.mode = mode
}
diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go
index 937828b94..9cc65a208 100644
--- a/eth/downloader/statesync.go
+++ b/eth/downloader/statesync.go
@@ -20,7 +20,6 @@ import (
"fmt"
"hash"
"sync"
- "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
@@ -294,6 +293,9 @@ func (s *stateSync) loop() error {
case <-s.cancel:
return errCancelStateFetch
+ case <-s.d.cancelCh:
+ return errCancelStateFetch
+
case req := <-s.deliver:
// Response, disconnect or timeout triggered, drop the peer if stalling
log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut())
@@ -304,15 +306,11 @@ func (s *stateSync) loop() error {
s.d.dropPeer(req.peer.id)
}
// Process all the received blobs and check for stale delivery
- stale, err := s.process(req)
- if err != nil {
+ if err := s.process(req); err != nil {
log.Warn("Node data write error", "err", err)
return err
}
- // The the delivery contains requested data, mark the node idle (otherwise it's a timed out delivery)
- if !stale {
- req.peer.SetNodeDataIdle(len(req.response))
- }
+ req.peer.SetNodeDataIdle(len(req.response))
}
}
return s.commit(true)
@@ -352,6 +350,7 @@ func (s *stateSync) assignTasks() {
case s.d.trackStateReq <- req:
req.peer.FetchNodeData(req.items)
case <-s.cancel:
+ case <-s.d.cancelCh:
}
}
}
@@ -390,7 +389,7 @@ func (s *stateSync) fillTasks(n int, req *stateReq) {
// process iterates over a batch of delivered state data, injecting each item
// into a running state sync, re-queuing any items that were requested but not
// delivered.
-func (s *stateSync) process(req *stateReq) (bool, error) {
+func (s *stateSync) process(req *stateReq) error {
// Collect processing stats and update progress if valid data was received
duplicate, unexpected := 0, 0
@@ -401,7 +400,7 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
}(time.Now())
// Iterate over all the delivered data and inject one-by-one into the trie
- progress, stale := false, len(req.response) > 0
+ progress := false
for _, blob := range req.response {
prog, hash, err := s.processNodeData(blob)
@@ -415,20 +414,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
case trie.ErrAlreadyProcessed:
duplicate++
default:
- return stale, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
+ return fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
}
- // If the node delivered a requested item, mark the delivery non-stale
if _, ok := req.tasks[hash]; ok {
delete(req.tasks, hash)
- stale = false
}
}
- // If we're inside the critical section, reset fail counter since we progressed.
- if progress && atomic.LoadUint32(&s.d.fsPivotFails) > 1 {
- log.Trace("Fast-sync progressed, resetting fail counter", "previous", atomic.LoadUint32(&s.d.fsPivotFails))
- atomic.StoreUint32(&s.d.fsPivotFails, 1) // Don't ever reset to 0, as that will unlock the pivot block
- }
-
// Put unfulfilled tasks back into the retry queue
npeers := s.d.peers.Len()
for hash, task := range req.tasks {
@@ -441,12 +432,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
// If we've requested the node too many times already, it may be a malicious
// sync where nobody has the right data. Abort.
if len(task.attempts) >= npeers {
- return stale, fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
+ return fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
}
// Missing item, place into the retry queue.
s.tasks[hash] = task
}
- return stale, nil
+ return nil
}
// processNodeData tries to inject a trie node data blob delivered from a remote
diff --git a/eth/handler.go b/eth/handler.go
index fcd53c5a6..c2426544f 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -71,7 +71,6 @@ type ProtocolManager struct {
txpool txPool
blockchain *core.BlockChain
- chaindb ethdb.Database
chainconfig *params.ChainConfig
maxPeers int
@@ -106,7 +105,6 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
eventMux: mux,
txpool: txpool,
blockchain: blockchain,
- chaindb: chaindb,
chainconfig: config,
peers: newPeerSet(),
newPeerCh: make(chan *peer),
@@ -538,7 +536,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
// Retrieve the requested state entry, stopping if enough was found
- if entry, err := pm.chaindb.Get(hash.Bytes()); err == nil {
+ if entry, err := pm.blockchain.TrieNode(hash); err == nil {
data = append(data, entry)
bytes += len(entry)
}
@@ -576,7 +574,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
// Retrieve the requested block's receipts, skipping if unknown to us
- results := core.GetBlockReceipts(pm.chaindb, hash, core.GetBlockNumber(pm.chaindb, hash))
+ results := pm.blockchain.GetReceiptsByHash(hash)
if results == nil {
if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
continue
diff --git a/eth/handler_test.go b/eth/handler_test.go
index 9a02eddfb..e336dfa28 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -56,7 +56,7 @@ func TestProtocolCompatibility(t *testing.T) {
for i, tt := range tests {
ProtocolVersions = []uint{tt.version}
- pm, err := newTestProtocolManager(tt.mode, 0, nil, nil)
+ pm, _, err := newTestProtocolManager(tt.mode, 0, nil, nil)
if pm != nil {
defer pm.Stop()
}
@@ -71,7 +71,7 @@ func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) }
func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) }
func testGetBlockHeaders(t *testing.T, protocol int) {
- pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
@@ -230,7 +230,7 @@ func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) }
func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) }
func testGetBlockBodies(t *testing.T, protocol int) {
- pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
@@ -337,13 +337,13 @@ func testGetNodeData(t *testing.T, protocol int) {
}
}
// Assemble the test environment
- pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
+ pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
// Fetch for now the entire chain db
hashes := []common.Hash{}
- for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() {
+ for _, key := range db.Keys() {
if len(key) == len(common.Hash{}) {
hashes = append(hashes, common.BytesToHash(key))
}
@@ -429,7 +429,7 @@ func testGetReceipt(t *testing.T, protocol int) {
}
}
// Assemble the test environment
- pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
@@ -439,7 +439,7 @@ func testGetReceipt(t *testing.T, protocol int) {
block := pm.blockchain.GetBlockByNumber(i)
hashes = append(hashes, block.Hash())
- receipts = append(receipts, core.GetBlockReceipts(pm.chaindb, block.Hash(), block.NumberU64()))
+ receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash()))
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x0f, hashes)
@@ -472,7 +472,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool
config = &params.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked}
gspec = &core.Genesis{Config: config}
genesis = gspec.MustCommit(db)
- blockchain, _ = core.NewBlockChain(db, config, pow, vm.Config{})
+ blockchain, _ = core.NewBlockChain(db, nil, config, pow, vm.Config{})
)
pm, err := NewProtocolManager(config, downloader.FullSync, DefaultConfig.NetworkId, evmux, new(testTxPool), pow, blockchain, db)
if err != nil {
diff --git a/eth/helper_test.go b/eth/helper_test.go
index 9a4dc9010..2b05cea80 100644
--- a/eth/helper_test.go
+++ b/eth/helper_test.go
@@ -49,7 +49,7 @@ var (
// newTestProtocolManager creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification
// channels for different events.
-func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, error) {
+func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase, error) {
var (
evmux = new(event.TypeMux)
engine = ethash.NewFaker()
@@ -59,7 +59,7 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}},
}
genesis = gspec.MustCommit(db)
- blockchain, _ = core.NewBlockChain(db, gspec.Config, engine, vm.Config{})
+ blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
)
chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
if _, err := blockchain.InsertChain(chain); err != nil {
@@ -68,22 +68,22 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
pm, err := NewProtocolManager(gspec.Config, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx}, engine, blockchain, db)
if err != nil {
- return nil, err
+ return nil, nil, err
}
pm.Start(1000)
- return pm, nil
+ return pm, db, nil
}
// newTestProtocolManagerMust creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification
// channels for different events. In case of an error, the constructor force-
// fails the test.
-func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager {
- pm, err := newTestProtocolManager(mode, blocks, generator, newtx)
+func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase) {
+ pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx)
if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err)
}
- return pm
+ return pm, db
}
// testTxPool is a fake, helper transaction pool for testing purposes
diff --git a/eth/protocol_test.go b/eth/protocol_test.go
index 7cbcba571..b2f93d8dd 100644
--- a/eth/protocol_test.go
+++ b/eth/protocol_test.go
@@ -41,7 +41,7 @@ func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) }
func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) }
func testStatusMsgErrors(t *testing.T, protocol int) {
- pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
var (
genesis = pm.blockchain.Genesis()
head = pm.blockchain.CurrentHeader()
@@ -98,7 +98,7 @@ func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
func testRecvTransactions(t *testing.T, protocol int) {
txAdded := make(chan []*types.Transaction)
- pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
pm.acceptTxs = 1 // mark synced to accept transactions
p, _ := newTestPeer("peer", protocol, pm, true)
defer pm.Stop()
@@ -125,7 +125,7 @@ func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) }
func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
func testSendTransactions(t *testing.T, protocol int) {
- pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
+ pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
defer pm.Stop()
// Fill the pool with big transactions.
diff --git a/eth/sync_test.go b/eth/sync_test.go
index 9eaa1156f..88c10c7f7 100644
--- a/eth/sync_test.go
+++ b/eth/sync_test.go
@@ -30,12 +30,12 @@ import (
// imported into the blockchain.
func TestFastSyncDisabling(t *testing.T) {
// Create a pristine protocol manager, check that fast sync is left enabled
- pmEmpty := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
+ pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
if atomic.LoadUint32(&pmEmpty.fastSync) == 0 {
t.Fatalf("fast sync disabled on pristine blockchain")
}
// Create a full protocol manager, check that fast sync gets disabled
- pmFull := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
+ pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
if atomic.LoadUint32(&pmFull.fastSync) == 1 {
t.Fatalf("fast sync not disabled on non-empty blockchain")
}
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index a4cba7a4d..314086335 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -808,7 +808,7 @@ func (s *PublicBlockChainAPI) rpcOutputBlock(b *types.Block, inclTx bool, fullTx
"difficulty": (*hexutil.Big)(head.Difficulty),
"totalDifficulty": (*hexutil.Big)(s.b.GetTd(b.Hash())),
"extraData": hexutil.Bytes(head.Extra),
- "size": hexutil.Uint64(uint64(b.Size().Int64())),
+ "size": hexutil.Uint64(b.Size()),
"gasLimit": hexutil.Uint64(head.GasLimit),
"gasUsed": hexutil.Uint64(head.GasUsed),
"timestamp": (*hexutil.Big)(head.Time),
diff --git a/les/handler.go b/les/handler.go
index 8cd37c7ab..5c93133fb 100644
--- a/les/handler.go
+++ b/les/handler.go
@@ -18,7 +18,6 @@
package les
import (
- "bytes"
"encoding/binary"
"errors"
"fmt"
@@ -78,6 +77,7 @@ type BlockChain interface {
GetHeaderByHash(hash common.Hash) *types.Header
CurrentHeader() *types.Header
GetTd(hash common.Hash, number uint64) *big.Int
+ State() (*state.StateDB, error)
InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
Rollback(chain []common.Hash)
GetHeaderByNumber(number uint64) *types.Header
@@ -579,17 +579,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
for _, req := range req.Reqs {
// Retrieve the requested state entry, stopping if enough was found
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
- if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil {
- sdata := trie.Get(req.AccKey)
- var acc state.Account
- if err := rlp.DecodeBytes(sdata, &acc); err == nil {
- entry, _ := pm.chainDb.Get(acc.CodeHash)
- if bytes+len(entry) >= softResponseLimit {
- break
- }
- data = append(data, entry)
- bytes += len(entry)
- }
+ statedb, err := pm.blockchain.State()
+ if err != nil {
+ continue
+ }
+ account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
+ if err != nil {
+ continue
+ }
+ code, _ := statedb.Database().TrieDB().Node(common.BytesToHash(account.CodeHash))
+
+ data = append(data, code)
+ if bytes += len(code); bytes >= softResponseLimit {
+ break
}
}
}
@@ -701,25 +703,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrRequestRejected, "")
}
for _, req := range req.Reqs {
- if bytes >= softResponseLimit {
- break
- }
// Retrieve the requested state entry, stopping if enough was found
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
- if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil {
- if len(req.AccKey) > 0 {
- sdata := tr.Get(req.AccKey)
- tr = nil
- var acc state.Account
- if err := rlp.DecodeBytes(sdata, &acc); err == nil {
- tr, _ = trie.New(acc.Root, pm.chainDb)
- }
+ statedb, err := pm.blockchain.State()
+ if err != nil {
+ continue
+ }
+ var trie state.Trie
+ if len(req.AccKey) > 0 {
+ account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
+ if err != nil {
+ continue
}
- if tr != nil {
- var proof light.NodeList
- tr.Prove(req.Key, 0, &proof)
- proofs = append(proofs, proof)
- bytes += proof.DataSize()
+ trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
+ } else {
+ trie, _ = statedb.Database().OpenTrie(header.Root)
+ }
+ if trie != nil {
+ var proof light.NodeList
+ trie.Prove(req.Key, 0, &proof)
+
+ proofs = append(proofs, proof)
+ if bytes += proof.DataSize(); bytes >= softResponseLimit {
+ break
}
}
}
@@ -740,9 +746,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
// Gather state data until the fetch or network limits is reached
var (
- lastBHash common.Hash
- lastAccKey []byte
- tr, str *trie.Trie
+ lastBHash common.Hash
+ statedb *state.StateDB
+ root common.Hash
)
reqCnt := len(req.Reqs)
if reject(uint64(reqCnt), MaxProofsFetch) {
@@ -752,35 +758,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
nodes := light.NewNodeSet()
for _, req := range req.Reqs {
- if nodes.DataSize() >= softResponseLimit {
- break
- }
- if tr == nil || req.BHash != lastBHash {
+ // Look up the state belonging to the request
+ if statedb == nil || req.BHash != lastBHash {
+ statedb, root, lastBHash = nil, common.Hash{}, req.BHash
+
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
- tr, _ = trie.New(header.Root, pm.chainDb)
- } else {
- tr = nil
+ statedb, _ = pm.blockchain.State()
+ root = header.Root
}
- lastBHash = req.BHash
- str = nil
}
- if tr != nil {
- if len(req.AccKey) > 0 {
- if str == nil || !bytes.Equal(req.AccKey, lastAccKey) {
- sdata := tr.Get(req.AccKey)
- str = nil
- var acc state.Account
- if err := rlp.DecodeBytes(sdata, &acc); err == nil {
- str, _ = trie.New(acc.Root, pm.chainDb)
- }
- lastAccKey = common.CopyBytes(req.AccKey)
- }
- if str != nil {
- str.Prove(req.Key, req.FromLevel, nodes)
- }
- } else {
- tr.Prove(req.Key, req.FromLevel, nodes)
+ if statedb == nil {
+ continue
+ }
+ // Pull the account or storage trie of the request
+ var trie state.Trie
+ if len(req.AccKey) > 0 {
+ account, err := pm.getAccount(statedb, root, common.BytesToHash(req.AccKey))
+ if err != nil {
+ continue
}
+ trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
+ } else {
+ trie, _ = statedb.Database().OpenTrie(root)
+ }
+ if trie == nil {
+ continue
+ }
+ // Prove the user's request from the account or stroage trie
+ trie.Prove(req.Key, req.FromLevel, nodes)
+ if nodes.DataSize() >= softResponseLimit {
+ break
}
}
proofs := nodes.NodeList()
@@ -849,23 +856,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) {
return errResp(ErrRequestRejected, "")
}
- trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix)
for _, req := range req.Reqs {
- if bytes >= softResponseLimit {
- break
- }
-
if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil {
sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1)
if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) {
- if tr, _ := trie.New(root, trieDb); tr != nil {
- var encNumber [8]byte
- binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
- var proof light.NodeList
- tr.Prove(encNumber[:], 0, &proof)
- proofs = append(proofs, ChtResp{Header: header, Proof: proof})
- bytes += proof.DataSize() + estHeaderRlpSize
+ statedb, err := pm.blockchain.State()
+ if err != nil {
+ continue
}
+ trie, err := statedb.Database().OpenTrie(root)
+ if err != nil {
+ continue
+ }
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
+
+ var proof light.NodeList
+ trie.Prove(encNumber[:], 0, &proof)
+
+ proofs = append(proofs, ChtResp{Header: header, Proof: proof})
+ if bytes += proof.DataSize() + estHeaderRlpSize; bytes >= softResponseLimit {
+ break
+ }
+
}
}
}
@@ -897,25 +910,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
lastIdx uint64
lastType uint
root common.Hash
- tr *trie.Trie
+ statedb *state.StateDB
+ trie state.Trie
)
nodes := light.NewNodeSet()
for _, req := range req.Reqs {
- if nodes.DataSize()+auxBytes >= softResponseLimit {
- break
- }
- if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx {
- var prefix string
- root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx)
- if root != (common.Hash{}) {
- if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil {
- tr = t
+ if trie == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx {
+ statedb, trie, lastType, lastIdx = nil, nil, req.HelperTrieType, req.TrieIdx
+
+ if root, _ = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx); root != (common.Hash{}) {
+ if statedb, _ = pm.blockchain.State(); statedb != nil {
+ trie, _ = statedb.Database().OpenTrie(root)
}
}
- lastType = req.HelperTrieType
- lastIdx = req.TrieIdx
}
if req.AuxReq == auxRoot {
var data []byte
@@ -925,8 +934,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
auxData = append(auxData, data)
auxBytes += len(data)
} else {
- if tr != nil {
- tr.Prove(req.Key, req.FromLevel, nodes)
+ if trie != nil {
+ trie.Prove(req.Key, req.FromLevel, nodes)
}
if req.AuxReq != 0 {
data := pm.getHelperTrieAuxData(req)
@@ -934,6 +943,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
auxBytes += len(data)
}
}
+ if nodes.DataSize()+auxBytes >= softResponseLimit {
+ break
+ }
}
proofs := nodes.NodeList()
bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
@@ -1090,6 +1102,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return nil
}
+// getAccount retrieves an account from the state based at root.
+func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) {
+ trie, err := trie.New(root, statedb.Database().TrieDB())
+ if err != nil {
+ return state.Account{}, err
+ }
+ blob, err := trie.TryGet(hash[:])
+ if err != nil {
+ return state.Account{}, err
+ }
+ var account state.Account
+ if err = rlp.DecodeBytes(blob, &account); err != nil {
+ return state.Account{}, err
+ }
+ return account, nil
+}
+
// getHelperTrie returns the post-processed trie root for the given trie ID and section index
func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) {
switch id {
diff --git a/les/handler_test.go b/les/handler_test.go
index 10e5499a3..e5446c031 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -359,7 +359,7 @@ func testGetProofs(t *testing.T, protocol int) {
for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
header := bc.GetHeaderByNumber(i)
root := header.Root
- trie, _ := trie.New(root, db)
+ trie, _ := trie.New(root, trie.NewDatabase(db))
for _, acc := range accounts {
req := ProofReq{
diff --git a/les/helper_test.go b/les/helper_test.go
index 1c1de64ad..bf08e1e2f 100644
--- a/les/helper_test.go
+++ b/les/helper_test.go
@@ -146,7 +146,7 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
if lightSync {
chain, _ = light.NewLightChain(odr, gspec.Config, engine)
} else {
- blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, vm.Config{})
+ blockchain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
gchain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err)
diff --git a/les/odr_test.go b/les/odr_test.go
index cf609be88..88e121cda 100644
--- a/les/odr_test.go
+++ b/les/odr_test.go
@@ -101,7 +101,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
res = append(res, rlp...)
}
}
-
return res
}
diff --git a/light/lightchain.go b/light/lightchain.go
index f47957512..24529ef82 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -18,6 +18,7 @@ package light
import (
"context"
+ "errors"
"math/big"
"sync"
"sync/atomic"
@@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
@@ -212,6 +214,11 @@ func (bc *LightChain) Genesis() *types.Block {
return bc.genesisBlock
}
+// State returns a new mutable state based on the current HEAD block.
+func (bc *LightChain) State() (*state.StateDB, error) {
+ return nil, errors.New("not implemented, needs client/server interface split")
+}
+
// GetBody retrieves a block body (transactions and uncles) from the database
// or ODR service by hash, caching it if found.
func (self *LightChain) GetBody(ctx context.Context, hash common.Hash) (*types.Body, error) {
diff --git a/light/nodeset.go b/light/nodeset.go
index c530a4fbe..ffdb71bb7 100644
--- a/light/nodeset.go
+++ b/light/nodeset.go
@@ -22,8 +22,8 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
)
// NodeSet stores a set of trie nodes. It implements trie.Database and can also
@@ -99,7 +99,7 @@ func (db *NodeSet) NodeList() NodeList {
}
// Store writes the contents of the set to the given database
-func (db *NodeSet) Store(target trie.Database) {
+func (db *NodeSet) Store(target ethdb.Putter) {
db.lock.RLock()
defer db.lock.RUnlock()
@@ -108,11 +108,11 @@ func (db *NodeSet) Store(target trie.Database) {
}
}
-// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter.
+// NodeList stores an ordered list of trie nodes. It implements ethdb.Putter.
type NodeList []rlp.RawValue
// Store writes the contents of the list to the given database
-func (n NodeList) Store(db trie.Database) {
+func (n NodeList) Store(db ethdb.Putter) {
for _, node := range n {
db.Put(crypto.Keccak256(node), node)
}
diff --git a/light/odr_test.go b/light/odr_test.go
index e3d07518a..d3f9374fd 100644
--- a/light/odr_test.go
+++ b/light/odr_test.go
@@ -74,7 +74,7 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
case *ReceiptsRequest:
req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash))
case *TrieRequest:
- t, _ := trie.New(req.Id.Root, odr.sdb)
+ t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb))
nodes := NewNodeSet()
t.Prove(req.Key, 0, nodes)
req.Proof = nodes
@@ -239,7 +239,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
)
gspec.MustCommit(ldb)
// Assemble the test environment
- blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+ blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil {
t.Fatal(err)
diff --git a/light/postprocess.go b/light/postprocess.go
index 32dbc102b..bbac58d12 100644
--- a/light/postprocess.go
+++ b/light/postprocess.go
@@ -113,7 +113,8 @@ func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common
// ChtIndexerBackend implements core.ChainIndexerBackend
type ChtIndexerBackend struct {
- db, cdb ethdb.Database
+ diskdb ethdb.Database
+ triedb *trie.Database
section, sectionSize uint64
lastHash common.Hash
trie *trie.Trie
@@ -121,8 +122,6 @@ type ChtIndexerBackend struct {
// NewBloomTrieIndexer creates a BloomTrie chain indexer
func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
- cdb := ethdb.NewTable(db, ChtTablePrefix)
- idb := ethdb.NewTable(db, "chtIndex-")
var sectionSize, confirmReq uint64
if clientMode {
sectionSize = ChtFrequency
@@ -131,17 +130,23 @@ func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
sectionSize = ChtV1Frequency
confirmReq = HelperTrieProcessConfirmations
}
- return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht")
+ idb := ethdb.NewTable(db, "chtIndex-")
+ backend := &ChtIndexerBackend{
+ diskdb: db,
+ triedb: trie.NewDatabase(ethdb.NewTable(db, ChtTablePrefix)),
+ sectionSize: sectionSize,
+ }
+ return core.NewChainIndexer(db, idb, backend, sectionSize, confirmReq, time.Millisecond*100, "cht")
}
// Reset implements core.ChainIndexerBackend
func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
var root common.Hash
if section > 0 {
- root = GetChtRoot(c.db, section-1, lastSectionHead)
+ root = GetChtRoot(c.diskdb, section-1, lastSectionHead)
}
var err error
- c.trie, err = trie.New(root, c.cdb)
+ c.trie, err = trie.New(root, c.triedb)
c.section = section
return err
}
@@ -151,7 +156,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
hash, num := header.Hash(), header.Number.Uint64()
c.lastHash = hash
- td := core.GetTd(c.db, hash, num)
+ td := core.GetTd(c.diskdb, hash, num)
if td == nil {
panic(nil)
}
@@ -163,17 +168,16 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
// Commit implements core.ChainIndexerBackend
func (c *ChtIndexerBackend) Commit() error {
- batch := c.cdb.NewBatch()
- root, err := c.trie.CommitTo(batch)
+ root, err := c.trie.Commit(nil)
if err != nil {
return err
- } else {
- batch.Write()
- if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 {
- log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root))
- }
- StoreChtRoot(c.db, c.section, c.lastHash, root)
}
+ c.triedb.Commit(root, false)
+
+ if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 {
+ log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root))
+ }
+ StoreChtRoot(c.diskdb, c.section, c.lastHash, root)
return nil
}
@@ -205,7 +209,8 @@ func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root
// BloomTrieIndexerBackend implements core.ChainIndexerBackend
type BloomTrieIndexerBackend struct {
- db, cdb ethdb.Database
+ diskdb ethdb.Database
+ triedb *trie.Database
section, parentSectionSize, bloomTrieRatio uint64
trie *trie.Trie
sectionHeads []common.Hash
@@ -213,9 +218,12 @@ type BloomTrieIndexerBackend struct {
// NewBloomTrieIndexer creates a BloomTrie chain indexer
func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
- cdb := ethdb.NewTable(db, BloomTrieTablePrefix)
+ backend := &BloomTrieIndexerBackend{
+ diskdb: db,
+ triedb: trie.NewDatabase(ethdb.NewTable(db, BloomTrieTablePrefix)),
+ }
idb := ethdb.NewTable(db, "bltIndex-")
- backend := &BloomTrieIndexerBackend{db: db, cdb: cdb}
+
var confirmReq uint64
if clientMode {
backend.parentSectionSize = BloomTrieFrequency
@@ -233,10 +241,10 @@ func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer
func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
var root common.Hash
if section > 0 {
- root = GetBloomTrieRoot(b.db, section-1, lastSectionHead)
+ root = GetBloomTrieRoot(b.diskdb, section-1, lastSectionHead)
}
var err error
- b.trie, err = trie.New(root, b.cdb)
+ b.trie, err = trie.New(root, b.triedb)
b.section = section
return err
}
@@ -259,7 +267,7 @@ func (b *BloomTrieIndexerBackend) Commit() error {
binary.BigEndian.PutUint64(encKey[2:10], b.section)
var decomp []byte
for j := uint64(0); j < b.bloomTrieRatio; j++ {
- data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
+ data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
if err != nil {
return err
}
@@ -279,17 +287,15 @@ func (b *BloomTrieIndexerBackend) Commit() error {
b.trie.Delete(encKey[:])
}
}
-
- batch := b.cdb.NewBatch()
- root, err := b.trie.CommitTo(batch)
+ root, err := b.trie.Commit(nil)
if err != nil {
return err
- } else {
- batch.Write()
- sectionHead := b.sectionHeads[b.bloomTrieRatio-1]
- log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize))
- StoreBloomTrieRoot(b.db, b.section, sectionHead, root)
}
+ b.triedb.Commit(root, false)
+
+ sectionHead := b.sectionHeads[b.bloomTrieRatio-1]
+ log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize))
+ StoreBloomTrieRoot(b.diskdb, b.section, sectionHead, root)
return nil
}
diff --git a/light/trie.go b/light/trie.go
index 7a9c86b98..c07e99461 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -18,12 +18,14 @@ package light
import (
"context"
+ "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
)
@@ -83,6 +85,10 @@ func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, er
return len(code), err
}
+func (db *odrDatabase) TrieDB() *trie.Database {
+ return nil
+}
+
type odrTrie struct {
db *odrDatabase
id *TrieID
@@ -113,11 +119,11 @@ func (t *odrTrie) TryDelete(key []byte) error {
})
}
-func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) {
+func (t *odrTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
if t.trie == nil {
return t.id.Root, nil
}
- return t.trie.CommitTo(db)
+ return t.trie.Commit(onleaf)
}
func (t *odrTrie) Hash() common.Hash {
@@ -135,13 +141,17 @@ func (t *odrTrie) GetKey(sha []byte) []byte {
return nil
}
+func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+ return errors.New("not implemented, needs client/server interface split")
+}
+
// do tries and retries to execute a function until it returns with no error or
// an error type other than MissingNodeError
func (t *odrTrie) do(key []byte, fn func() error) error {
for {
var err error
if t.trie == nil {
- t.trie, err = trie.New(t.id.Root, t.db.backend.Database())
+ t.trie, err = trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
}
if err == nil {
err = fn()
@@ -167,7 +177,7 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
// Open the actual non-ODR trie if that hasn't happened yet.
if t.trie == nil {
it.do(func() error {
- t, err := trie.New(t.id.Root, t.db.backend.Database())
+ t, err := trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
if err == nil {
it.t.trie = t
}
diff --git a/light/trie_test.go b/light/trie_test.go
index d99664718..0d6b2cc1d 100644
--- a/light/trie_test.go
+++ b/light/trie_test.go
@@ -40,7 +40,7 @@ func TestNodeIterator(t *testing.T) {
genesis = gspec.MustCommit(fulldb)
)
gspec.MustCommit(lightdb)
- blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+ blockchain, _ := core.NewBlockChain(fulldb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), fulldb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err)
diff --git a/light/txpool_test.go b/light/txpool_test.go
index b343f79b0..13d7d3ceb 100644
--- a/light/txpool_test.go
+++ b/light/txpool_test.go
@@ -88,7 +88,7 @@ func TestTxPool(t *testing.T) {
)
gspec.MustCommit(ldb)
// Assemble the test environment
- blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+ blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, poolTestBlocks, txPoolTestChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err)
diff --git a/miner/worker.go b/miner/worker.go
index 1520277e1..15395ae0b 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -309,7 +309,7 @@ func (self *worker) wait() {
for _, log := range work.state.Logs() {
log.BlockHash = block.Hash()
}
- stat, err := self.chain.WriteBlockAndState(block, work.receipts, work.state)
+ stat, err := self.chain.WriteBlockWithState(block, work.receipts, work.state)
if err != nil {
log.Error("Failed writing block to chain", "err", err)
continue
diff --git a/tests/block_test_util.go b/tests/block_test_util.go
index 4bfd6433f..beba48483 100644
--- a/tests/block_test_util.go
+++ b/tests/block_test_util.go
@@ -110,7 +110,7 @@ func (t *BlockTest) Run() error {
return fmt.Errorf("genesis block state root does not match test: computed=%x, test=%x", gblock.Root().Bytes()[:6], t.json.Genesis.StateRoot[:6])
}
- chain, err := core.NewBlockChain(db, config, ethash.NewShared(), vm.Config{})
+ chain, err := core.NewBlockChain(db, nil, config, ethash.NewShared(), vm.Config{})
if err != nil {
return err
}
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index 78c05b024..18280d2a4 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -125,7 +125,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
if !ok {
return nil, UnsupportedForkError{subtest.Fork}
}
- block, _ := t.genesis(config).ToBlock()
+ block := t.genesis(config).ToBlock(nil)
db, _ := ethdb.NewMemDatabase()
statedb := MakePreState(db, t.json.Pre)
@@ -147,7 +147,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) {
return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs)
}
- root, _ := statedb.CommitTo(db, config.IsEIP158(block.Number()))
+ root, _ := statedb.Commit(config.IsEIP158(block.Number()))
if root != common.Hash(post.Root) {
return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root)
}
@@ -170,7 +170,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB
}
}
// Commit and re-open to start with a clean state.
- root, _ := statedb.CommitTo(db, false)
+ root, _ := statedb.Commit(false)
statedb, _ = state.New(root, sdb)
return statedb
}
diff --git a/trie/database.go b/trie/database.go
new file mode 100644
index 000000000..d79120813
--- /dev/null
+++ b/trie/database.go
@@ -0,0 +1,355 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import (
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+)
+
+// secureKeyPrefix is the database key prefix used to store trie node preimages.
+var secureKeyPrefix = []byte("secure-key-")
+
+// secureKeyLength is the length of the above prefix + 32byte hash.
+const secureKeyLength = 11 + 32
+
+// DatabaseReader wraps the Get and Has method of a backing store for the trie.
+type DatabaseReader interface {
+ // Get retrieves the value associated with key form the database.
+ Get(key []byte) (value []byte, err error)
+
+ // Has retrieves whether a key is present in the database.
+ Has(key []byte) (bool, error)
+}
+
+// Database is an intermediate write layer between the trie data structures and
+// the disk database. The aim is to accumulate trie writes in-memory and only
+// periodically flush a couple tries to disk, garbage collecting the remainder.
+type Database struct {
+ diskdb ethdb.Database // Persistent storage for matured trie nodes
+
+ nodes map[common.Hash]*cachedNode // Data and references relationships of a node
+ preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
+ seckeybuf [secureKeyLength]byte // Ephemeral buffer for calculating preimage keys
+
+ gctime time.Duration // Time spent on garbage collection since last commit
+ gcnodes uint64 // Nodes garbage collected since last commit
+ gcsize common.StorageSize // Data storage garbage collected since last commit
+
+ nodesSize common.StorageSize // Storage size of the nodes cache
+ preimagesSize common.StorageSize // Storage size of the preimages cache
+
+ lock sync.RWMutex
+}
+
+// cachedNode is all the information we know about a single cached node in the
+// memory database write layer.
+type cachedNode struct {
+ blob []byte // Cached data block of the trie node
+ parents int // Number of live nodes referencing this one
+ children map[common.Hash]int // Children referenced by this nodes
+}
+
+// NewDatabase creates a new trie database to store ephemeral trie content before
+// its written out to disk or garbage collected.
+func NewDatabase(diskdb ethdb.Database) *Database {
+ return &Database{
+ diskdb: diskdb,
+ nodes: map[common.Hash]*cachedNode{
+ {}: {children: make(map[common.Hash]int)},
+ },
+ preimages: make(map[common.Hash][]byte),
+ }
+}
+
+// DiskDB retrieves the persistent storage backing the trie database.
+func (db *Database) DiskDB() DatabaseReader {
+ return db.diskdb
+}
+
+// Insert writes a new trie node to the memory database if it's yet unknown. The
+// method will make a copy of the slice.
+func (db *Database) Insert(hash common.Hash, blob []byte) {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.insert(hash, blob)
+}
+
+// insert is the private locked version of Insert.
+func (db *Database) insert(hash common.Hash, blob []byte) {
+ if _, ok := db.nodes[hash]; ok {
+ return
+ }
+ db.nodes[hash] = &cachedNode{
+ blob: common.CopyBytes(blob),
+ children: make(map[common.Hash]int),
+ }
+ db.nodesSize += common.StorageSize(common.HashLength + len(blob))
+}
+
+// insertPreimage writes a new trie node pre-image to the memory database if it's
+// yet unknown. The method will make a copy of the slice.
+//
+// Note, this method assumes that the database's lock is held!
+func (db *Database) insertPreimage(hash common.Hash, preimage []byte) {
+ if _, ok := db.preimages[hash]; ok {
+ return
+ }
+ db.preimages[hash] = common.CopyBytes(preimage)
+ db.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
+}
+
+// Node retrieves a cached trie node from memory. If it cannot be found cached,
+// the method queries the persistent database for the content.
+func (db *Database) Node(hash common.Hash) ([]byte, error) {
+ // Retrieve the node from cache if available
+ db.lock.RLock()
+ node := db.nodes[hash]
+ db.lock.RUnlock()
+
+ if node != nil {
+ return node.blob, nil
+ }
+ // Content unavailable in memory, attempt to retrieve from disk
+ return db.diskdb.Get(hash[:])
+}
+
+// preimage retrieves a cached trie node pre-image from memory. If it cannot be
+// found cached, the method queries the persistent database for the content.
+func (db *Database) preimage(hash common.Hash) ([]byte, error) {
+ // Retrieve the node from cache if available
+ db.lock.RLock()
+ preimage := db.preimages[hash]
+ db.lock.RUnlock()
+
+ if preimage != nil {
+ return preimage, nil
+ }
+ // Content unavailable in memory, attempt to retrieve from disk
+ return db.diskdb.Get(db.secureKey(hash[:]))
+}
+
+// secureKey returns the database key for the preimage of key, as an ephemeral
+// buffer. The caller must not hold onto the return value because it will become
+// invalid on the next call.
+func (db *Database) secureKey(key []byte) []byte {
+ buf := append(db.seckeybuf[:0], secureKeyPrefix...)
+ buf = append(buf, key...)
+ return buf
+}
+
+// Nodes retrieves the hashes of all the nodes cached within the memory database.
+// This method is extremely expensive and should only be used to validate internal
+// states in test code.
+func (db *Database) Nodes() []common.Hash {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ var hashes = make([]common.Hash, 0, len(db.nodes))
+ for hash := range db.nodes {
+ if hash != (common.Hash{}) { // Special case for "root" references/nodes
+ hashes = append(hashes, hash)
+ }
+ }
+ return hashes
+}
+
+// Reference adds a new reference from a parent node to a child node.
+func (db *Database) Reference(child common.Hash, parent common.Hash) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ db.reference(child, parent)
+}
+
+// reference is the private locked version of Reference.
+func (db *Database) reference(child common.Hash, parent common.Hash) {
+ // If the node does not exist, it's a node pulled from disk, skip
+ node, ok := db.nodes[child]
+ if !ok {
+ return
+ }
+ // If the reference already exists, only duplicate for roots
+ if _, ok = db.nodes[parent].children[child]; ok && parent != (common.Hash{}) {
+ return
+ }
+ node.parents++
+ db.nodes[parent].children[child]++
+}
+
+// Dereference removes an existing reference from a parent node to a child node.
+func (db *Database) Dereference(child common.Hash, parent common.Hash) {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ nodes, storage, start := len(db.nodes), db.nodesSize, time.Now()
+ db.dereference(child, parent)
+
+ db.gcnodes += uint64(nodes - len(db.nodes))
+ db.gcsize += storage - db.nodesSize
+ db.gctime += time.Since(start)
+
+ log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
+ "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
+}
+
+// dereference is the private locked version of Dereference.
+func (db *Database) dereference(child common.Hash, parent common.Hash) {
+ // Dereference the parent-child
+ node := db.nodes[parent]
+
+ node.children[child]--
+ if node.children[child] == 0 {
+ delete(node.children, child)
+ }
+ // If the node does not exist, it's a previously committed node.
+ node, ok := db.nodes[child]
+ if !ok {
+ return
+ }
+ // If there are no more references to the child, delete it and cascade
+ node.parents--
+ if node.parents == 0 {
+ for hash := range node.children {
+ db.dereference(hash, child)
+ }
+ delete(db.nodes, child)
+ db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
+ }
+}
+
+// Commit iterates over all the children of a particular node, writes them out
+// to disk, forcefully tearing down all references in both directions.
+//
+// As a side effect, all pre-images accumulated up to this point are also written.
+func (db *Database) Commit(node common.Hash, report bool) error {
+ // Create a database batch to flush persistent data out. It is important that
+ // outside code doesn't see an inconsistent state (referenced data removed from
+ // memory cache during commit but not yet in persistent storage). This is ensured
+ // by only uncaching existing data when the database write finalizes.
+ db.lock.RLock()
+
+ start := time.Now()
+ batch := db.diskdb.NewBatch()
+
+ // Move all of the accumulated preimages into a write batch
+ for hash, preimage := range db.preimages {
+ if err := batch.Put(db.secureKey(hash[:]), preimage); err != nil {
+ log.Error("Failed to commit preimage from trie database", "err", err)
+ db.lock.RUnlock()
+ return err
+ }
+ if batch.ValueSize() > ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ batch.Reset()
+ }
+ }
+ // Move the trie itself into the batch, flushing if enough data is accumulated
+ nodes, storage := len(db.nodes), db.nodesSize+db.preimagesSize
+ if err := db.commit(node, batch); err != nil {
+ log.Error("Failed to commit trie from trie database", "err", err)
+ db.lock.RUnlock()
+ return err
+ }
+ // Write batch ready, unlock for readers during persistence
+ if err := batch.Write(); err != nil {
+ log.Error("Failed to write trie to disk", "err", err)
+ db.lock.RUnlock()
+ return err
+ }
+ db.lock.RUnlock()
+
+ // Write successful, clear out the flushed data
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.preimages = make(map[common.Hash][]byte)
+ db.preimagesSize = 0
+
+ db.uncache(node)
+
+ logger := log.Info
+ if !report {
+ logger = log.Debug
+ }
+ logger("Persisted trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
+ "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
+
+ // Reset the garbage collection statistics
+ db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
+
+ return nil
+}
+
+// commit is the private locked version of Commit.
+func (db *Database) commit(hash common.Hash, batch ethdb.Batch) error {
+ // If the node does not exist, it's a previously committed node
+ node, ok := db.nodes[hash]
+ if !ok {
+ return nil
+ }
+ for child := range node.children {
+ if err := db.commit(child, batch); err != nil {
+ return err
+ }
+ }
+ if err := batch.Put(hash[:], node.blob); err != nil {
+ return err
+ }
+ // If we've reached an optimal match size, commit and start over
+ if batch.ValueSize() >= ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ batch.Reset()
+ }
+ return nil
+}
+
+// uncache is the post-processing step of a commit operation where the already
+// persisted trie is removed from the cache. The reason behind the two-phase
+// commit is to ensure consistent data availability while moving from memory
+// to disk.
+func (db *Database) uncache(hash common.Hash) {
+ // If the node does not exist, we're done on this path
+ node, ok := db.nodes[hash]
+ if !ok {
+ return
+ }
+ // Otherwise uncache the node's subtries and remove the node itself too
+ for child := range node.children {
+ db.uncache(child)
+ }
+ delete(db.nodes, hash)
+ db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+func (db *Database) Size() common.StorageSize {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ return db.nodesSize + db.preimagesSize
+}
diff --git a/trie/hasher.go b/trie/hasher.go
index 4719aabf6..2fc44787a 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -27,21 +27,23 @@ import (
)
type hasher struct {
- tmp *bytes.Buffer
- sha hash.Hash
- cachegen, cachelimit uint16
+ tmp *bytes.Buffer
+ sha hash.Hash
+ cachegen uint16
+ cachelimit uint16
+ onleaf LeafCallback
}
-// hashers live in a global pool.
+// hashers live in a global db.
var hasherPool = sync.Pool{
New: func() interface{} {
return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
},
}
-func newHasher(cachegen, cachelimit uint16) *hasher {
+func newHasher(cachegen, cachelimit uint16, onleaf LeafCallback) *hasher {
h := hasherPool.Get().(*hasher)
- h.cachegen, h.cachelimit = cachegen, cachelimit
+ h.cachegen, h.cachelimit, h.onleaf = cachegen, cachelimit, onleaf
return h
}
@@ -51,7 +53,7 @@ func returnHasherToPool(h *hasher) {
// hash collapses a node down into a hash node, also returning a copy of the
// original node initialized with the computed hash to replace the original one.
-func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) {
+func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) {
// If we're not storing the node, just hashing, use available cached data
if hash, dirty := n.cache(); hash != nil {
if db == nil {
@@ -98,7 +100,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
// hashChildren replaces the children of a node with their hashes if the encoded
// size of the child is larger than a hash, returning the collapsed node as well
// as a replacement for the original node with the child hashes cached in.
-func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) {
+func (h *hasher) hashChildren(original node, db *Database) (node, node, error) {
var err error
switch n := original.(type) {
@@ -145,7 +147,10 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
}
}
-func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
+// store hashes the node n and if we have a storage layer specified, it writes
+// the key/value pair to it and tracks any node->child references as well as any
+// node->external trie references.
+func (h *hasher) store(n node, db *Database, force bool) (node, error) {
// Don't store hashes or empty nodes.
if _, isHash := n.(hashNode); n == nil || isHash {
return n, nil
@@ -155,7 +160,6 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
if err := rlp.Encode(h.tmp, n); err != nil {
panic("encode error: " + err.Error())
}
-
if h.tmp.Len() < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent
}
@@ -167,7 +171,42 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
hash = hashNode(h.sha.Sum(nil))
}
if db != nil {
- return hash, db.Put(hash, h.tmp.Bytes())
+ // We are pooling the trie nodes into an intermediate memory cache
+ db.lock.Lock()
+
+ hash := common.BytesToHash(hash)
+ db.insert(hash, h.tmp.Bytes())
+
+ // Track all direct parent->child node references
+ switch n := n.(type) {
+ case *shortNode:
+ if child, ok := n.Val.(hashNode); ok {
+ db.reference(common.BytesToHash(child), hash)
+ }
+ case *fullNode:
+ for i := 0; i < 16; i++ {
+ if child, ok := n.Children[i].(hashNode); ok {
+ db.reference(common.BytesToHash(child), hash)
+ }
+ }
+ }
+ db.lock.Unlock()
+
+ // Track external references from account->storage trie
+ if h.onleaf != nil {
+ switch n := n.(type) {
+ case *shortNode:
+ if child, ok := n.Val.(valueNode); ok {
+ h.onleaf(child, hash)
+ }
+ case *fullNode:
+ for i := 0; i < 16; i++ {
+ if child, ok := n.Children[i].(valueNode); ok {
+ h.onleaf(child, hash)
+ }
+ }
+ }
+ }
}
return hash, nil
}
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 4808d8b0c..dce1c78b5 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -42,7 +42,7 @@ func TestIterator(t *testing.T) {
all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v))
}
- trie.Commit()
+ trie.Commit(nil)
found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil))
@@ -109,11 +109,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
}
// Cross check the hashes and the database itself
for hash := range hashes {
- if _, err := db.Get(hash.Bytes()); err != nil {
+ if _, err := db.Node(hash); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err)
}
}
- for _, key := range db.(*ethdb.MemDatabase).Keys() {
+ for hash, obj := range db.nodes {
+ if obj != nil && hash != (common.Hash{}) {
+ if _, ok := hashes[hash]; !ok {
+ t.Errorf("state entry not reported %x", hash)
+ }
+ }
+ }
+ for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() {
if _, ok := hashes[common.BytesToHash(key)]; !ok {
t.Errorf("state entry not reported %x", key)
}
@@ -191,13 +198,13 @@ func TestDifferenceIterator(t *testing.T) {
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
- triea.Commit()
+ triea.Commit(nil)
trieb := newEmpty()
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
- trieb.Commit()
+ trieb.Commit(nil)
found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
@@ -227,13 +234,13 @@ func TestUnionIterator(t *testing.T) {
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
- triea.Commit()
+ triea.Commit(nil)
trieb := newEmpty()
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
- trieb.Commit()
+ trieb.Commit(nil)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
it := NewIterator(di)
@@ -278,43 +285,75 @@ func TestIteratorNoDups(t *testing.T) {
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
-func TestIteratorContinueAfterError(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- tr, _ := New(common.Hash{}, db)
+func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueAfterError(t, false) }
+func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
+
+func testIteratorContinueAfterError(t *testing.T, memonly bool) {
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ tr, _ := New(common.Hash{}, triedb)
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
- tr.Commit()
+ tr.Commit(nil)
+ if !memonly {
+ triedb.Commit(tr.Hash(), true)
+ }
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
- keys := db.Keys()
- t.Log("node count", wantNodeCount)
+ var (
+ diskKeys [][]byte
+ memKeys []common.Hash
+ )
+ if memonly {
+ memKeys = triedb.Nodes()
+ } else {
+ diskKeys = diskdb.Keys()
+ }
for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB.
- tr, _ := New(tr.Hash(), db)
+ tr, _ := New(tr.Hash(), triedb)
// Remove a random node from the database. It can't be the root node
// because that one is already loaded.
- var rkey []byte
+ var (
+ rkey common.Hash
+ rval []byte
+ robj *cachedNode
+ )
for {
- if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) {
+ if memonly {
+ rkey = memKeys[rand.Intn(len(memKeys))]
+ } else {
+ copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
+ }
+ if rkey != tr.Hash() {
break
}
}
- rval, _ := db.Get(rkey)
- db.Delete(rkey)
-
+ if memonly {
+ robj = triedb.nodes[rkey]
+ delete(triedb.nodes, rkey)
+ } else {
+ rval, _ = diskdb.Get(rkey[:])
+ diskdb.Delete(rkey[:])
+ }
// Iterate until the error is hit.
seen := make(map[string]bool)
it := tr.NodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
- if !ok || !bytes.Equal(missing.NodeHash[:], rkey) {
+ if !ok || missing.NodeHash != rkey {
t.Fatal("didn't hit missing node, got", it.Error())
}
// Add the node back and continue iteration.
- db.Put(rkey, rval)
+ if memonly {
+ triedb.nodes[rkey] = robj
+ } else {
+ diskdb.Put(rkey[:], rval)
+ }
checkIteratorNoDups(t, it, seen)
if it.Error() != nil {
t.Fatal("unexpected error", it.Error())
@@ -328,21 +367,41 @@ func TestIteratorContinueAfterError(t *testing.T) {
// Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time.
-func TestIteratorContinueAfterSeekError(t *testing.T) {
+func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, false)
+}
+func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, true)
+}
+
+func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to db, then remove the node containing "bars".
- db, _ := ethdb.NewMemDatabase()
- ctr, _ := New(common.Hash{}, db)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ ctr, _ := New(common.Hash{}, triedb)
for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v))
}
- root, _ := ctr.Commit()
+ root, _ := ctr.Commit(nil)
+ if !memonly {
+ triedb.Commit(root, true)
+ }
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
- barNode, _ := db.Get(barNodeHash[:])
- db.Delete(barNodeHash[:])
-
+ var (
+ barNodeBlob []byte
+ barNodeObj *cachedNode
+ )
+ if memonly {
+ barNodeObj = triedb.nodes[barNodeHash]
+ delete(triedb.nodes, barNodeHash)
+ } else {
+ barNodeBlob, _ = diskdb.Get(barNodeHash[:])
+ diskdb.Delete(barNodeHash[:])
+ }
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
- tr, _ := New(root, db)
+ tr, _ := New(root, triedb)
it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
@@ -350,10 +409,12 @@ func TestIteratorContinueAfterSeekError(t *testing.T) {
} else if missing.NodeHash != barNodeHash {
t.Fatal("wrong node missing")
}
-
// Reinsert the missing node.
- db.Put(barNodeHash[:], barNode[:])
-
+ if memonly {
+ triedb.nodes[barNodeHash] = barNodeObj
+ } else {
+ diskdb.Put(barNodeHash[:], barNodeBlob)
+ }
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
t.Fatal(err)
diff --git a/trie/proof.go b/trie/proof.go
index 5e886a259..508e4a6cf 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -22,20 +22,19 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)
-// Prove constructs a merkle proof for key. The result contains all
-// encoded nodes on the path to the value at key. The value itself is
-// also included in the last node and can be retrieved by verifying
-// the proof.
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
//
-// If the trie does not contain a value for key, the returned proof
-// contains all nodes of the longest existing prefix of the key
-// (at least the root node), ending with the node that proves the
-// absence of the key.
-func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
// Collect all nodes on the path to key.
key = keybytesToHex(key)
nodes := []node{}
@@ -66,7 +65,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
- hasher := newHasher(0, 0)
+ hasher := newHasher(0, 0, nil)
for i, n := range nodes {
// Don't bother checking for errors here since hasher panics
// if encoding doesn't work and we're not writing to any database.
@@ -89,19 +88,29 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
return nil
}
-// VerifyProof checks merkle proofs. The given proof must contain the
-// value for key in a trie with the given root hash. VerifyProof
-// returns an error if the proof contains invalid trie nodes or the
-// wrong value.
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
+//
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+ return t.trie.Prove(key, fromLevel, proofDb)
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the value for
+// key in a trie with the given root hash. VerifyProof returns an error if the
+// proof contains invalid trie nodes or the wrong value.
func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) {
key = keybytesToHex(key)
- wantHash := rootHash[:]
+ wantHash := rootHash
for i := 0; ; i++ {
- buf, _ := proofDb.Get(wantHash)
+ buf, _ := proofDb.Get(wantHash[:])
if buf == nil {
- return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i
+ return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash), i
}
- n, err := decodeNode(wantHash, buf, 0)
+ n, err := decodeNode(wantHash[:], buf, 0)
if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err), i
}
@@ -112,7 +121,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (valu
return nil, nil, i
case hashNode:
key = keyrest
- wantHash = cld
+ copy(wantHash[:], cld)
case valueNode:
return cld, nil, i + 1
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 20c303f31..3881ee18a 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -23,10 +23,6 @@ import (
"github.com/ethereum/go-ethereum/log"
)
-var secureKeyPrefix = []byte("secure-key-")
-
-const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
-
// SecureTrie wraps a trie with key hashing. In a secure trie, all
// access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that
@@ -39,25 +35,25 @@ const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
// SecureTrie is not safe for concurrent use.
type SecureTrie struct {
trie Trie
- hashKeyBuf [secureKeyLength]byte
- secKeyBuf [200]byte
+ hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch
}
-// NewSecure creates a trie with an existing root node from db.
+// NewSecure creates a trie with an existing root node from a backing database
+// and optional intermediate in-memory node pool.
//
// If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found.
//
-// Accessing the trie loads nodes from db on demand.
+// Accessing the trie loads nodes from the database or node pool on demand.
// Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep.
-func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) {
+func NewSecure(root common.Hash, db *Database, cachelimit uint16) (*SecureTrie, error) {
if db == nil {
- panic("NewSecure called with nil database")
+ panic("trie.NewSecure called without a database")
}
trie, err := New(root, db)
if err != nil {
@@ -135,7 +131,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
- key, _ := t.trie.db.Get(t.secKey(shaKey))
+ key, _ := t.trie.db.preimage(common.BytesToHash(shaKey))
return key
}
@@ -144,8 +140,19 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
//
// Committing flushes nodes from memory. Subsequent Get calls will load nodes
// from the database.
-func (t *SecureTrie) Commit() (root common.Hash, err error) {
- return t.CommitTo(t.trie.db)
+func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
+ // Write all the pre-images to the actual disk database
+ if len(t.getSecKeyCache()) > 0 {
+ t.trie.db.lock.Lock()
+ for hk, key := range t.secKeyCache {
+ t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key)
+ }
+ t.trie.db.lock.Unlock()
+
+ t.secKeyCache = make(map[string][]byte)
+ }
+ // Commit the trie to its intermediate node database
+ return t.trie.Commit(onleaf)
}
func (t *SecureTrie) Hash() common.Hash {
@@ -167,38 +174,11 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
return t.trie.NodeIterator(start)
}
-// CommitTo writes all nodes and the secure hash pre-images to the given database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory. Subsequent Get calls will load nodes from
-// the trie's database. Calling code must ensure that the changes made to db are
-// written back to the trie's attached database before using the trie.
-func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
- if len(t.getSecKeyCache()) > 0 {
- for hk, key := range t.secKeyCache {
- if err := db.Put(t.secKey([]byte(hk)), key); err != nil {
- return common.Hash{}, err
- }
- }
- t.secKeyCache = make(map[string][]byte)
- }
- return t.trie.CommitTo(db)
-}
-
-// secKey returns the database key for the preimage of key, as an ephemeral buffer.
-// The caller must not hold onto the return value because it will become
-// invalid on the next call to hashKey or secKey.
-func (t *SecureTrie) secKey(key []byte) []byte {
- buf := append(t.secKeyBuf[:0], secureKeyPrefix...)
- buf = append(buf, key...)
- return buf
-}
-
// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte {
- h := newHasher(0, 0)
+ h := newHasher(0, 0, nil)
h.sha.Reset()
h.sha.Write(key)
buf := h.sha.Sum(t.hashKeyBuf[:0])
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index d74102e2a..aedf5a1cd 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -28,16 +28,20 @@ import (
)
func newEmptySecure() *SecureTrie {
- db, _ := ethdb.NewMemDatabase()
- trie, _ := NewSecure(common.Hash{}, db, 0)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ trie, _ := NewSecure(common.Hash{}, triedb, 0)
return trie
}
// makeTestSecureTrie creates a large enough secure trie for testing.
-func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
+func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
// Create an empty trie
- db, _ := ethdb.NewMemDatabase()
- trie, _ := NewSecure(common.Hash{}, db, 0)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ trie, _ := NewSecure(common.Hash{}, triedb, 0)
// Fill it with some arbitrary data
content := make(map[string][]byte)
@@ -58,10 +62,10 @@ func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
trie.Update(key, val)
}
}
- trie.Commit()
+ trie.Commit(nil)
// Return the generated trie
- return db, trie, content
+ return triedb, trie, content
}
func TestSecureDelete(t *testing.T) {
@@ -137,7 +141,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
tries[index].Update(key, val)
}
}
- tries[index].Commit()
+ tries[index].Commit(nil)
}(i)
}
// Wait for all threads to finish
diff --git a/trie/sync.go b/trie/sync.go
index fea10051f..b573a9f73 100644
--- a/trie/sync.go
+++ b/trie/sync.go
@@ -21,6 +21,7 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
"gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
@@ -42,7 +43,7 @@ type request struct {
depth int // Depth level within the trie the node is located to prioritise DFS
deps int // Number of dependencies before allowed to commit this node
- callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch
+ callback LeafCallback // Callback to invoke if a leaf node it reached on this branch
}
// SyncResult is a simple list to return missing nodes along with their request
@@ -67,11 +68,6 @@ func newSyncMemBatch() *syncMemBatch {
}
}
-// TrieSyncLeafCallback is a callback type invoked when a trie sync reaches a
-// leaf node. It's used by state syncing to check if the leaf node requires some
-// further data syncing.
-type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error
-
// TrieSync is the main state trie synchronisation scheduler, which provides yet
// unknown trie hashes to retrieve, accepts node data associated with said hashes
// and reconstructs the trie step by step until all is done.
@@ -83,7 +79,7 @@ type TrieSync struct {
}
// NewTrieSync creates a new trie data download scheduler.
-func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLeafCallback) *TrieSync {
+func NewTrieSync(root common.Hash, database DatabaseReader, callback LeafCallback) *TrieSync {
ts := &TrieSync{
database: database,
membatch: newSyncMemBatch(),
@@ -95,7 +91,7 @@ func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLea
}
// AddSubTrie registers a new trie to the sync code, rooted at the designated parent.
-func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) {
+func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback LeafCallback) {
// Short circuit if the trie is empty or already known
if root == emptyRoot {
return
@@ -217,7 +213,7 @@ func (s *TrieSync) Process(results []SyncResult) (bool, int, error) {
// Commit flushes the data stored in the internal membatch out to persistent
// storage, returning th enumber of items written and any occurred error.
-func (s *TrieSync) Commit(dbw DatabaseWriter) (int, error) {
+func (s *TrieSync) Commit(dbw ethdb.Putter) (int, error) {
// Dump the membatch into a database dbw
for i, key := range s.membatch.order {
if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil {
diff --git a/trie/sync_test.go b/trie/sync_test.go
index ec16a25bd..4a720612b 100644
--- a/trie/sync_test.go
+++ b/trie/sync_test.go
@@ -25,10 +25,11 @@ import (
)
// makeTestTrie create a sample test trie to test node-wise reconstruction.
-func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) {
+func makeTestTrie() (*Database, *Trie, map[string][]byte) {
// Create an empty trie
- db, _ := ethdb.NewMemDatabase()
- trie, _ := New(common.Hash{}, db)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ trie, _ := New(common.Hash{}, triedb)
// Fill it with some arbitrary data
content := make(map[string][]byte)
@@ -49,15 +50,15 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) {
trie.Update(key, val)
}
}
- trie.Commit()
+ trie.Commit(nil)
// Return the generated trie
- return db, trie, content
+ return triedb, trie, content
}
// checkTrieContents cross references a reconstructed trie with an expected data
// content map.
-func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) {
+func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) {
// Check root availability and trie contents
trie, err := New(common.BytesToHash(root), db)
if err != nil {
@@ -74,7 +75,7 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin
}
// checkTrieConsistency checks that all nodes in a trie are indeed present.
-func checkTrieConsistency(db Database, root common.Hash) error {
+func checkTrieConsistency(db *Database, root common.Hash) error {
// Create and iterate a trie rooted in a subnode
trie, err := New(root, db)
if err != nil {
@@ -88,12 +89,18 @@ func checkTrieConsistency(db Database, root common.Hash) error {
// Tests that an empty trie is not scheduled for syncing.
func TestEmptyTrieSync(t *testing.T) {
- emptyA, _ := New(common.Hash{}, nil)
- emptyB, _ := New(emptyRoot, nil)
+ diskdbA, _ := ethdb.NewMemDatabase()
+ triedbA := NewDatabase(diskdbA)
+
+ diskdbB, _ := ethdb.NewMemDatabase()
+ triedbB := NewDatabase(diskdbB)
+
+ emptyA, _ := New(common.Hash{}, triedbA)
+ emptyB, _ := New(emptyRoot, triedbB)
for i, trie := range []*Trie{emptyA, emptyB} {
- db, _ := ethdb.NewMemDatabase()
- if req := NewTrieSync(common.BytesToHash(trie.Root()), db, nil).Missing(1); len(req) != 0 {
+ diskdb, _ := ethdb.NewMemDatabase()
+ if req := NewTrieSync(trie.Hash(), diskdb, nil).Missing(1); len(req) != 0 {
t.Errorf("test %d: content requested for empty trie: %v", i, req)
}
}
@@ -109,14 +116,15 @@ func testIterativeTrieSync(t *testing.T, batch int) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(batch)...)
for len(queue) > 0 {
results := make([]SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -125,13 +133,13 @@ func testIterativeTrieSync(t *testing.T, batch int) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
queue = append(queue[:0], sched.Missing(batch)...)
}
// Cross check that the two tries are in sync
- checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+ checkTrieContents(t, triedb, srcTrie.Root(), srcData)
}
// Tests that the trie scheduler can correctly reconstruct the state even if only
@@ -141,15 +149,16 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(10000)...)
for len(queue) > 0 {
// Sync only half of the scheduled nodes
results := make([]SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -158,13 +167,13 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
queue = append(queue[len(results):], sched.Missing(10000)...)
}
// Cross check that the two tries are in sync
- checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+ checkTrieContents(t, triedb, srcTrie.Root(), srcData)
}
// Tests that given a root hash, a trie can sync iteratively on a single thread,
@@ -178,8 +187,9 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := make(map[common.Hash]struct{})
for _, hash := range sched.Missing(batch) {
@@ -189,7 +199,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order
results := make([]SyncResult, 0, len(queue))
for hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -199,7 +209,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
queue = make(map[common.Hash]struct{})
@@ -208,7 +218,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
}
}
// Cross check that the two tries are in sync
- checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+ checkTrieContents(t, triedb, srcTrie.Root(), srcData)
}
// Tests that the trie scheduler can correctly reconstruct the state even if only
@@ -218,8 +228,9 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := make(map[common.Hash]struct{})
for _, hash := range sched.Missing(10000) {
@@ -229,7 +240,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
// Sync only half of the scheduled nodes, even those in random order
results := make([]SyncResult, 0, len(queue)/2+1)
for hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -243,7 +254,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
for _, result := range results {
@@ -254,7 +265,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
}
}
// Cross check that the two tries are in sync
- checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+ checkTrieContents(t, triedb, srcTrie.Root(), srcData)
}
// Tests that a trie sync will not request nodes multiple times, even if they
@@ -264,8 +275,9 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(0)...)
requested := make(map[common.Hash]struct{})
@@ -273,7 +285,7 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
for len(queue) > 0 {
results := make([]SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -287,13 +299,13 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
queue = append(queue[:0], sched.Missing(0)...)
}
// Cross check that the two tries are in sync
- checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+ checkTrieContents(t, triedb, srcTrie.Root(), srcData)
}
// Tests that at any point in time during a sync, only complete sub-tries are in
@@ -303,8 +315,9 @@ func TestIncompleteTrieSync(t *testing.T) {
srcDb, srcTrie, _ := makeTestTrie()
// Create a destination trie and sync with the scheduler
- dstDb, _ := ethdb.NewMemDatabase()
- sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+ sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
added := []common.Hash{}
queue := append([]common.Hash{}, sched.Missing(1)...)
@@ -312,7 +325,7 @@ func TestIncompleteTrieSync(t *testing.T) {
// Fetch a batch of trie nodes
results := make([]SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcDb.Node(hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -322,7 +335,7 @@ func TestIncompleteTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err)
}
- if index, err := sched.Commit(dstDb); err != nil {
+ if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err)
}
for _, result := range results {
@@ -330,7 +343,7 @@ func TestIncompleteTrieSync(t *testing.T) {
}
// Check that all known sub-tries in the synced trie are complete
for _, root := range added {
- if err := checkTrieConsistency(dstDb, root); err != nil {
+ if err := checkTrieConsistency(triedb, root); err != nil {
t.Fatalf("trie inconsistent: %v", err)
}
}
@@ -340,12 +353,12 @@ func TestIncompleteTrieSync(t *testing.T) {
// Sanity check that removing any node from the database is detected
for _, node := range added[1:] {
key := node.Bytes()
- value, _ := dstDb.Get(key)
+ value, _ := diskdb.Get(key)
- dstDb.Delete(key)
- if err := checkTrieConsistency(dstDb, added[0]); err == nil {
+ diskdb.Delete(key)
+ if err := checkTrieConsistency(triedb, added[0]); err == nil {
t.Fatalf("trie inconsistency not caught, missing: %x", key)
}
- dstDb.Put(key, value)
+ diskdb.Put(key, value)
}
}
diff --git a/trie/trie.go b/trie/trie.go
index 8fe98d835..e37a1ae10 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -22,16 +22,17 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/rcrowley/go-metrics"
)
var (
- // This is the known root hash of an empty trie.
+ // emptyRoot is the known root hash of an empty trie.
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
- // This is the known hash of an empty state trie entry.
- emptyState common.Hash
+
+ // emptyState is the known hash of an empty state trie entry.
+ emptyState = crypto.Keccak256Hash(nil)
)
var (
@@ -53,29 +54,10 @@ func CacheUnloads() int64 {
return cacheUnloadCounter.Count()
}
-func init() {
- sha3.NewKeccak256().Sum(emptyState[:0])
-}
-
-// Database must be implemented by backing stores for the trie.
-type Database interface {
- DatabaseReader
- DatabaseWriter
-}
-
-// DatabaseReader wraps the Get method of a backing store for the trie.
-type DatabaseReader interface {
- Get(key []byte) (value []byte, err error)
- Has(key []byte) (bool, error)
-}
-
-// DatabaseWriter wraps the Put method of a backing store for the trie.
-type DatabaseWriter interface {
- // Put stores the mapping key->value in the database.
- // Implementations must not hold onto the value bytes, the trie
- // will reuse the slice across calls to Put.
- Put(key, value []byte) error
-}
+// LeafCallback is a callback type invoked when a trie operation reaches a leaf
+// node. It's used by state sync and commit to allow handling external references
+// between account and storage tries.
+type LeafCallback func(leaf []byte, parent common.Hash) error
// Trie is a Merkle Patricia Trie.
// The zero value is an empty trie with no database.
@@ -83,8 +65,8 @@ type DatabaseWriter interface {
//
// Trie is not safe for concurrent use.
type Trie struct {
+ db *Database
root node
- db Database
originalRoot common.Hash
// Cache generation values.
@@ -111,12 +93,15 @@ func (t *Trie) newFlag() nodeFlag {
// trie is initially empty and does not require a database. Otherwise,
// New will panic if db is nil and returns a MissingNodeError if root does
// not exist in the database. Accessing the trie loads nodes from db on demand.
-func New(root common.Hash, db Database) (*Trie, error) {
- trie := &Trie{db: db, originalRoot: root}
+func New(root common.Hash, db *Database) (*Trie, error) {
+ if db == nil {
+ panic("trie.New called without a database")
+ }
+ trie := &Trie{
+ db: db,
+ originalRoot: root,
+ }
if (root != common.Hash{}) && root != emptyRoot {
- if db == nil {
- panic("trie.New: cannot use existing root without a database")
- }
rootnode, err := trie.resolveHash(root[:], nil)
if err != nil {
return nil, err
@@ -447,12 +432,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
cacheMissCounter.Inc(1)
- enc, err := t.db.Get(n)
+ hash := common.BytesToHash(n)
+
+ enc, err := t.db.Node(hash)
if err != nil || enc == nil {
- return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix}
+ return nil, &MissingNodeError{NodeHash: hash, Path: prefix}
}
- dec := mustDecodeNode(n, enc, t.cachegen)
- return dec, nil
+ return mustDecodeNode(n, enc, t.cachegen), nil
}
// Root returns the root hash of the trie.
@@ -462,32 +448,18 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() }
// Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash {
- hash, cached, _ := t.hashRoot(nil)
+ hash, cached, _ := t.hashRoot(nil, nil)
t.root = cached
return common.BytesToHash(hash.(hashNode))
}
-// Commit writes all nodes to the trie's database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory.
-// Subsequent Get calls will load nodes from the database.
-func (t *Trie) Commit() (root common.Hash, err error) {
+// Commit writes all nodes to the trie's memory database, tracking the internal
+// and external (for account tries) references.
+func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
if t.db == nil {
- panic("Commit called on trie with nil database")
+ panic("commit called on trie with nil database")
}
- return t.CommitTo(t.db)
-}
-
-// CommitTo writes all nodes to the given database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory. Subsequent Get calls will
-// load nodes from the trie's database. Calling code must ensure that
-// the changes made to db are written back to the trie's attached
-// database before using the trie.
-func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
- hash, cached, err := t.hashRoot(db)
+ hash, cached, err := t.hashRoot(t.db, onleaf)
if err != nil {
return common.Hash{}, err
}
@@ -496,11 +468,11 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
return common.BytesToHash(hash.(hashNode)), nil
}
-func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
+func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) {
if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil
}
- h := newHasher(t.cachegen, t.cachelimit)
+ h := newHasher(t.cachegen, t.cachelimit, onleaf)
defer returnHasherToPool(h)
return h.hash(t.root, db, true)
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 1e28c3bc4..997222628 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -43,8 +43,8 @@ func init() {
// Used for testing
func newEmpty() *Trie {
- db, _ := ethdb.NewMemDatabase()
- trie, _ := New(common.Hash{}, db)
+ diskdb, _ := ethdb.NewMemDatabase()
+ trie, _ := New(common.Hash{}, NewDatabase(diskdb))
return trie
}
@@ -68,8 +68,8 @@ func TestNull(t *testing.T) {
}
func TestMissingRoot(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db)
+ diskdb, _ := ethdb.NewMemDatabase()
+ trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb))
if trie != nil {
t.Error("New returned non-nil trie for invalid root")
}
@@ -78,70 +78,75 @@ func TestMissingRoot(t *testing.T) {
}
}
-func TestMissingNode(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- trie, _ := New(common.Hash{}, db)
+func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) }
+func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
+
+func testMissingNode(t *testing.T, memonly bool) {
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ trie, _ := New(common.Hash{}, triedb)
updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
- root, _ := trie.Commit()
+ root, _ := trie.Commit(nil)
+ if !memonly {
+ triedb.Commit(root, true)
+ }
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err := trie.TryGet([]byte("120000"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("120099"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
err = trie.TryDelete([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9"))
+ hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+ if memonly {
+ delete(triedb.nodes, hash)
+ } else {
+ diskdb.Delete(hash[:])
+ }
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
-
- trie, _ = New(root, db)
+ trie, _ = New(root, triedb)
err = trie.TryDelete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
@@ -165,7 +170,7 @@ func TestInsert(t *testing.T) {
updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
- root, err := trie.Commit()
+ root, err := trie.Commit(nil)
if err != nil {
t.Fatalf("commit error: %v", err)
}
@@ -194,7 +199,7 @@ func TestGet(t *testing.T) {
if i == 1 {
return
}
- trie.Commit()
+ trie.Commit(nil)
}
}
@@ -263,7 +268,7 @@ func TestReplication(t *testing.T) {
for _, val := range vals {
updateString(trie, val.k, val.v)
}
- exp, err := trie.Commit()
+ exp, err := trie.Commit(nil)
if err != nil {
t.Fatalf("commit error: %v", err)
}
@@ -278,7 +283,7 @@ func TestReplication(t *testing.T) {
t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
}
}
- hash, err := trie2.Commit()
+ hash, err := trie2.Commit(nil)
if err != nil {
t.Fatalf("commit error: %v", err)
}
@@ -314,7 +319,7 @@ func TestLargeValue(t *testing.T) {
}
type countingDB struct {
- Database
+ ethdb.Database
gets map[string]int
}
@@ -332,19 +337,20 @@ func TestCacheUnload(t *testing.T) {
key2 := "---some other branch"
updateString(trie, key1, "this is the branch of key1.")
updateString(trie, key2, "this is the branch of key2.")
- root, _ := trie.Commit()
+
+ root, _ := trie.Commit(nil)
+ trie.db.Commit(root, true)
// Commit the trie repeatedly and access key1.
// The branch containing it is loaded from DB exactly two times:
// in the 0th and 6th iteration.
- db := &countingDB{Database: trie.db, gets: make(map[string]int)}
- trie, _ = New(root, db)
+ db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)}
+ trie, _ = New(root, NewDatabase(db))
trie.SetCacheLimit(5)
for i := 0; i < 12; i++ {
getString(trie, key1)
- trie.Commit()
+ trie.Commit(nil)
}
-
// Check that it got loaded two times.
for dbkey, count := range db.gets {
if count != 2 {
@@ -407,8 +413,10 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
}
func runRandTest(rt randTest) bool {
- db, _ := ethdb.NewMemDatabase()
- tr, _ := New(common.Hash{}, db)
+ diskdb, _ := ethdb.NewMemDatabase()
+ triedb := NewDatabase(diskdb)
+
+ tr, _ := New(common.Hash{}, triedb)
values := make(map[string]string) // tracks content of the trie
for i, step := range rt {
@@ -426,23 +434,23 @@ func runRandTest(rt randTest) bool {
rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
}
case opCommit:
- _, rt[i].err = tr.Commit()
+ _, rt[i].err = tr.Commit(nil)
case opHash:
tr.Hash()
case opReset:
- hash, err := tr.Commit()
+ hash, err := tr.Commit(nil)
if err != nil {
rt[i].err = err
return false
}
- newtr, err := New(hash, db)
+ newtr, err := New(hash, triedb)
if err != nil {
rt[i].err = err
return false
}
tr = newtr
case opItercheckhash:
- checktr, _ := New(common.Hash{}, nil)
+ checktr, _ := New(common.Hash{}, triedb)
it := NewIterator(tr.NodeIterator(nil))
for it.Next() {
checktr.Update(it.Key, it.Value)
@@ -524,7 +532,7 @@ func benchGet(b *testing.B, commit bool) {
}
binary.LittleEndian.PutUint64(k, benchElemCount/2)
if commit {
- trie.Commit()
+ trie.Commit(nil)
}
b.ResetTimer()
@@ -534,7 +542,7 @@ func benchGet(b *testing.B, commit bool) {
b.StopTimer()
if commit {
- ldb := trie.db.(*ethdb.LDBDatabase)
+ ldb := trie.db.diskdb.(*ethdb.LDBDatabase)
ldb.Close()
os.RemoveAll(ldb.Path())
}
@@ -585,16 +593,16 @@ func BenchmarkHash(b *testing.B) {
trie.Hash()
}
-func tempDB() (string, Database) {
+func tempDB() (string, *Database) {
dir, err := ioutil.TempDir("", "trie-bench")
if err != nil {
panic(fmt.Sprintf("can't create temporary directory: %v", err))
}
- db, err := ethdb.NewLDBDatabase(dir, 256, 0)
+ diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0)
if err != nil {
panic(fmt.Sprintf("can't create temporary database: %v", err))
}
- return dir, db
+ return dir, NewDatabase(diskdb)
}
func getString(trie *Trie, k string) []byte {