diff options
96 files changed, 3244 insertions, 2798 deletions
diff --git a/build/update-license.go b/build/update-license.go index e28005cbd..04f52a13c 100644 --- a/build/update-license.go +++ b/build/update-license.go @@ -46,9 +46,10 @@ var ( skipPrefixes = []string{ // boring stuff "Godeps/", "tests/files/", "build/", - // don't relicense vendored packages + // don't relicense vendored sources "crypto/sha3/", "crypto/ecies/", "logger/glog/", "crypto/curve.go", + "trie/arc.go", } // paths with this prefix are licensed as GPL. all other files are LGPL. diff --git a/cmd/evm/main.go b/cmd/evm/main.go index 243dd6266..bf24da982 100644 --- a/cmd/evm/main.go +++ b/cmd/evm/main.go @@ -80,12 +80,17 @@ var ( Name: "sysstat", Usage: "display system stats", } + VerbosityFlag = cli.IntFlag{ + Name: "verbosity", + Usage: "sets the verbosity level", + } ) func init() { app = utils.NewApp("0.2", "the evm command line interface") app.Flags = []cli.Flag{ DebugFlag, + VerbosityFlag, ForceJitFlag, DisableJitFlag, SysStatFlag, @@ -105,6 +110,7 @@ func run(ctx *cli.Context) { vm.EnableJit = !ctx.GlobalBool(DisableJitFlag.Name) glog.SetToStderr(true) + glog.SetV(ctx.GlobalInt(VerbosityFlag.Name)) db, _ := ethdb.NewMemDatabase() statedb := state.New(common.Hash{}, db) @@ -179,18 +185,20 @@ func NewEnv(state *state.StateDB, transactor common.Address, value *big.Int) *VM } } -func (self *VMEnv) State() *state.StateDB { return self.state } -func (self *VMEnv) Origin() common.Address { return *self.transactor } -func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } -func (self *VMEnv) Coinbase() common.Address { return *self.transactor } -func (self *VMEnv) Time() *big.Int { return self.time } -func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } -func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } -func (self *VMEnv) Value() *big.Int { return self.value } -func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } -func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } -func (self *VMEnv) Depth() int { return 0 } -func (self *VMEnv) SetDepth(i int) { self.depth = i } +func (self *VMEnv) Db() vm.Database { return self.state } +func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } +func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } +func (self *VMEnv) Origin() common.Address { return *self.transactor } +func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } +func (self *VMEnv) Coinbase() common.Address { return *self.transactor } +func (self *VMEnv) Time() *big.Int { return self.time } +func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } +func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } +func (self *VMEnv) Value() *big.Int { return self.value } +func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } +func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } +func (self *VMEnv) Depth() int { return 0 } +func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) GetHash(n uint64) common.Hash { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { return self.block.Hash() @@ -203,34 +211,24 @@ func (self *VMEnv) AddStructLog(log vm.StructLog) { func (self *VMEnv) StructLogs() []vm.StructLog { return self.logs } -func (self *VMEnv) AddLog(log *state.Log) { +func (self *VMEnv) AddLog(log *vm.Log) { self.state.AddLog(log) } -func (self *VMEnv) CanTransfer(from vm.Account, balance *big.Int) bool { - return from.Balance().Cmp(balance) >= 0 +func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { + return self.state.GetBalance(from).Cmp(balance) >= 0 } func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) error { - return vm.Transfer(from, to, amount) -} - -func (self *VMEnv) vm(addr *common.Address, data []byte, gas, price, value *big.Int) *core.Execution { - return core.NewExecution(self, addr, data, gas, price, value) + return core.Transfer(from, to, amount) } -func (self *VMEnv) Call(caller vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { - exe := self.vm(&addr, data, gas, price, value) - ret, err := exe.Call(addr, caller) - self.Gas = exe.Gas - - return ret, err +func (self *VMEnv) Call(caller vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { + self.Gas = gas + return core.Call(self, caller, addr, data, gas, price, value) } -func (self *VMEnv) CallCode(caller vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { - a := caller.Address() - exe := self.vm(&a, data, gas, price, value) - return exe.Call(addr, caller) +func (self *VMEnv) CallCode(caller vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { + return core.CallCode(self, caller, addr, data, gas, price, value) } -func (self *VMEnv) Create(caller vm.ContextRef, data []byte, gas, price, value *big.Int) ([]byte, error, vm.ContextRef) { - exe := self.vm(nil, data, gas, price, value) - return exe.Create(caller) +func (self *VMEnv) Create(caller vm.ContractRef, data []byte, gas, price, value *big.Int) ([]byte, common.Address, error) { + return core.Create(self, caller, data, gas, price, value) } diff --git a/cmd/geth/blocktestcmd.go b/cmd/geth/blocktestcmd.go index d6195e025..e0a5becdc 100644 --- a/cmd/geth/blocktestcmd.go +++ b/cmd/geth/blocktestcmd.go @@ -118,7 +118,7 @@ func runOneBlockTest(ctx *cli.Context, test *tests.BlockTest) (*eth.Ethereum, er return ethereum, fmt.Errorf("InsertPreState: %v", err) } - cm := ethereum.ChainManager() + cm := ethereum.BlockChain() validBlocks, err := test.TryBlocksInsert(cm) if err != nil { return ethereum, fmt.Errorf("Block Test load error: %v", err) diff --git a/cmd/geth/js_test.go b/cmd/geth/js_test.go index 1f5b28e3a..2ad3d669c 100644 --- a/cmd/geth/js_test.go +++ b/cmd/geth/js_test.go @@ -196,7 +196,7 @@ func TestBlockChain(t *testing.T) { tmpfile := filepath.Join(extmp, "export.chain") tmpfileq := strconv.Quote(tmpfile) - ethereum.ChainManager().Reset() + ethereum.BlockChain().Reset() checkEvalJSON(t, repl, `admin.exportChain(`+tmpfileq+`)`, `true`) if _, err := os.Stat(tmpfile); err != nil { diff --git a/cmd/geth/main.go b/cmd/geth/main.go index daffda30c..fa9beafd0 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -48,9 +48,9 @@ import ( const ( ClientIdentifier = "Geth" - Version = "1.2.0" + Version = "1.3.0-dev" VersionMajor = 1 - VersionMinor = 2 + VersionMinor = 3 VersionPatch = 0 ) @@ -390,7 +390,7 @@ func makeDefaultExtra() []byte { } func run(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) if ctx.GlobalBool(utils.OlympicFlag.Name) { utils.InitOlympic() } @@ -409,7 +409,7 @@ func run(ctx *cli.Context) { } func attach(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) var client comms.EthereumClient var err error @@ -441,7 +441,7 @@ func attach(ctx *cli.Context) { } func console(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) cfg := utils.MakeEthConfig(ClientIdentifier, nodeNameVersion, ctx) cfg.ExtraData = makeExtra(ctx) @@ -475,7 +475,7 @@ func console(ctx *cli.Context) { } func execJSFiles(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) cfg := utils.MakeEthConfig(ClientIdentifier, nodeNameVersion, ctx) ethereum, err := eth.New(cfg) @@ -502,7 +502,7 @@ func execJSFiles(ctx *cli.Context) { } func unlockAccount(ctx *cli.Context, am *accounts.Manager, addr string, i int) (addrHex, auth string) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) var err error addrHex, err = utils.ParamToAddress(addr, am) @@ -527,7 +527,7 @@ func unlockAccount(ctx *cli.Context, am *accounts.Manager, addr string, i int) ( } func blockRecovery(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) arg := ctx.Args().First() if len(ctx.Args()) < 1 && len(arg) > 0 { @@ -593,7 +593,7 @@ func startEth(ctx *cli.Context, eth *eth.Ethereum) { } func accountList(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) am := utils.MakeAccountManager(ctx) accts, err := am.Accounts() @@ -643,7 +643,7 @@ func getPassPhrase(ctx *cli.Context, desc string, confirmation bool, i int) (pas } func accountCreate(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) am := utils.MakeAccountManager(ctx) passphrase := getPassPhrase(ctx, "Your new account is locked with a password. Please give a password. Do not forget this password.", true, 0) @@ -655,7 +655,7 @@ func accountCreate(ctx *cli.Context) { } func accountUpdate(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) am := utils.MakeAccountManager(ctx) arg := ctx.Args().First() @@ -672,7 +672,7 @@ func accountUpdate(ctx *cli.Context) { } func importWallet(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) keyfile := ctx.Args().First() if len(keyfile) == 0 { @@ -694,7 +694,7 @@ func importWallet(ctx *cli.Context) { } func accountImport(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) keyfile := ctx.Args().First() if len(keyfile) == 0 { @@ -710,7 +710,7 @@ func accountImport(ctx *cli.Context) { } func makedag(ctx *cli.Context) { - utils.CheckLegalese(ctx.GlobalString(utils.DataDirFlag.Name)) + utils.CheckLegalese(utils.MustDataDir(ctx)) args := ctx.Args() wrongArgs := func() { diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 983762db8..5e4bfc937 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -169,7 +169,7 @@ func FormatTransactionData(data string) []byte { return d } -func ImportChain(chain *core.ChainManager, fn string) error { +func ImportChain(chain *core.BlockChain, fn string) error { // Watch for Ctrl-C while the import is running. // If a signal is received, the import will stop at the next batch. interrupt := make(chan os.Signal, 1) @@ -244,7 +244,7 @@ func ImportChain(chain *core.ChainManager, fn string) error { return nil } -func hasAllBlocks(chain *core.ChainManager, bs []*types.Block) bool { +func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool { for _, b := range bs { if !chain.HasBlock(b.Hash()) { return false @@ -253,21 +253,21 @@ func hasAllBlocks(chain *core.ChainManager, bs []*types.Block) bool { return true } -func ExportChain(chainmgr *core.ChainManager, fn string) error { +func ExportChain(blockchain *core.BlockChain, fn string) error { glog.Infoln("Exporting blockchain to", fn) fh, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.ModePerm) if err != nil { return err } defer fh.Close() - if err := chainmgr.Export(fh); err != nil { + if err := blockchain.Export(fh); err != nil { return err } glog.Infoln("Exported blockchain to", fn) return nil } -func ExportAppendChain(chainmgr *core.ChainManager, fn string, first uint64, last uint64) error { +func ExportAppendChain(blockchain *core.BlockChain, fn string, first uint64, last uint64) error { glog.Infoln("Exporting blockchain to", fn) // TODO verify mode perms fh, err := os.OpenFile(fn, os.O_CREATE|os.O_APPEND|os.O_WRONLY, os.ModePerm) @@ -275,7 +275,7 @@ func ExportAppendChain(chainmgr *core.ChainManager, fn string, first uint64, las return err } defer fh.Close() - if err := chainmgr.ExportN(fh, first, last); err != nil { + if err := blockchain.ExportN(fh, first, last); err != nil { return err } glog.Infoln("Exported blockchain to", fn) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index b45ef0af2..dea43bc5c 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -416,7 +416,7 @@ func MakeEthConfig(clientID, version string, ctx *cli.Context) *eth.Config { cfg := ð.Config{ Name: common.MakeName(clientID, version), - DataDir: ctx.GlobalString(DataDirFlag.Name), + DataDir: MustDataDir(ctx), GenesisNonce: ctx.GlobalInt(GenesisNonceFlag.Name), GenesisFile: ctx.GlobalString(GenesisFileFlag.Name), BlockChainVersion: ctx.GlobalInt(BlockchainVersionFlag.Name), @@ -508,8 +508,8 @@ func SetupEth(ctx *cli.Context) { } // MakeChain creates a chain manager from set command line flags. -func MakeChain(ctx *cli.Context) (chain *core.ChainManager, chainDb ethdb.Database) { - datadir := ctx.GlobalString(DataDirFlag.Name) +func MakeChain(ctx *cli.Context) (chain *core.BlockChain, chainDb ethdb.Database) { + datadir := MustDataDir(ctx) cache := ctx.GlobalInt(CacheFlag.Name) var err error @@ -527,7 +527,7 @@ func MakeChain(ctx *cli.Context) (chain *core.ChainManager, chainDb ethdb.Databa eventMux := new(event.TypeMux) pow := ethash.New() //genesis := core.GenesisBlock(uint64(ctx.GlobalInt(GenesisNonceFlag.Name)), blockDB) - chain, err = core.NewChainManager(chainDb, pow, eventMux) + chain, err = core.NewBlockChain(chainDb, pow, eventMux) if err != nil { Fatalf("Could not start chainmanager: %v", err) } @@ -539,11 +539,21 @@ func MakeChain(ctx *cli.Context) (chain *core.ChainManager, chainDb ethdb.Databa // MakeChain creates an account manager from set command line flags. func MakeAccountManager(ctx *cli.Context) *accounts.Manager { - dataDir := ctx.GlobalString(DataDirFlag.Name) + dataDir := MustDataDir(ctx) ks := crypto.NewKeyStorePassphrase(filepath.Join(dataDir, "keystore")) return accounts.NewManager(ks) } +// MustDataDir retrieves the currently requested data directory, terminating if +// none (or the empty string) is specified. +func MustDataDir(ctx *cli.Context) string { + if path := ctx.GlobalString(DataDirFlag.Name); path != "" { + return path + } + Fatalf("Cannot determine default data directory, please set manually (--datadir)") + return "" +} + func IpcSocketPath(ctx *cli.Context) (ipcpath string) { if runtime.GOOS == "windows" { ipcpath = common.DefaultIpcPath() diff --git a/common/path.go b/common/path.go index 8b3c7d14b..1253c424c 100644 --- a/common/path.go +++ b/common/path.go @@ -100,14 +100,24 @@ func DefaultAssetPath() string { } func DefaultDataDir() string { - usr, _ := user.Current() - if runtime.GOOS == "darwin" { - return filepath.Join(usr.HomeDir, "Library", "Ethereum") - } else if runtime.GOOS == "windows" { - return filepath.Join(usr.HomeDir, "AppData", "Roaming", "Ethereum") + // Try to place the data folder in the user's home dir + var home string + if usr, err := user.Current(); err == nil { + home = usr.HomeDir } else { - return filepath.Join(usr.HomeDir, ".ethereum") + home = os.Getenv("HOME") } + if home != "" { + if runtime.GOOS == "darwin" { + return filepath.Join(home, "Library", "Ethereum") + } else if runtime.GOOS == "windows" { + return filepath.Join(home, "AppData", "Roaming", "Ethereum") + } else { + return filepath.Join(home, ".ethereum") + } + } + // As we cannot guess a stable location, return empty and handle later + return "" } func DefaultIpcPath() string { diff --git a/core/bench_test.go b/core/bench_test.go index def4f0d2a..27f3e3158 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -168,7 +168,7 @@ func benchInsertChain(b *testing.B, disk bool, gen func(int, *BlockGen)) { // Time the insertion of the new chain. // State and blocks are stored in the same DB. evmux := new(event.TypeMux) - chainman, _ := NewChainManager(db, FakePow{}, evmux) + chainman, _ := NewBlockChain(db, FakePow{}, evmux) chainman.SetProcessor(NewBlockProcessor(db, FakePow{}, chainman, evmux)) defer chainman.Stop() b.ReportAllocs() diff --git a/core/block_processor.go b/core/block_processor.go index 238b2db95..783e15687 100644 --- a/core/block_processor.go +++ b/core/block_processor.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -46,7 +47,7 @@ type BlockProcessor struct { // Mutex for locking the block processor. Blocks can only be handled one at a time mutex sync.Mutex // Canonical block chain - bc *ChainManager + bc *BlockChain // non-persistent key/value memory storage mem map[string]*big.Int // Proof of work used for validating @@ -69,12 +70,12 @@ type GasPool interface { SubGas(gas, price *big.Int) error } -func NewBlockProcessor(db ethdb.Database, pow pow.PoW, chainManager *ChainManager, eventMux *event.TypeMux) *BlockProcessor { +func NewBlockProcessor(db ethdb.Database, pow pow.PoW, blockchain *BlockChain, eventMux *event.TypeMux) *BlockProcessor { sm := &BlockProcessor{ chainDb: db, mem: make(map[string]*big.Int), Pow: pow, - bc: chainManager, + bc: blockchain, eventMux: eventMux, } return sm @@ -100,10 +101,8 @@ func (self *BlockProcessor) ApplyTransaction(gp GasPool, statedb *state.StateDB, } // Update the state with pending changes - statedb.SyncIntermediate() - usedGas.Add(usedGas, gas) - receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas) + receipt := types.NewReceipt(statedb.IntermediateRoot().Bytes(), usedGas) receipt.TxHash = tx.Hash() receipt.GasUsed = new(big.Int).Set(gas) if MessageCreatesContract(tx) { @@ -125,7 +124,7 @@ func (self *BlockProcessor) ApplyTransaction(gp GasPool, statedb *state.StateDB, return receipt, gas, err } -func (self *BlockProcessor) ChainManager() *ChainManager { +func (self *BlockProcessor) BlockChain() *BlockChain { return self.bc } @@ -165,7 +164,7 @@ func (self *BlockProcessor) ApplyTransactions(gp GasPool, statedb *state.StateDB return receipts, err } -func (sm *BlockProcessor) RetryProcess(block *types.Block) (logs state.Logs, err error) { +func (sm *BlockProcessor) RetryProcess(block *types.Block) (logs vm.Logs, err error) { // Processing a blocks may never happen simultaneously sm.mutex.Lock() defer sm.mutex.Unlock() @@ -190,7 +189,7 @@ func (sm *BlockProcessor) RetryProcess(block *types.Block) (logs state.Logs, err // Process block will attempt to process the given block's transactions and applies them // on top of the block's parent state (given it exists) and will return wether it was // successful or not. -func (sm *BlockProcessor) Process(block *types.Block) (logs state.Logs, receipts types.Receipts, err error) { +func (sm *BlockProcessor) Process(block *types.Block) (logs vm.Logs, receipts types.Receipts, err error) { // Processing a blocks may never happen simultaneously sm.mutex.Lock() defer sm.mutex.Unlock() @@ -206,7 +205,7 @@ func (sm *BlockProcessor) Process(block *types.Block) (logs state.Logs, receipts return sm.processWithParent(block, parent) } -func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs state.Logs, receipts types.Receipts, err error) { +func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs vm.Logs, receipts types.Receipts, err error) { // Create a new state based on the parent's root (e.g., create copy) state := state.New(parent.Root(), sm.chainDb) header := block.Header() @@ -265,16 +264,16 @@ func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs st // Accumulate static rewards; block reward, uncle's and uncle inclusion. AccumulateRewards(state, header, uncles) - // Commit state objects/accounts to a temporary trie (does not save) - // used to calculate the state root. - state.SyncObjects() - if header.Root != state.Root() { - err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root()) - return + // Commit state objects/accounts to a database batch and calculate + // the state root. The database is not modified if the root + // doesn't match. + root, batch := state.CommitBatch() + if header.Root != root { + return nil, nil, fmt.Errorf("invalid merkle root: header=%x computed=%x", header.Root, root) } - // Sync the current block's state to the database - state.Sync() + // Execute the database writes. + batch.Write() return state.Logs(), receipts, nil } @@ -348,7 +347,7 @@ func (sm *BlockProcessor) VerifyUncles(statedb *state.StateDB, block, parent *ty // GetBlockReceipts returns the receipts beloniging to the block hash func (sm *BlockProcessor) GetBlockReceipts(bhash common.Hash) types.Receipts { - if block := sm.ChainManager().GetBlock(bhash); block != nil { + if block := sm.BlockChain().GetBlock(bhash); block != nil { return GetBlockReceipts(sm.chainDb, block.Hash()) } @@ -358,7 +357,7 @@ func (sm *BlockProcessor) GetBlockReceipts(bhash common.Hash) types.Receipts { // GetLogs returns the logs of the given block. This method is using a two step approach // where it tries to get it from the (updated) method which gets them from the receipts or // the depricated way by re-processing the block. -func (sm *BlockProcessor) GetLogs(block *types.Block) (logs state.Logs, err error) { +func (sm *BlockProcessor) GetLogs(block *types.Block) (logs vm.Logs, err error) { receipts := GetBlockReceipts(sm.chainDb, block.Hash()) // coalesce logs for _, receipt := range receipts { diff --git a/core/block_processor_test.go b/core/block_processor_test.go index 538cf4ee5..ba8bd7bcd 100644 --- a/core/block_processor_test.go +++ b/core/block_processor_test.go @@ -24,21 +24,22 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/pow/ezp" ) -func proc() (*BlockProcessor, *ChainManager) { +func proc() (*BlockProcessor, *BlockChain) { db, _ := ethdb.NewMemDatabase() var mux event.TypeMux WriteTestNetGenesisBlock(db, 0) - chainMan, err := NewChainManager(db, thePow(), &mux) + blockchain, err := NewBlockChain(db, thePow(), &mux) if err != nil { fmt.Println(err) } - return NewBlockProcessor(db, ezp.New(), chainMan, &mux), chainMan + return NewBlockProcessor(db, ezp.New(), blockchain, &mux), blockchain } func TestNumber(t *testing.T) { @@ -69,7 +70,7 @@ func TestPutReceipt(t *testing.T) { hash[0] = 2 receipt := new(types.Receipt) - receipt.SetLogs(state.Logs{&state.Log{ + receipt.SetLogs(vm.Logs{&vm.Log{ Address: addr, Topics: []common.Hash{hash}, Data: []byte("hi"), diff --git a/core/chain_manager.go b/core/blockchain.go index 383fce70c..ad545cf69 100644 --- a/core/chain_manager.go +++ b/core/blockchain.go @@ -55,11 +55,9 @@ const ( blockCacheLimit = 256 maxFutureBlocks = 256 maxTimeFutureBlocks = 30 - checkpointLimit = 200 ) -type ChainManager struct { - //eth EthManager +type BlockChain struct { chainDb ethdb.Database processor types.BlockProcessor eventMux *event.TypeMux @@ -69,7 +67,6 @@ type ChainManager struct { chainmu sync.RWMutex tsmu sync.RWMutex - checkpoint int // checkpoint counts towards the new checkpoint td *big.Int currentBlock *types.Block currentGasLimit *big.Int @@ -90,7 +87,7 @@ type ChainManager struct { pow pow.PoW } -func NewChainManager(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*ChainManager, error) { +func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*BlockChain, error) { headerCache, _ := lru.New(headerCacheLimit) bodyCache, _ := lru.New(bodyCacheLimit) bodyRLPCache, _ := lru.New(bodyCacheLimit) @@ -98,7 +95,7 @@ func NewChainManager(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (* blockCache, _ := lru.New(blockCacheLimit) futureBlocks, _ := lru.New(maxFutureBlocks) - bc := &ChainManager{ + bc := &BlockChain{ chainDb: chainDb, eventMux: mux, quit: make(chan struct{}), @@ -144,7 +141,7 @@ func NewChainManager(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (* return bc, nil } -func (bc *ChainManager) SetHead(head *types.Block) { +func (bc *BlockChain) SetHead(head *types.Block) { bc.mu.Lock() defer bc.mu.Unlock() @@ -163,80 +160,55 @@ func (bc *ChainManager) SetHead(head *types.Block) { bc.setLastState() } -func (self *ChainManager) Td() *big.Int { +func (self *BlockChain) Td() *big.Int { self.mu.RLock() defer self.mu.RUnlock() return new(big.Int).Set(self.td) } -func (self *ChainManager) GasLimit() *big.Int { +func (self *BlockChain) GasLimit() *big.Int { self.mu.RLock() defer self.mu.RUnlock() return self.currentBlock.GasLimit() } -func (self *ChainManager) LastBlockHash() common.Hash { +func (self *BlockChain) LastBlockHash() common.Hash { self.mu.RLock() defer self.mu.RUnlock() return self.currentBlock.Hash() } -func (self *ChainManager) CurrentBlock() *types.Block { +func (self *BlockChain) CurrentBlock() *types.Block { self.mu.RLock() defer self.mu.RUnlock() return self.currentBlock } -func (self *ChainManager) Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) { +func (self *BlockChain) Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) { self.mu.RLock() defer self.mu.RUnlock() return new(big.Int).Set(self.td), self.currentBlock.Hash(), self.genesisBlock.Hash() } -func (self *ChainManager) SetProcessor(proc types.BlockProcessor) { +func (self *BlockChain) SetProcessor(proc types.BlockProcessor) { self.processor = proc } -func (self *ChainManager) State() *state.StateDB { +func (self *BlockChain) State() *state.StateDB { return state.New(self.CurrentBlock().Root(), self.chainDb) } -func (bc *ChainManager) recover() bool { - data, _ := bc.chainDb.Get([]byte("checkpoint")) - if len(data) != 0 { - block := bc.GetBlock(common.BytesToHash(data)) - if block != nil { - if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil { - glog.Fatalf("failed to write database head number: %v", err) - } - if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil { - glog.Fatalf("failed to write database head hash: %v", err) - } - bc.currentBlock = block - return true - } - } - return false -} - -func (bc *ChainManager) setLastState() error { +func (bc *BlockChain) setLastState() error { head := GetHeadBlockHash(bc.chainDb) if head != (common.Hash{}) { block := bc.GetBlock(head) if block != nil { bc.currentBlock = block - } else { - glog.Infof("LastBlock (%x) not found. Recovering...\n", head) - if bc.recover() { - glog.Infof("Recover successful") - } else { - glog.Fatalf("Recover failed. Please report") - } } } else { bc.Reset() @@ -252,13 +224,13 @@ func (bc *ChainManager) setLastState() error { } // Reset purges the entire blockchain, restoring it to its genesis state. -func (bc *ChainManager) Reset() { +func (bc *BlockChain) Reset() { bc.ResetWithGenesisBlock(bc.genesisBlock) } // ResetWithGenesisBlock purges the entire blockchain, restoring it to the // specified genesis state. -func (bc *ChainManager) ResetWithGenesisBlock(genesis *types.Block) { +func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) { bc.mu.Lock() defer bc.mu.Unlock() @@ -286,7 +258,7 @@ func (bc *ChainManager) ResetWithGenesisBlock(genesis *types.Block) { } // Export writes the active chain to the given writer. -func (self *ChainManager) Export(w io.Writer) error { +func (self *BlockChain) Export(w io.Writer) error { if err := self.ExportN(w, uint64(0), self.currentBlock.NumberU64()); err != nil { return err } @@ -294,7 +266,7 @@ func (self *ChainManager) Export(w io.Writer) error { } // ExportN writes a subset of the active chain to the given writer. -func (self *ChainManager) ExportN(w io.Writer, first uint64, last uint64) error { +func (self *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { self.mu.RLock() defer self.mu.RUnlock() @@ -320,7 +292,7 @@ func (self *ChainManager) ExportN(w io.Writer, first uint64, last uint64) error // insert injects a block into the current chain block chain. Note, this function // assumes that the `mu` mutex is held! -func (bc *ChainManager) insert(block *types.Block) { +func (bc *BlockChain) insert(block *types.Block) { // Add the block to the canonical chain number scheme and mark as the head if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil { glog.Fatalf("failed to insert block number: %v", err) @@ -328,32 +300,23 @@ func (bc *ChainManager) insert(block *types.Block) { if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil { glog.Fatalf("failed to insert block number: %v", err) } - // Add a new restore point if we reached some limit - bc.checkpoint++ - if bc.checkpoint > checkpointLimit { - if err := bc.chainDb.Put([]byte("checkpoint"), block.Hash().Bytes()); err != nil { - glog.Fatalf("failed to create checkpoint: %v", err) - } - bc.checkpoint = 0 - } - // Update the internal internal state with the head block bc.currentBlock = block } // Accessors -func (bc *ChainManager) Genesis() *types.Block { +func (bc *BlockChain) Genesis() *types.Block { return bc.genesisBlock } // HasHeader checks if a block header is present in the database or not, caching // it if present. -func (bc *ChainManager) HasHeader(hash common.Hash) bool { +func (bc *BlockChain) HasHeader(hash common.Hash) bool { return bc.GetHeader(hash) != nil } // GetHeader retrieves a block header from the database by hash, caching it if // found. -func (self *ChainManager) GetHeader(hash common.Hash) *types.Header { +func (self *BlockChain) GetHeader(hash common.Hash) *types.Header { // Short circuit if the header's already in the cache, retrieve otherwise if header, ok := self.headerCache.Get(hash); ok { return header.(*types.Header) @@ -369,7 +332,7 @@ func (self *ChainManager) GetHeader(hash common.Hash) *types.Header { // GetHeaderByNumber retrieves a block header from the database by number, // caching it (associated with its hash) if found. -func (self *ChainManager) GetHeaderByNumber(number uint64) *types.Header { +func (self *BlockChain) GetHeaderByNumber(number uint64) *types.Header { hash := GetCanonicalHash(self.chainDb, number) if hash == (common.Hash{}) { return nil @@ -379,7 +342,7 @@ func (self *ChainManager) GetHeaderByNumber(number uint64) *types.Header { // GetBody retrieves a block body (transactions and uncles) from the database by // hash, caching it if found. -func (self *ChainManager) GetBody(hash common.Hash) *types.Body { +func (self *BlockChain) GetBody(hash common.Hash) *types.Body { // Short circuit if the body's already in the cache, retrieve otherwise if cached, ok := self.bodyCache.Get(hash); ok { body := cached.(*types.Body) @@ -396,7 +359,7 @@ func (self *ChainManager) GetBody(hash common.Hash) *types.Body { // GetBodyRLP retrieves a block body in RLP encoding from the database by hash, // caching it if found. -func (self *ChainManager) GetBodyRLP(hash common.Hash) rlp.RawValue { +func (self *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue { // Short circuit if the body's already in the cache, retrieve otherwise if cached, ok := self.bodyRLPCache.Get(hash); ok { return cached.(rlp.RawValue) @@ -412,7 +375,7 @@ func (self *ChainManager) GetBodyRLP(hash common.Hash) rlp.RawValue { // GetTd retrieves a block's total difficulty in the canonical chain from the // database by hash, caching it if found. -func (self *ChainManager) GetTd(hash common.Hash) *big.Int { +func (self *BlockChain) GetTd(hash common.Hash) *big.Int { // Short circuit if the td's already in the cache, retrieve otherwise if cached, ok := self.tdCache.Get(hash); ok { return cached.(*big.Int) @@ -428,12 +391,12 @@ func (self *ChainManager) GetTd(hash common.Hash) *big.Int { // HasBlock checks if a block is fully present in the database or not, caching // it if present. -func (bc *ChainManager) HasBlock(hash common.Hash) bool { +func (bc *BlockChain) HasBlock(hash common.Hash) bool { return bc.GetBlock(hash) != nil } // GetBlock retrieves a block from the database by hash, caching it if found. -func (self *ChainManager) GetBlock(hash common.Hash) *types.Block { +func (self *BlockChain) GetBlock(hash common.Hash) *types.Block { // Short circuit if the block's already in the cache, retrieve otherwise if block, ok := self.blockCache.Get(hash); ok { return block.(*types.Block) @@ -449,7 +412,7 @@ func (self *ChainManager) GetBlock(hash common.Hash) *types.Block { // GetBlockByNumber retrieves a block from the database by number, caching it // (associated with its hash) if found. -func (self *ChainManager) GetBlockByNumber(number uint64) *types.Block { +func (self *BlockChain) GetBlockByNumber(number uint64) *types.Block { hash := GetCanonicalHash(self.chainDb, number) if hash == (common.Hash{}) { return nil @@ -459,7 +422,7 @@ func (self *ChainManager) GetBlockByNumber(number uint64) *types.Block { // GetBlockHashesFromHash retrieves a number of block hashes starting at a given // hash, fetching towards the genesis block. -func (self *ChainManager) GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash { +func (self *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash { // Get the origin header from which to fetch header := self.GetHeader(hash) if header == nil { @@ -481,7 +444,7 @@ func (self *ChainManager) GetBlockHashesFromHash(hash common.Hash, max uint64) [ // [deprecated by eth/62] // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. -func (self *ChainManager) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) { +func (self *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) { for i := 0; i < n; i++ { block := self.GetBlock(hash) if block == nil { @@ -493,7 +456,7 @@ func (self *ChainManager) GetBlocksFromHash(hash common.Hash, n int) (blocks []* return } -func (self *ChainManager) GetUnclesInChain(block *types.Block, length int) (uncles []*types.Header) { +func (self *BlockChain) GetUnclesInChain(block *types.Block, length int) (uncles []*types.Header) { for i := 0; block != nil && i < length; i++ { uncles = append(uncles, block.Uncles()...) block = self.GetBlock(block.ParentHash()) @@ -504,11 +467,11 @@ func (self *ChainManager) GetUnclesInChain(block *types.Block, length int) (uncl // setTotalDifficulty updates the TD of the chain manager. Note, this function // assumes that the `mu` mutex is held! -func (bc *ChainManager) setTotalDifficulty(td *big.Int) { +func (bc *BlockChain) setTotalDifficulty(td *big.Int) { bc.td = new(big.Int).Set(td) } -func (bc *ChainManager) Stop() { +func (bc *BlockChain) Stop() { if !atomic.CompareAndSwapInt32(&bc.running, 0, 1) { return } @@ -527,7 +490,7 @@ type queueEvent struct { splitCount int } -func (self *ChainManager) procFutureBlocks() { +func (self *BlockChain) procFutureBlocks() { blocks := make([]*types.Block, self.futureBlocks.Len()) for i, hash := range self.futureBlocks.Keys() { block, _ := self.futureBlocks.Get(hash) @@ -549,7 +512,7 @@ const ( ) // WriteBlock writes the block to the chain. -func (self *ChainManager) WriteBlock(block *types.Block) (status writeStatus, err error) { +func (self *BlockChain) WriteBlock(block *types.Block) (status writeStatus, err error) { self.wg.Add(1) defer self.wg.Done() @@ -599,7 +562,7 @@ func (self *ChainManager) WriteBlock(block *types.Block) (status writeStatus, er // InsertChain will attempt to insert the given chain in to the canonical chain or, otherwise, create a fork. It an error is returned // it will return the index number of the failing block as well an error describing what went wrong (for possible errors see core/errors.go). -func (self *ChainManager) InsertChain(chain types.Blocks) (int, error) { +func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { self.wg.Add(1) defer self.wg.Done() @@ -730,7 +693,7 @@ func (self *ChainManager) InsertChain(chain types.Blocks) (int, error) { // reorgs takes two blocks, an old chain and a new chain and will reconstruct the blocks and inserts them // to be part of the new canonical chain and accumulates potential missing transactions and post an // event about them -func (self *ChainManager) reorg(oldBlock, newBlock *types.Block) error { +func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { self.mu.Lock() defer self.mu.Unlock() @@ -804,12 +767,14 @@ func (self *ChainManager) reorg(oldBlock, newBlock *types.Block) error { DeleteReceipt(self.chainDb, tx.Hash()) DeleteTransaction(self.chainDb, tx.Hash()) } - self.eventMux.Post(RemovedTransactionEvent{diff}) + // Must be posted in a goroutine because of the transaction pool trying + // to acquire the chain manager lock + go self.eventMux.Post(RemovedTransactionEvent{diff}) return nil } -func (self *ChainManager) update() { +func (self *BlockChain) update() { events := self.eventMux.Subscribe(queueEvent{}) futureTimer := time.Tick(5 * time.Second) out: @@ -840,8 +805,8 @@ out: } func blockErr(block *types.Block, err error) { - h := block.Header() - glog.V(logger.Error).Infof("Bad block #%v (%x)\n", h.Number, h.Hash().Bytes()) - glog.V(logger.Error).Infoln(err) - glog.V(logger.Debug).Infoln(verifyNonces) + if glog.V(logger.Error) { + glog.Errorf("Bad block #%v (%s)\n", block.Number(), block.Hash().Hex()) + glog.Errorf(" %v", err) + } } diff --git a/core/chain_manager_test.go b/core/blockchain_test.go index 6cfafb8c0..13971ccba 100644 --- a/core/chain_manager_test.go +++ b/core/blockchain_test.go @@ -28,8 +28,8 @@ import ( "github.com/ethereum/ethash" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -48,19 +48,19 @@ func thePow() pow.PoW { return pow } -func theChainManager(db ethdb.Database, t *testing.T) *ChainManager { +func theBlockChain(db ethdb.Database, t *testing.T) *BlockChain { var eventMux event.TypeMux WriteTestNetGenesisBlock(db, 0) - chainMan, err := NewChainManager(db, thePow(), &eventMux) + blockchain, err := NewBlockChain(db, thePow(), &eventMux) if err != nil { t.Error("failed creating chainmanager:", err) t.FailNow() return nil } - blockMan := NewBlockProcessor(db, nil, chainMan, &eventMux) - chainMan.SetProcessor(blockMan) + blockMan := NewBlockProcessor(db, nil, blockchain, &eventMux) + blockchain.SetProcessor(blockMan) - return chainMan + return blockchain } // Test fork of length N starting from block i @@ -104,7 +104,7 @@ func testFork(t *testing.T, bman *BlockProcessor, i, N int, f func(td1, td2 *big // Loop over parents making sure reconstruction is done properly } -func printChain(bc *ChainManager) { +func printChain(bc *BlockChain) { for i := bc.CurrentBlock().Number().Uint64(); i > 0; i-- { b := bc.GetBlockByNumber(uint64(i)) fmt.Printf("\t%x %v\n", b.Hash(), b.Difficulty()) @@ -144,8 +144,8 @@ func loadChain(fn string, t *testing.T) (types.Blocks, error) { return chain, nil } -func insertChain(done chan bool, chainMan *ChainManager, chain types.Blocks, t *testing.T) { - _, err := chainMan.InsertChain(chain) +func insertChain(done chan bool, blockchain *BlockChain, chain types.Blocks, t *testing.T) { + _, err := blockchain.InsertChain(chain) if err != nil { fmt.Println(err) t.FailNow() @@ -153,6 +153,19 @@ func insertChain(done chan bool, chainMan *ChainManager, chain types.Blocks, t * done <- true } +func TestLastBlock(t *testing.T) { + db, err := ethdb.NewMemDatabase() + if err != nil { + t.Fatal("Failed to create db:", err) + } + bchain := theBlockChain(db, t) + block := makeChain(bchain.CurrentBlock(), 1, db, 0)[0] + bchain.insert(block) + if block.Hash() != GetHeadBlockHash(db) { + t.Errorf("Write/Get HeadBlockHash failed") + } +} + func TestExtendCanonical(t *testing.T) { CanonicalLength := 5 db, err := ethdb.NewMemDatabase() @@ -294,23 +307,23 @@ func TestChainInsertions(t *testing.T) { t.FailNow() } - chainMan := theChainManager(db, t) + blockchain := theBlockChain(db, t) const max = 2 done := make(chan bool, max) - go insertChain(done, chainMan, chain1, t) - go insertChain(done, chainMan, chain2, t) + go insertChain(done, blockchain, chain1, t) + go insertChain(done, blockchain, chain2, t) for i := 0; i < max; i++ { <-done } - if chain2[len(chain2)-1].Hash() != chainMan.CurrentBlock().Hash() { + if chain2[len(chain2)-1].Hash() != blockchain.CurrentBlock().Hash() { t.Error("chain2 is canonical and shouldn't be") } - if chain1[len(chain1)-1].Hash() != chainMan.CurrentBlock().Hash() { + if chain1[len(chain1)-1].Hash() != blockchain.CurrentBlock().Hash() { t.Error("chain1 isn't canonical and should be") } } @@ -337,7 +350,7 @@ func TestChainMultipleInsertions(t *testing.T) { } } - chainMan := theChainManager(db, t) + blockchain := theBlockChain(db, t) done := make(chan bool, max) for i, chain := range chains { @@ -345,7 +358,7 @@ func TestChainMultipleInsertions(t *testing.T) { i := i chain := chain go func() { - insertChain(done, chainMan, chain, t) + insertChain(done, blockchain, chain, t) fmt.Println(i, "done") }() } @@ -354,14 +367,14 @@ func TestChainMultipleInsertions(t *testing.T) { <-done } - if chains[longest][len(chains[longest])-1].Hash() != chainMan.CurrentBlock().Hash() { + if chains[longest][len(chains[longest])-1].Hash() != blockchain.CurrentBlock().Hash() { t.Error("Invalid canonical chain") } } type bproc struct{} -func (bproc) Process(*types.Block) (state.Logs, types.Receipts, error) { return nil, nil, nil } +func (bproc) Process(*types.Block) (vm.Logs, types.Receipts, error) { return nil, nil, nil } func makeChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.Block { var chain []*types.Block @@ -382,9 +395,9 @@ func makeChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.Block return chain } -func chm(genesis *types.Block, db ethdb.Database) *ChainManager { +func chm(genesis *types.Block, db ethdb.Database) *BlockChain { var eventMux event.TypeMux - bc := &ChainManager{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}} + bc := &BlockChain{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}} bc.headerCache, _ = lru.New(100) bc.bodyCache, _ = lru.New(100) bc.bodyRLPCache, _ = lru.New(100) @@ -459,7 +472,7 @@ func TestReorgBadHashes(t *testing.T) { BadHashes[chain[3].Header().Hash()] = true var eventMux event.TypeMux - ncm, err := NewChainManager(db, FakePow{}, &eventMux) + ncm, err := NewBlockChain(db, FakePow{}, &eventMux) if err != nil { t.Errorf("NewChainManager err: %s", err) } @@ -593,7 +606,7 @@ func TestChainTxReorgs(t *testing.T) { }) // Import the chain. This runs all block validation rules. evmux := &event.TypeMux{} - chainman, _ := NewChainManager(db, FakePow{}, evmux) + chainman, _ := NewBlockChain(db, FakePow{}, evmux) chainman.SetProcessor(NewBlockProcessor(db, FakePow{}, chainman, evmux)) if i, err := chainman.InsertChain(chain); err != nil { t.Fatalf("failed to insert original chain[%d]: %v", i, err) diff --git a/core/chain_makers.go b/core/chain_makers.go index 70233438d..ba09b3029 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -17,6 +17,7 @@ package core import ( + "fmt" "math/big" "github.com/ethereum/go-ethereum/common" @@ -94,9 +95,9 @@ func (b *BlockGen) AddTx(tx *types.Transaction) { if err != nil { panic(err) } - b.statedb.SyncIntermediate() + root := b.statedb.IntermediateRoot() b.header.GasUsed.Add(b.header.GasUsed, gas) - receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed) + receipt := types.NewReceipt(root.Bytes(), b.header.GasUsed) logs := b.statedb.GetLogs(tx.Hash()) receipt.SetLogs(logs) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) @@ -152,7 +153,7 @@ func (b *BlockGen) OffsetTime(seconds int64) { // and their coinbase will be the zero address. // // Blocks created by GenerateChain do not contain valid proof of work -// values. Inserting them into ChainManager requires use of FakePow or +// values. Inserting them into BlockChain requires use of FakePow or // a similar non-validating proof of work implementation. func GenerateChain(parent *types.Block, db ethdb.Database, n int, gen func(int, *BlockGen)) []*types.Block { statedb := state.New(parent.Root(), db) @@ -163,8 +164,11 @@ func GenerateChain(parent *types.Block, db ethdb.Database, n int, gen func(int, gen(i, b) } AccumulateRewards(statedb, h, b.uncles) - statedb.SyncIntermediate() - h.Root = statedb.Root() + root, err := statedb.Commit() + if err != nil { + panic(fmt.Sprintf("state write error: %v", err)) + } + h.Root = root return types.NewBlock(h, b.txs, b.uncles, b.receipts) } for i := 0; i < n; i++ { @@ -184,7 +188,7 @@ func makeHeader(parent *types.Block, state *state.StateDB) *types.Header { time = new(big.Int).Add(parent.Time(), big.NewInt(10)) // block time is fixed at 10 seconds } return &types.Header{ - Root: state.Root(), + Root: state.IntermediateRoot(), ParentHash: parent.Hash(), Coinbase: parent.Coinbase(), Difficulty: CalcDifficulty(time.Uint64(), new(big.Int).Sub(time, big.NewInt(10)).Uint64(), parent.Number(), parent.Difficulty()), @@ -201,7 +205,7 @@ func newCanonical(n int, db ethdb.Database) (*BlockProcessor, error) { evmux := &event.TypeMux{} WriteTestNetGenesisBlock(db, 0) - chainman, _ := NewChainManager(db, FakePow{}, evmux) + chainman, _ := NewBlockChain(db, FakePow{}, evmux) bman := NewBlockProcessor(db, FakePow{}, chainman, evmux) bman.bc.SetProcessor(bman) parent := bman.bc.CurrentBlock() diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go index ac18e5e0b..b33af8d87 100644 --- a/core/chain_makers_test.go +++ b/core/chain_makers_test.go @@ -77,7 +77,7 @@ func ExampleGenerateChain() { // Import the chain. This runs all block validation rules. evmux := &event.TypeMux{} - chainman, _ := NewChainManager(db, FakePow{}, evmux) + chainman, _ := NewBlockChain(db, FakePow{}, evmux) chainman.SetProcessor(NewBlockProcessor(db, FakePow{}, chainman, evmux)) if i, err := chainman.InsertChain(chain); err != nil { fmt.Printf("insert error (block %d): %v\n", i, err) diff --git a/core/error.go b/core/error.go index ff58d69d6..5e32124a7 100644 --- a/core/error.go +++ b/core/error.go @@ -181,7 +181,7 @@ func IsValueTransferErr(e error) bool { type BadHashError common.Hash func (h BadHashError) Error() string { - return fmt.Sprintf("Found known bad hash in chain %x", h) + return fmt.Sprintf("Found known bad hash in chain %x", h[:]) } func IsBadHashError(err error) bool { diff --git a/core/events.go b/core/events.go index e142b6dba..8cf230dda 100644 --- a/core/events.go +++ b/core/events.go @@ -20,8 +20,8 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" ) // TxPreEvent is posted when a transaction enters the transaction pool. @@ -42,23 +42,23 @@ type RemovedTransactionEvent struct{ Txs types.Transactions } // ChainSplit is posted when a new head is detected type ChainSplitEvent struct { Block *types.Block - Logs state.Logs + Logs vm.Logs } type ChainEvent struct { Block *types.Block Hash common.Hash - Logs state.Logs + Logs vm.Logs } type ChainSideEvent struct { Block *types.Block - Logs state.Logs + Logs vm.Logs } type PendingBlockEvent struct { Block *types.Block - Logs state.Logs + Logs vm.Logs } type ChainUncleEvent struct { diff --git a/core/execution.go b/core/execution.go index 3a136515d..e3c00a2ea 100644 --- a/core/execution.go +++ b/core/execution.go @@ -17,108 +17,104 @@ package core import ( + "errors" "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/params" ) -// Execution is the execution environment for the given call or create action. -type Execution struct { - env vm.Environment - address *common.Address - input []byte - evm vm.VirtualMachine - - Gas, price, value *big.Int +// Call executes within the given contract +func Call(env vm.Environment, caller vm.ContractRef, addr common.Address, input []byte, gas, gasPrice, value *big.Int) (ret []byte, err error) { + ret, _, err = exec(env, caller, &addr, &addr, input, env.Db().GetCode(addr), gas, gasPrice, value) + return ret, err } -// NewExecution returns a new execution environment that handles all calling -// and creation logic defined by the YP. -func NewExecution(env vm.Environment, address *common.Address, input []byte, gas, gasPrice, value *big.Int) *Execution { - exe := &Execution{env: env, address: address, input: input, Gas: gas, price: gasPrice, value: value} - exe.evm = vm.NewVm(env) - return exe +// CallCode executes the given address' code as the given contract address +func CallCode(env vm.Environment, caller vm.ContractRef, addr common.Address, input []byte, gas, gasPrice, value *big.Int) (ret []byte, err error) { + prev := caller.Address() + ret, _, err = exec(env, caller, &prev, &addr, input, env.Db().GetCode(addr), gas, gasPrice, value) + return ret, err } -// Call executes within the given context -func (self *Execution) Call(codeAddr common.Address, caller vm.ContextRef) ([]byte, error) { - // Retrieve the executing code - code := self.env.State().GetCode(codeAddr) - - return self.exec(&codeAddr, code, caller) -} - -// Create creates a new contract and runs the initialisation procedure of the -// contract. This returns the returned code for the contract and is stored -// elsewhere. -func (self *Execution) Create(caller vm.ContextRef) (ret []byte, err error, account *state.StateObject) { - // Input must be nil for create - code := self.input - self.input = nil - ret, err = self.exec(nil, code, caller) +// Create creates a new contract with the given code +func Create(env vm.Environment, caller vm.ContractRef, code []byte, gas, gasPrice, value *big.Int) (ret []byte, address common.Address, err error) { + ret, address, err = exec(env, caller, nil, nil, nil, code, gas, gasPrice, value) // Here we get an error if we run into maximum stack depth, // See: https://github.com/ethereum/yellowpaper/pull/131 // and YP definitions for CREATE instruction if err != nil { - return nil, err, nil + return nil, address, err } - account = self.env.State().GetStateObject(*self.address) - return + return ret, address, err } -// exec executes the given code and executes within the contextAddr context. -func (self *Execution) exec(contextAddr *common.Address, code []byte, caller vm.ContextRef) (ret []byte, err error) { - env := self.env - evm := self.evm +func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.Address, input, code []byte, gas, gasPrice, value *big.Int) (ret []byte, addr common.Address, err error) { + evm := vm.NewVm(env) + // Depth check execution. Fail if we're trying to execute above the // limit. if env.Depth() > int(params.CallCreateDepth.Int64()) { - caller.ReturnGas(self.Gas, self.price) + caller.ReturnGas(gas, gasPrice) - return nil, vm.DepthError + return nil, common.Address{}, vm.DepthError } - if !env.CanTransfer(env.State().GetStateObject(caller.Address()), self.value) { - caller.ReturnGas(self.Gas, self.price) + if !env.CanTransfer(caller.Address(), value) { + caller.ReturnGas(gas, gasPrice) - return nil, ValueTransferErr("insufficient funds to transfer value. Req %v, has %v", self.value, env.State().GetBalance(caller.Address())) + return nil, common.Address{}, ValueTransferErr("insufficient funds to transfer value. Req %v, has %v", value, env.Db().GetBalance(caller.Address())) } var createAccount bool - if self.address == nil { + if address == nil { // Generate a new address - nonce := env.State().GetNonce(caller.Address()) - env.State().SetNonce(caller.Address(), nonce+1) + nonce := env.Db().GetNonce(caller.Address()) + env.Db().SetNonce(caller.Address(), nonce+1) - addr := crypto.CreateAddress(caller.Address(), nonce) + addr = crypto.CreateAddress(caller.Address(), nonce) - self.address = &addr + address = &addr createAccount = true } - snapshot := env.State().Copy() + snapshot := env.MakeSnapshot() var ( - from = env.State().GetStateObject(caller.Address()) - to *state.StateObject + from = env.Db().GetAccount(caller.Address()) + to vm.Account ) if createAccount { - to = env.State().CreateAccount(*self.address) + to = env.Db().CreateAccount(*address) } else { - to = env.State().GetOrNewStateObject(*self.address) + if !env.Db().Exist(*address) { + to = env.Db().CreateAccount(*address) + } else { + to = env.Db().GetAccount(*address) + } } - vm.Transfer(from, to, self.value) + env.Transfer(from, to, value) - context := vm.NewContext(caller, to, self.value, self.Gas, self.price) - context.SetCallCode(contextAddr, code) + contract := vm.NewContract(caller, to, value, gas, gasPrice) + contract.SetCallCode(codeAddr, code) - ret, err = evm.Run(context, self.input) + ret, err = evm.Run(contract, input) if err != nil { - env.State().Set(snapshot) + env.SetSnapshot(snapshot) //env.Db().Set(snapshot) } - return + return ret, addr, err +} + +// generic transfer method +func Transfer(from, to vm.Account, amount *big.Int) error { + if from.Balance().Cmp(amount) < 0 { + return errors.New("Insufficient balance in account") + } + + from.SubBalance(amount) + to.AddBalance(amount) + + return nil } diff --git a/core/genesis.go b/core/genesis.go index 11dbdee6d..62e039d1a 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -69,7 +69,7 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block, statedb.SetState(address, common.HexToHash(key), common.HexToHash(value)) } } - statedb.SyncObjects() + root, stateBatch := statedb.CommitBatch() difficulty := common.String2Big(genesis.Difficulty) block := types.NewBlock(&types.Header{ @@ -81,7 +81,7 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block, Difficulty: difficulty, MixDigest: common.HexToHash(genesis.Mixhash), Coinbase: common.HexToAddress(genesis.Coinbase), - Root: statedb.Root(), + Root: root, }, nil, nil, nil) if block := GetBlock(chainDb, block.Hash()); block != nil { @@ -92,8 +92,10 @@ func WriteGenesisBlock(chainDb ethdb.Database, reader io.Reader) (*types.Block, } return block, nil } - statedb.Sync() + if err := stateBatch.Write(); err != nil { + return nil, fmt.Errorf("cannot write state: %v", err) + } if err := WriteTd(chainDb, block.Hash(), difficulty); err != nil { return nil, err } @@ -118,12 +120,14 @@ func GenesisBlockForTesting(db ethdb.Database, addr common.Address, balance *big statedb := state.New(common.Hash{}, db) obj := statedb.GetOrNewStateObject(addr) obj.SetBalance(balance) - statedb.SyncObjects() - statedb.Sync() + root, err := statedb.Commit() + if err != nil { + panic(fmt.Sprintf("cannot write state: %v", err)) + } block := types.NewBlock(&types.Header{ Difficulty: params.GenesisDifficulty, GasLimit: params.GenesisGasLimit, - Root: statedb.Root(), + Root: root, }, nil, nil, nil) return block } diff --git a/core/helper_test.go b/core/helper_test.go index 81ea6fc22..fd6a5491c 100644 --- a/core/helper_test.go +++ b/core/helper_test.go @@ -34,7 +34,7 @@ type TestManager struct { db ethdb.Database txPool *TxPool - blockChain *ChainManager + blockChain *BlockChain Blocks []*types.Block } @@ -54,7 +54,7 @@ func (s *TestManager) Peers() *list.List { return list.New() } -func (s *TestManager) ChainManager() *ChainManager { +func (s *TestManager) BlockChain() *BlockChain { return s.blockChain } @@ -89,7 +89,7 @@ func NewTestManager() *TestManager { testManager.eventMux = new(event.TypeMux) testManager.db = db // testManager.txPool = NewTxPool(testManager) - // testManager.blockChain = NewChainManager(testManager) + // testManager.blockChain = NewBlockChain(testManager) // testManager.stateManager = NewStateManager(testManager) return testManager diff --git a/core/manager.go b/core/manager.go index 0f108a6de..289c87c11 100644 --- a/core/manager.go +++ b/core/manager.go @@ -26,7 +26,7 @@ import ( type Backend interface { AccountManager() *accounts.Manager BlockProcessor() *BlockProcessor - ChainManager() *ChainManager + BlockChain() *BlockChain TxPool() *TxPool ChainDb() ethdb.Database DappDb() ethdb.Database diff --git a/core/state/state_object.go b/core/state/state_object.go index 353f2357b..40af9ed9c 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -90,15 +90,13 @@ type StateObject struct { func NewStateObject(address common.Address, db ethdb.Database) *StateObject { object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true} - object.trie = trie.NewSecure((common.Hash{}).Bytes(), db) + object.trie, _ = trie.NewSecure(common.Hash{}, db) object.storage = make(Storage) object.gasPool = new(big.Int) - return object } func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Database) *StateObject { - // TODO clean me up var extobject struct { Nonce uint64 Balance *big.Int @@ -107,7 +105,13 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Datab } err := rlp.Decode(bytes.NewReader(data), &extobject) if err != nil { - fmt.Println(err) + glog.Errorf("can't decode state object %x: %v", address, err) + return nil + } + trie, err := trie.NewSecure(extobject.Root, db) + if err != nil { + // TODO: bubble this up or panic + glog.Errorf("can't create account trie with root %x: %v", extobject.Root[:], err) return nil } @@ -115,11 +119,10 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db ethdb.Datab object.nonce = extobject.Nonce object.balance = extobject.Balance object.codeHash = extobject.CodeHash - object.trie = trie.NewSecure(extobject.Root[:], db) + object.trie = trie object.storage = make(map[string]common.Hash) object.gasPool = new(big.Int) object.code, _ = db.Get(extobject.CodeHash) - return object } @@ -215,6 +218,7 @@ func (c *StateObject) ReturnGas(gas, price *big.Int) {} func (self *StateObject) SetGasLimit(gasLimit *big.Int) { self.gasPool = new(big.Int).Set(gasLimit) + self.dirty = true if glog.V(logger.Core) { glog.Infof("%x: gas (+ %v)", self.Address(), self.gasPool) @@ -225,19 +229,14 @@ func (self *StateObject) SubGas(gas, price *big.Int) error { if self.gasPool.Cmp(gas) < 0 { return GasLimitError(self.gasPool, gas) } - self.gasPool.Sub(self.gasPool, gas) - - rGas := new(big.Int).Set(gas) - rGas.Mul(rGas, price) - self.dirty = true - return nil } func (self *StateObject) AddGas(gas, price *big.Int) { self.gasPool.Add(self.gasPool, gas) + self.dirty = true } func (self *StateObject) Copy() *StateObject { diff --git a/core/state/state_test.go b/core/state/state_test.go index 60836738e..b5a7f4081 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -89,8 +89,7 @@ func TestNull(t *testing.T) { //value := common.FromHex("0x823140710bf13990e4500136726d8b55") var value common.Hash state.SetState(address, common.Hash{}, value) - state.SyncIntermediate() - state.Sync() + state.Commit() value = state.GetState(address, common.Hash{}) if !common.EmptyHash(value) { t.Errorf("expected empty hash. got %x", value) diff --git a/core/state/statedb.go b/core/state/statedb.go index 24f97e32a..499ea5f52 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -21,6 +21,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" @@ -35,7 +36,6 @@ import ( type StateDB struct { db ethdb.Database trie *trie.SecureTrie - root common.Hash stateObjects map[string]*StateObject @@ -43,18 +43,25 @@ type StateDB struct { thash, bhash common.Hash txIndex int - logs map[common.Hash]Logs + logs map[common.Hash]vm.Logs logSize uint } // Create a new state from a given trie func New(root common.Hash, db ethdb.Database) *StateDB { - trie := trie.NewSecure(root[:], db) - return &StateDB{root: root, db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)} -} - -func (self *StateDB) PrintRoot() { - self.trie.Trie.PrintRoot() + tr, err := trie.NewSecure(root, db) + if err != nil { + // TODO: bubble this up + tr, _ = trie.NewSecure(common.Hash{}, db) + glog.Errorf("can't create state trie with root %x: %v", root[:], err) + } + return &StateDB{ + db: db, + trie: tr, + stateObjects: make(map[string]*StateObject), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), + } } func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { @@ -63,7 +70,7 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { self.txIndex = ti } -func (self *StateDB) AddLog(log *Log) { +func (self *StateDB) AddLog(log *vm.Log) { log.TxHash = self.thash log.BlockHash = self.bhash log.TxIndex = uint(self.txIndex) @@ -72,30 +79,34 @@ func (self *StateDB) AddLog(log *Log) { self.logSize++ } -func (self *StateDB) GetLogs(hash common.Hash) Logs { +func (self *StateDB) GetLogs(hash common.Hash) vm.Logs { return self.logs[hash] } -func (self *StateDB) Logs() Logs { - var logs Logs +func (self *StateDB) Logs() vm.Logs { + var logs vm.Logs for _, lgs := range self.logs { logs = append(logs, lgs...) } return logs } -func (self *StateDB) Refund(gas *big.Int) { +func (self *StateDB) AddRefund(gas *big.Int) { self.refund.Add(self.refund, gas) } -/* - * GETTERS - */ - func (self *StateDB) HasAccount(addr common.Address) bool { return self.GetStateObject(addr) != nil } +func (self *StateDB) Exist(addr common.Address) bool { + return self.GetStateObject(addr) != nil +} + +func (self *StateDB) GetAccount(addr common.Address) vm.Account { + return self.GetStateObject(addr) +} + // Retrieve the balance from the given address or 0 if object not found func (self *StateDB) GetBalance(addr common.Address) *big.Int { stateObject := self.GetStateObject(addr) @@ -196,7 +207,6 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { if len(stateObject.CodeHash()) > 0 { self.db.Put(stateObject.CodeHash(), stateObject.code) } - addr := stateObject.Address() self.trie.Update(addr[:], stateObject.RlpEncode()) } @@ -207,6 +217,7 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) { addr := stateObject.Address() self.trie.Delete(addr[:]) + //delete(self.stateObjects, addr.Str()) } // Retrieve a state object given my the address. Nil if not found @@ -239,7 +250,7 @@ func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject = self.CreateAccount(addr) + stateObject = self.CreateStateObject(addr) } return stateObject @@ -258,7 +269,7 @@ func (self *StateDB) newStateObject(addr common.Address) *StateObject { } // Creates creates a new state object and takes ownership. This is different from "NewStateObject" -func (self *StateDB) CreateAccount(addr common.Address) *StateObject { +func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { // Get previous (if any) so := self.GetStateObject(addr) // Create a new one @@ -272,6 +283,10 @@ func (self *StateDB) CreateAccount(addr common.Address) *StateObject { return newSo } +func (self *StateDB) CreateAccount(addr common.Address) vm.Account { + return self.CreateStateObject(addr) +} + // // Setting, copying of the state methods // @@ -286,7 +301,7 @@ func (self *StateDB) Copy() *StateDB { state.refund.Set(self.refund) for hash, logs := range self.logs { - state.logs[hash] = make(Logs, len(logs)) + state.logs[hash] = make(vm.Logs, len(logs)) copy(state.logs[hash], logs) } state.logSize = self.logSize @@ -303,65 +318,71 @@ func (self *StateDB) Set(state *StateDB) { self.logSize = state.logSize } -func (s *StateDB) Root() common.Hash { - return common.BytesToHash(s.trie.Root()) -} - -// Syncs the trie and all siblings -func (s *StateDB) Sync() { - // Sync all nested states - for _, stateObject := range s.stateObjects { - stateObject.trie.Commit() - } - - s.trie.Commit() - - s.Empty() -} - -func (self *StateDB) Empty() { - self.stateObjects = make(map[string]*StateObject) - self.refund = new(big.Int) -} - -func (self *StateDB) Refunds() *big.Int { +func (self *StateDB) GetRefund() *big.Int { return self.refund } -// SyncIntermediate updates the intermediate state and all mid steps -func (self *StateDB) SyncIntermediate() { - self.refund = new(big.Int) - - for _, stateObject := range self.stateObjects { +// IntermediateRoot computes the current root hash of the state trie. +// It is called in between transactions to get the root hash that +// goes into transaction receipts. +func (s *StateDB) IntermediateRoot() common.Hash { + s.refund = new(big.Int) + for _, stateObject := range s.stateObjects { if stateObject.dirty { if stateObject.remove { - self.DeleteStateObject(stateObject) + s.DeleteStateObject(stateObject) } else { stateObject.Update() - - self.UpdateStateObject(stateObject) + s.UpdateStateObject(stateObject) } stateObject.dirty = false } } + return s.trie.Hash() +} + +// Commit commits all state changes to the database. +func (s *StateDB) Commit() (root common.Hash, err error) { + return s.commit(s.db) } -// SyncObjects syncs the changed objects to the trie -func (self *StateDB) SyncObjects() { - self.trie = trie.NewSecure(self.root[:], self.db) +// CommitBatch commits all state changes to a write batch but does not +// execute the batch. It is used to validate state changes against +// the root hash stored in a block. +func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { + batch = s.db.NewBatch() + root, _ = s.commit(batch) + return root, batch +} - self.refund = new(big.Int) +func (s *StateDB) commit(db trie.DatabaseWriter) (common.Hash, error) { + s.refund = new(big.Int) - for _, stateObject := range self.stateObjects { + for _, stateObject := range s.stateObjects { if stateObject.remove { - self.DeleteStateObject(stateObject) + // If the object has been removed, don't bother syncing it + // and just mark it for deletion in the trie. + s.DeleteStateObject(stateObject) } else { + // Write any storage changes in the state object to its trie. stateObject.Update() - - self.UpdateStateObject(stateObject) + // Commit the trie of the object to the batch. + // This updates the trie root internally, so + // getting the root hash of the storage trie + // through UpdateStateObject is fast. + if _, err := stateObject.trie.CommitTo(db); err != nil { + return common.Hash{}, err + } + // Update the object in the account trie. + s.UpdateStateObject(stateObject) } stateObject.dirty = false } + return s.trie.CommitTo(db) +} + +func (self *StateDB) Refunds() *big.Int { + return self.refund } // Debug stuff diff --git a/core/state_transition.go b/core/state_transition.go index 6ff7fa1ff..e83019229 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -51,7 +51,7 @@ type StateTransition struct { initialGas *big.Int value *big.Int data []byte - state *state.StateDB + state vm.Database env vm.Environment } @@ -95,11 +95,7 @@ func IntrinsicGas(data []byte) *big.Int { } func ApplyMessage(env vm.Environment, msg Message, gp GasPool) ([]byte, *big.Int, error) { - return NewStateTransition(env, msg, gp).transitionState() -} - -func NewStateTransition(env vm.Environment, msg Message, gp GasPool) *StateTransition { - return &StateTransition{ + var st = StateTransition{ gp: gp, env: env, msg: msg, @@ -108,18 +104,22 @@ func NewStateTransition(env vm.Environment, msg Message, gp GasPool) *StateTrans initialGas: new(big.Int), value: msg.Value(), data: msg.Data(), - state: env.State(), + state: env.Db(), } + return st.transitionDb() } -func (self *StateTransition) From() (*state.StateObject, error) { +func (self *StateTransition) from() (vm.Account, error) { f, err := self.msg.From() if err != nil { return nil, err } - return self.state.GetOrNewStateObject(f), nil + if !self.state.Exist(f) { + return self.state.CreateAccount(f), nil + } + return self.state.GetAccount(f), nil } -func (self *StateTransition) To() *state.StateObject { +func (self *StateTransition) to() vm.Account { if self.msg == nil { return nil } @@ -127,10 +127,14 @@ func (self *StateTransition) To() *state.StateObject { if to == nil { return nil // contract creation } - return self.state.GetOrNewStateObject(*to) + + if !self.state.Exist(*to) { + return self.state.CreateAccount(*to) + } + return self.state.GetAccount(*to) } -func (self *StateTransition) UseGas(amount *big.Int) error { +func (self *StateTransition) useGas(amount *big.Int) error { if self.gas.Cmp(amount) < 0 { return vm.OutOfGasError } @@ -139,15 +143,15 @@ func (self *StateTransition) UseGas(amount *big.Int) error { return nil } -func (self *StateTransition) AddGas(amount *big.Int) { +func (self *StateTransition) addGas(amount *big.Int) { self.gas.Add(self.gas, amount) } -func (self *StateTransition) BuyGas() error { +func (self *StateTransition) buyGas() error { mgas := self.msg.Gas() mgval := new(big.Int).Mul(mgas, self.gasPrice) - sender, err := self.From() + sender, err := self.from() if err != nil { return err } @@ -157,7 +161,7 @@ func (self *StateTransition) BuyGas() error { if err = self.gp.SubGas(mgas, self.gasPrice); err != nil { return err } - self.AddGas(mgas) + self.addGas(mgas) self.initialGas.Set(mgas) sender.SubBalance(mgval) return nil @@ -165,18 +169,19 @@ func (self *StateTransition) BuyGas() error { func (self *StateTransition) preCheck() (err error) { msg := self.msg - sender, err := self.From() + sender, err := self.from() if err != nil { return err } // Make sure this transaction's nonce is correct - if sender.Nonce() != msg.Nonce() { - return NonceError(msg.Nonce(), sender.Nonce()) + //if sender.Nonce() != msg.Nonce() { + if n := self.state.GetNonce(sender.Address()); n != msg.Nonce() { + return NonceError(msg.Nonce(), n) } // Pre-pay gas / Buy gas of the coinbase account - if err = self.BuyGas(); err != nil { + if err = self.buyGas(); err != nil { if state.IsGasLimitErr(err) { return err } @@ -186,28 +191,28 @@ func (self *StateTransition) preCheck() (err error) { return nil } -func (self *StateTransition) transitionState() (ret []byte, usedGas *big.Int, err error) { +func (self *StateTransition) transitionDb() (ret []byte, usedGas *big.Int, err error) { if err = self.preCheck(); err != nil { return } msg := self.msg - sender, _ := self.From() // err checked in preCheck + sender, _ := self.from() // err checked in preCheck // Pay intrinsic gas - if err = self.UseGas(IntrinsicGas(self.data)); err != nil { + if err = self.useGas(IntrinsicGas(self.data)); err != nil { return nil, nil, InvalidTxError(err) } vmenv := self.env - var ref vm.ContextRef + var addr common.Address if MessageCreatesContract(msg) { - ret, err, ref = vmenv.Create(sender, self.data, self.gas, self.gasPrice, self.value) + ret, addr, err = vmenv.Create(sender, self.data, self.gas, self.gasPrice, self.value) if err == nil { dataGas := big.NewInt(int64(len(ret))) dataGas.Mul(dataGas, params.CreateDataGas) - if err := self.UseGas(dataGas); err == nil { - ref.SetCode(ret) + if err := self.useGas(dataGas); err == nil { + self.state.SetCode(addr, ret) } else { ret = nil // does not affect consensus but useful for StateTests validations glog.V(logger.Core).Infoln("Insufficient gas for creating code. Require", dataGas, "and have", self.gas) @@ -216,8 +221,8 @@ func (self *StateTransition) transitionState() (ret []byte, usedGas *big.Int, er glog.V(logger.Core).Infoln("VM create err:", err) } else { // Increment the nonce for the next transaction - self.state.SetNonce(sender.Address(), sender.Nonce()+1) - ret, err = vmenv.Call(sender, self.To().Address(), self.data, self.gas, self.gasPrice, self.value) + self.state.SetNonce(sender.Address(), self.state.GetNonce(sender.Address())+1) + ret, err = vmenv.Call(sender, self.to().Address(), self.data, self.gas, self.gasPrice, self.value) glog.V(logger.Core).Infoln("VM call err:", err) } @@ -241,13 +246,13 @@ func (self *StateTransition) transitionState() (ret []byte, usedGas *big.Int, er } func (self *StateTransition) refundGas() { - sender, _ := self.From() // err already checked + sender, _ := self.from() // err already checked // Return remaining gas remaining := new(big.Int).Mul(self.gas, self.gasPrice) sender.AddBalance(remaining) uhalf := remaining.Div(self.gasUsed(), common.Big2) - refund := common.BigMin(uhalf, self.state.Refunds()) + refund := common.BigMin(uhalf, self.state.GetRefund()) self.gas.Add(self.gas, refund) self.state.AddBalance(sender.Address(), refund.Mul(refund, self.gasPrice)) diff --git a/core/types/bloom9.go b/core/types/bloom9.go index 0629b31d4..f87ae58e6 100644 --- a/core/types/bloom9.go +++ b/core/types/bloom9.go @@ -20,7 +20,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" ) @@ -37,7 +37,7 @@ func CreateBloom(receipts Receipts) Bloom { return BytesToBloom(bin.Bytes()) } -func LogsBloom(logs state.Logs) *big.Int { +func LogsBloom(logs vm.Logs) *big.Int { bin := new(big.Int) for _, log := range logs { data := make([]common.Hash, len(log.Topics)) diff --git a/core/types/common.go b/core/types/common.go index de6efcd86..dc428c00c 100644 --- a/core/types/common.go +++ b/core/types/common.go @@ -19,14 +19,14 @@ package types import ( "math/big" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" - "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/vm" ) type BlockProcessor interface { - Process(*Block) (state.Logs, Receipts, error) + Process(*Block) (vm.Logs, Receipts, error) } const bloomLength = 256 diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 478edb0e8..00c42c5bc 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -17,8 +17,9 @@ package types import ( + "bytes" + "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -29,12 +30,12 @@ type DerivableList interface { } func DeriveSha(list DerivableList) common.Hash { - db, _ := ethdb.NewMemDatabase() - trie := trie.New(nil, db) + keybuf := new(bytes.Buffer) + trie := new(trie.Trie) for i := 0; i < list.Len(); i++ { - key, _ := rlp.EncodeToBytes(uint(i)) - trie.Update(key, list.GetRlp(i)) + keybuf.Reset() + rlp.Encode(keybuf, uint(i)) + trie.Update(keybuf.Bytes(), list.GetRlp(i)) } - - return common.BytesToHash(trie.Root()) + return trie.Hash() } diff --git a/core/types/receipt.go b/core/types/receipt.go index e01d69005..bcb4bd8a5 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -23,7 +23,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/rlp" ) @@ -33,7 +33,7 @@ type Receipt struct { Bloom Bloom TxHash common.Hash ContractAddress common.Address - logs state.Logs + logs vm.Logs GasUsed *big.Int } @@ -41,11 +41,11 @@ func NewReceipt(root []byte, cumalativeGasUsed *big.Int) *Receipt { return &Receipt{PostState: common.CopyBytes(root), CumulativeGasUsed: new(big.Int).Set(cumalativeGasUsed)} } -func (self *Receipt) SetLogs(logs state.Logs) { +func (self *Receipt) SetLogs(logs vm.Logs) { self.logs = logs } -func (self *Receipt) Logs() state.Logs { +func (self *Receipt) Logs() vm.Logs { return self.logs } @@ -60,7 +60,7 @@ func (self *Receipt) DecodeRLP(s *rlp.Stream) error { Bloom Bloom TxHash common.Hash ContractAddress common.Address - Logs state.Logs + Logs vm.Logs GasUsed *big.Int } if err := s.Decode(&r); err != nil { @@ -74,9 +74,9 @@ func (self *Receipt) DecodeRLP(s *rlp.Stream) error { type ReceiptForStorage Receipt func (self *ReceiptForStorage) EncodeRLP(w io.Writer) error { - storageLogs := make([]*state.LogForStorage, len(self.logs)) + storageLogs := make([]*vm.LogForStorage, len(self.logs)) for i, log := range self.logs { - storageLogs[i] = (*state.LogForStorage)(log) + storageLogs[i] = (*vm.LogForStorage)(log) } return rlp.Encode(w, []interface{}{self.PostState, self.CumulativeGasUsed, self.Bloom, self.TxHash, self.ContractAddress, storageLogs, self.GasUsed}) } diff --git a/core/vm/asm.go b/core/vm/asm.go index 639201e50..065d3eb97 100644 --- a/core/vm/asm.go +++ b/core/vm/asm.go @@ -23,6 +23,8 @@ import ( "github.com/ethereum/go-ethereum/common" ) +// Dissassemble dissassembles the byte code and returns the string +// representation (human readable opcodes). func Disassemble(script []byte) (asm []string) { pc := new(big.Int) for { diff --git a/core/vm/common.go b/core/vm/common.go index 2e03ec80b..2d1aa9332 100644 --- a/core/vm/common.go +++ b/core/vm/common.go @@ -22,34 +22,34 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/params" ) // Global Debug flag indicating Debug VM (full logging) var Debug bool +// Type is the VM type accepted by **NewVm** type Type byte const ( - StdVmTy Type = iota - JitVmTy + StdVmTy Type = iota // Default standard VM + JitVmTy // LLVM JIT VM MaxVmTy - - LogTyPretty byte = 0x1 - LogTyDiff byte = 0x2 ) var ( - Pow256 = common.BigPow(2, 256) + Pow256 = common.BigPow(2, 256) // Pow256 is 2**256 - U256 = common.U256 - S256 = common.S256 + U256 = common.U256 // Shortcut to common.U256 + S256 = common.S256 // Shortcut to common.S256 - Zero = common.Big0 - One = common.Big1 + Zero = common.Big0 // Shortcut to common.Big0 + One = common.Big1 // Shortcut to common.Big1 - max = big.NewInt(math.MaxInt64) + max = big.NewInt(math.MaxInt64) // Maximum 64 bit integer ) +// NewVm returns a new VM based on the Environment func NewVm(env Environment) VirtualMachine { switch env.VmType() { case JitVmTy: @@ -62,6 +62,7 @@ func NewVm(env Environment) VirtualMachine { } } +// calculates the memory size required for a step func calcMemSize(off, l *big.Int) *big.Int { if l.Cmp(common.Big0) == 0 { return common.Big0 @@ -70,6 +71,32 @@ func calcMemSize(off, l *big.Int) *big.Int { return new(big.Int).Add(off, l) } +// calculates the quadratic gas +func quadMemGas(mem *Memory, newMemSize, gas *big.Int) { + if newMemSize.Cmp(common.Big0) > 0 { + newMemSizeWords := toWordSize(newMemSize) + newMemSize.Mul(newMemSizeWords, u256(32)) + + if newMemSize.Cmp(u256(int64(mem.Len()))) > 0 { + // be careful reusing variables here when changing. + // The order has been optimised to reduce allocation + oldSize := toWordSize(big.NewInt(int64(mem.Len()))) + pow := new(big.Int).Exp(oldSize, common.Big2, Zero) + linCoef := oldSize.Mul(oldSize, params.MemoryGas) + quadCoef := new(big.Int).Div(pow, params.QuadCoeffDiv) + oldTotalFee := new(big.Int).Add(linCoef, quadCoef) + + pow.Exp(newMemSizeWords, common.Big2, Zero) + linCoef = linCoef.Mul(newMemSizeWords, params.MemoryGas) + quadCoef = quadCoef.Div(pow, params.QuadCoeffDiv) + newTotalFee := linCoef.Add(linCoef, quadCoef) + + fee := newTotalFee.Sub(newTotalFee, oldTotalFee) + gas.Add(gas, fee) + } + } +} + // Simple helper func u256(n int64) *big.Int { return big.NewInt(n) @@ -86,6 +113,8 @@ func toValue(val *big.Int) interface{} { return val } +// getData returns a slice from the data based on the start and size and pads +// up to size with zero's. This function is overflow safe. func getData(data []byte, start, size *big.Int) []byte { dlen := big.NewInt(int64(len(data))) @@ -94,7 +123,9 @@ func getData(data []byte, start, size *big.Int) []byte { return common.RightPadBytes(data[s.Uint64():e.Uint64()], int(size.Uint64())) } -func UseGas(gas, amount *big.Int) bool { +// useGas attempts to subtract the amount of gas and returns whether it was +// successful +func useGas(gas, amount *big.Int) bool { if gas.Cmp(amount) < 0 { return false } diff --git a/core/vm/context.go b/core/vm/contract.go index d17934ba5..95417e747 100644 --- a/core/vm/context.go +++ b/core/vm/contract.go @@ -22,15 +22,18 @@ import ( "github.com/ethereum/go-ethereum/common" ) -type ContextRef interface { +// ContractRef is a reference to the contract's backing object +type ContractRef interface { ReturnGas(*big.Int, *big.Int) Address() common.Address SetCode([]byte) } -type Context struct { - caller ContextRef - self ContextRef +// Contract represents an ethereum contract in the state database. It contains +// the the contract code, calling arguments. Contract implements ContractReg +type Contract struct { + caller ContractRef + self ContractRef jumpdests destinations // result of JUMPDEST analysis. @@ -44,10 +47,10 @@ type Context struct { } // Create a new context for the given data items. -func NewContext(caller ContextRef, object ContextRef, value, gas, price *big.Int) *Context { - c := &Context{caller: caller, self: object, Args: nil} +func NewContract(caller ContractRef, object ContractRef, value, gas, price *big.Int) *Contract { + c := &Contract{caller: caller, self: object, Args: nil} - if parent, ok := caller.(*Context); ok { + if parent, ok := caller.(*Contract); ok { // Reuse JUMPDEST analysis from parent context if available. c.jumpdests = parent.jumpdests } else { @@ -66,11 +69,13 @@ func NewContext(caller ContextRef, object ContextRef, value, gas, price *big.Int return c } -func (c *Context) GetOp(n uint64) OpCode { +// GetOp returns the n'th element in the contract's byte array +func (c *Contract) GetOp(n uint64) OpCode { return OpCode(c.GetByte(n)) } -func (c *Context) GetByte(n uint64) byte { +// GetByte returns the n'th byte in the contract's byte array +func (c *Contract) GetByte(n uint64) byte { if n < uint64(len(c.Code)) { return c.Code[n] } @@ -78,43 +83,44 @@ func (c *Context) GetByte(n uint64) byte { return 0 } -func (c *Context) Return(ret []byte) []byte { +// Return returns the given ret argument and returns any remaining gas to the +// caller +func (c *Contract) Return(ret []byte) []byte { // Return the remaining gas to the caller c.caller.ReturnGas(c.Gas, c.Price) return ret } -/* - * Gas functions - */ -func (c *Context) UseGas(gas *big.Int) (ok bool) { - ok = UseGas(c.Gas, gas) +// UseGas attempts the use gas and subtracts it and returns true on success +func (c *Contract) UseGas(gas *big.Int) (ok bool) { + ok = useGas(c.Gas, gas) if ok { c.UsedGas.Add(c.UsedGas, gas) } return } -// Implement the caller interface -func (c *Context) ReturnGas(gas, price *big.Int) { +// ReturnGas adds the given gas back to itself. +func (c *Contract) ReturnGas(gas, price *big.Int) { // Return the gas to the context c.Gas.Add(c.Gas, gas) c.UsedGas.Sub(c.UsedGas, gas) } -/* - * Set / Get - */ -func (c *Context) Address() common.Address { +// Address returns the contracts address +func (c *Contract) Address() common.Address { return c.self.Address() } -func (self *Context) SetCode(code []byte) { +// SetCode sets the code to the contract +func (self *Contract) SetCode(code []byte) { self.Code = code } -func (self *Context) SetCallCode(addr *common.Address, code []byte) { +// SetCallCode sets the code of the contract and address of the backing data +// object +func (self *Contract) SetCallCode(addr *common.Address, code []byte) { self.Code = code self.CodeAddr = addr } diff --git a/core/vm/contracts.go b/core/vm/contracts.go index b965fa095..22cb9eab2 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -26,22 +26,22 @@ import ( "github.com/ethereum/go-ethereum/params" ) -type Address interface { - Call(in []byte) []byte -} - +// PrecompiledAccount represents a native ethereum contract type PrecompiledAccount struct { Gas func(l int) *big.Int fn func(in []byte) []byte } +// Call calls the native function func (self PrecompiledAccount) Call(in []byte) []byte { return self.fn(in) } +// Precompiled contains the default set of ethereum contracts var Precompiled = PrecompiledContracts() -// XXX Could set directly. Testing requires resetting and setting of pre compiled contracts. +// PrecompiledContracts returns the default set of precompiled ethereum +// contracts defined by the ethereum yellow paper. func PrecompiledContracts() map[string]*PrecompiledAccount { return map[string]*PrecompiledAccount{ // ECRECOVER diff --git a/core/vm/settings.go b/core/vm/doc.go index f9296f6c8..ab87bf934 100644 --- a/core/vm/settings.go +++ b/core/vm/doc.go @@ -14,12 +14,19 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. -package vm +/* +Package vm implements the Ethereum Virtual Machine. -var ( - EnableJit bool // Enables the JIT VM - ForceJit bool // Force the JIT, skip byte VM - MaxProgSize int // Max cache size for JIT Programs -) +The vm package implements two EVMs, a byte code VM and a JIT VM. The BC +(Byte Code) VM loops over a set of bytes and executes them according to the set +of rules defined in the Ethereum yellow paper. When the BC VM is invoked it +invokes the JIT VM in a seperate goroutine and compiles the byte code in JIT +instructions. -const defaultJitMaxCache int = 64 +The JIT VM, when invoked, loops around a set of pre-defined instructions until +it either runs of gas, causes an internal error, returns or stops. At a later +stage the JIT VM will see some additional features that will cause sets of +instructions to be compiled down to segments. Segments are sets of instructions +that can be run in one go saving precious time during execution. +*/ +package vm diff --git a/core/vm/environment.go b/core/vm/environment.go index 916081f51..f8e19baea 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -17,39 +17,86 @@ package vm import ( - "errors" "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" ) // Environment is is required by the virtual machine to get information from -// it's own isolated environment. For an example see `core.VMEnv` -type Environment interface { - State() *state.StateDB +// it's own isolated environment. +// Environment is an EVM requirement and helper which allows access to outside +// information such as states. +type Environment interface { + // The state database + Db() Database + // Creates a restorable snapshot + MakeSnapshot() Database + // Set database to previous snapshot + SetSnapshot(Database) + // Address of the original invoker (first occurance of the VM invoker) Origin() common.Address + // The block number this VM is invoken on BlockNumber() *big.Int + // The n'th hash ago from this block number GetHash(n uint64) common.Hash + // The handler's address Coinbase() common.Address + // The current time (block time) Time() *big.Int + // Difficulty set on the current block Difficulty() *big.Int + // The gas limit of the block GasLimit() *big.Int - CanTransfer(from Account, balance *big.Int) bool + // Determines whether it's possible to transact + CanTransfer(from common.Address, balance *big.Int) bool + // Transfers amount from one account to the other Transfer(from, to Account, amount *big.Int) error - AddLog(*state.Log) + // Adds a LOG to the state + AddLog(*Log) + // Adds a structured log to the env AddStructLog(StructLog) + // Returns all coalesced structured logs StructLogs() []StructLog + // Type of the VM VmType() Type + // Current calling depth Depth() int SetDepth(i int) - Call(me ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) - CallCode(me ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) - Create(me ContextRef, data []byte, gas, price, value *big.Int) ([]byte, error, ContextRef) + // Call another contract + Call(me ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) + // Take another's contract code and execute within our own context + CallCode(me ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) + // Create a new contract + Create(me ContractRef, data []byte, gas, price, value *big.Int) ([]byte, common.Address, error) +} + +// Database is a EVM database for full state querying +type Database interface { + GetAccount(common.Address) Account + CreateAccount(common.Address) Account + + AddBalance(common.Address, *big.Int) + GetBalance(common.Address) *big.Int + + GetNonce(common.Address) uint64 + SetNonce(common.Address, uint64) + + GetCode(common.Address) []byte + SetCode(common.Address, []byte) + + AddRefund(*big.Int) + GetRefund() *big.Int + + GetState(common.Address, common.Hash) common.Hash + SetState(common.Address, common.Hash, common.Hash) + + Delete(common.Address) bool + Exist(common.Address) bool + IsDeleted(common.Address) bool } // StructLog is emited to the Environment each cycle and lists information about the curent internal state @@ -68,18 +115,10 @@ type StructLog struct { type Account interface { SubBalance(amount *big.Int) AddBalance(amount *big.Int) + SetBalance(*big.Int) + SetNonce(uint64) Balance() *big.Int Address() common.Address -} - -// generic transfer method -func Transfer(from, to Account, amount *big.Int) error { - if from.Balance().Cmp(amount) < 0 { - return errors.New("Insufficient balance in account") - } - - from.SubBalance(amount) - to.AddBalance(amount) - - return nil + ReturnGas(*big.Int, *big.Int) + SetCode([]byte) } diff --git a/core/vm/gas.go b/core/vm/gas.go index b2f068e6e..bff0ac91b 100644 --- a/core/vm/gas.go +++ b/core/vm/gas.go @@ -37,6 +37,7 @@ var ( GasContractByte = big.NewInt(200) ) +// baseCheck checks for any stack error underflows func baseCheck(op OpCode, stack *stack, gas *big.Int) error { // PUSH and DUP are a bit special. They all cost the same but we do want to have checking on stack push limit // PUSH is also allowed to calculate the same price for all PUSHes @@ -63,6 +64,7 @@ func baseCheck(op OpCode, stack *stack, gas *big.Int) error { return nil } +// casts a arbitrary number to the amount of words (sets of 32 bytes) func toWordSize(size *big.Int) *big.Int { tmp := new(big.Int) tmp.Add(size, u256(31)) diff --git a/core/vm/instructions.go b/core/vm/instructions.go index aa0117cc8..6c6039f74 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -20,46 +20,52 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/params" ) -type instrFn func(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) -type instrExFn func(instr instruction, ret *big.Int, env Environment, context *Context, memory *Memory, stack *stack) +type programInstruction interface { + Do(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) +} + +type instrFn func(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) + +// Do executes the function. This implements programInstruction +func (fn instrFn) Do(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + fn(instr, pc, env, contract, memory, stack) +} type instruction struct { - op OpCode - pc uint64 - fn instrFn - specFn instrExFn - data *big.Int + op OpCode + pc uint64 + fn instrFn + data *big.Int gas *big.Int spop int spush int } -func opStaticJump(instr instruction, ret *big.Int, env Environment, context *Context, memory *Memory, stack *stack) { +func opStaticJump(instr instruction, pc *uint64, ret *big.Int, env Environment, contract *Contract, memory *Memory, stack *stack) { ret.Set(instr.data) } -func opAdd(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opAdd(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(U256(x.Add(x, y))) } -func opSub(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSub(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(U256(x.Sub(x, y))) } -func opMul(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMul(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(U256(x.Mul(x, y))) } -func opDiv(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opDiv(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() if y.Cmp(common.Big0) != 0 { stack.push(U256(x.Div(x, y))) @@ -68,7 +74,7 @@ func opDiv(instr instruction, env Environment, context *Context, memory *Memory, } } -func opSdiv(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSdiv(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := S256(stack.pop()), S256(stack.pop()) if y.Cmp(common.Big0) == 0 { stack.push(new(big.Int)) @@ -88,7 +94,7 @@ func opSdiv(instr instruction, env Environment, context *Context, memory *Memory } } -func opMod(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMod(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() if y.Cmp(common.Big0) == 0 { stack.push(new(big.Int)) @@ -97,7 +103,7 @@ func opMod(instr instruction, env Environment, context *Context, memory *Memory, } } -func opSmod(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSmod(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := S256(stack.pop()), S256(stack.pop()) if y.Cmp(common.Big0) == 0 { @@ -117,12 +123,12 @@ func opSmod(instr instruction, env Environment, context *Context, memory *Memory } } -func opExp(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opExp(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(U256(x.Exp(x, y, Pow256))) } -func opSignExtend(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSignExtend(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { back := stack.pop() if back.Cmp(big.NewInt(31)) < 0 { bit := uint(back.Uint64()*8 + 7) @@ -139,12 +145,12 @@ func opSignExtend(instr instruction, env Environment, context *Context, memory * } } -func opNot(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opNot(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x := stack.pop() stack.push(U256(x.Not(x))) } -func opLt(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opLt(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() if x.Cmp(y) < 0 { stack.push(big.NewInt(1)) @@ -153,7 +159,7 @@ func opLt(instr instruction, env Environment, context *Context, memory *Memory, } } -func opGt(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opGt(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() if x.Cmp(y) > 0 { stack.push(big.NewInt(1)) @@ -162,7 +168,7 @@ func opGt(instr instruction, env Environment, context *Context, memory *Memory, } } -func opSlt(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSlt(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := S256(stack.pop()), S256(stack.pop()) if x.Cmp(S256(y)) < 0 { stack.push(big.NewInt(1)) @@ -171,7 +177,7 @@ func opSlt(instr instruction, env Environment, context *Context, memory *Memory, } } -func opSgt(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSgt(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := S256(stack.pop()), S256(stack.pop()) if x.Cmp(y) > 0 { stack.push(big.NewInt(1)) @@ -180,7 +186,7 @@ func opSgt(instr instruction, env Environment, context *Context, memory *Memory, } } -func opEq(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opEq(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() if x.Cmp(y) == 0 { stack.push(big.NewInt(1)) @@ -189,7 +195,7 @@ func opEq(instr instruction, env Environment, context *Context, memory *Memory, } } -func opIszero(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opIszero(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x := stack.pop() if x.Cmp(common.Big0) > 0 { stack.push(new(big.Int)) @@ -198,19 +204,19 @@ func opIszero(instr instruction, env Environment, context *Context, memory *Memo } } -func opAnd(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opAnd(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(x.And(x, y)) } -func opOr(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opOr(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(x.Or(x, y)) } -func opXor(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opXor(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y := stack.pop(), stack.pop() stack.push(x.Xor(x, y)) } -func opByte(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opByte(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { th, val := stack.pop(), stack.pop() if th.Cmp(big.NewInt(32)) < 0 { byte := big.NewInt(int64(common.LeftPadBytes(val.Bytes(), 32)[th.Int64()])) @@ -219,7 +225,7 @@ func opByte(instr instruction, env Environment, context *Context, memory *Memory stack.push(new(big.Int)) } } -func opAddmod(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opAddmod(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y, z := stack.pop(), stack.pop(), stack.pop() if z.Cmp(Zero) > 0 { add := x.Add(x, y) @@ -229,7 +235,7 @@ func opAddmod(instr instruction, env Environment, context *Context, memory *Memo stack.push(new(big.Int)) } } -func opMulmod(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMulmod(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { x, y, z := stack.pop(), stack.pop(), stack.pop() if z.Cmp(Zero) > 0 { mul := x.Mul(x, y) @@ -240,92 +246,92 @@ func opMulmod(instr instruction, env Environment, context *Context, memory *Memo } } -func opSha3(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSha3(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { offset, size := stack.pop(), stack.pop() hash := crypto.Sha3(memory.Get(offset.Int64(), size.Int64())) stack.push(common.BytesToBig(hash)) } -func opAddress(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(common.Bytes2Big(context.Address().Bytes())) +func opAddress(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(common.Bytes2Big(contract.Address().Bytes())) } -func opBalance(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opBalance(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { addr := common.BigToAddress(stack.pop()) - balance := env.State().GetBalance(addr) + balance := env.Db().GetBalance(addr) stack.push(new(big.Int).Set(balance)) } -func opOrigin(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opOrigin(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(env.Origin().Big()) } -func opCaller(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(common.Bytes2Big(context.caller.Address().Bytes())) +func opCaller(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(common.Bytes2Big(contract.caller.Address().Bytes())) } -func opCallValue(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(new(big.Int).Set(context.value)) +func opCallValue(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(new(big.Int).Set(contract.value)) } -func opCalldataLoad(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(common.Bytes2Big(getData(context.Input, stack.pop(), common.Big32))) +func opCalldataLoad(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(common.Bytes2Big(getData(contract.Input, stack.pop(), common.Big32))) } -func opCalldataSize(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(big.NewInt(int64(len(context.Input)))) +func opCalldataSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(big.NewInt(int64(len(contract.Input)))) } -func opCalldataCopy(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCalldataCopy(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { var ( mOff = stack.pop() cOff = stack.pop() l = stack.pop() ) - memory.Set(mOff.Uint64(), l.Uint64(), getData(context.Input, cOff, l)) + memory.Set(mOff.Uint64(), l.Uint64(), getData(contract.Input, cOff, l)) } -func opExtCodeSize(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { addr := common.BigToAddress(stack.pop()) - l := big.NewInt(int64(len(env.State().GetCode(addr)))) + l := big.NewInt(int64(len(env.Db().GetCode(addr)))) stack.push(l) } -func opCodeSize(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - l := big.NewInt(int64(len(context.Code))) +func opCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + l := big.NewInt(int64(len(contract.Code))) stack.push(l) } -func opCodeCopy(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCodeCopy(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { var ( mOff = stack.pop() cOff = stack.pop() l = stack.pop() ) - codeCopy := getData(context.Code, cOff, l) + codeCopy := getData(contract.Code, cOff, l) memory.Set(mOff.Uint64(), l.Uint64(), codeCopy) } -func opExtCodeCopy(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opExtCodeCopy(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { var ( addr = common.BigToAddress(stack.pop()) mOff = stack.pop() cOff = stack.pop() l = stack.pop() ) - codeCopy := getData(env.State().GetCode(addr), cOff, l) + codeCopy := getData(env.Db().GetCode(addr), cOff, l) memory.Set(mOff.Uint64(), l.Uint64(), codeCopy) } -func opGasprice(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(new(big.Int).Set(context.Price)) +func opGasprice(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(new(big.Int).Set(contract.Price)) } -func opBlockhash(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opBlockhash(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { num := stack.pop() n := new(big.Int).Sub(env.BlockNumber(), common.Big257) @@ -336,43 +342,43 @@ func opBlockhash(instr instruction, env Environment, context *Context, memory *M } } -func opCoinbase(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCoinbase(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(env.Coinbase().Big()) } -func opTimestamp(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opTimestamp(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(U256(new(big.Int).Set(env.Time()))) } -func opNumber(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opNumber(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(U256(new(big.Int).Set(env.BlockNumber()))) } -func opDifficulty(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opDifficulty(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(U256(new(big.Int).Set(env.Difficulty()))) } -func opGasLimit(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opGasLimit(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(U256(new(big.Int).Set(env.GasLimit()))) } -func opPop(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opPop(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.pop() } -func opPush(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opPush(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(new(big.Int).Set(instr.data)) } -func opDup(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opDup(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.dup(int(instr.data.Int64())) } -func opSwap(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSwap(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.swap(int(instr.data.Int64())) } -func opLog(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opLog(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { n := int(instr.data.Int64()) topics := make([]common.Hash, n) mStart, mSize := stack.pop(), stack.pop() @@ -381,85 +387,88 @@ func opLog(instr instruction, env Environment, context *Context, memory *Memory, } d := memory.Get(mStart.Int64(), mSize.Int64()) - log := state.NewLog(context.Address(), topics, d, env.BlockNumber().Uint64()) + log := NewLog(contract.Address(), topics, d, env.BlockNumber().Uint64()) env.AddLog(log) } -func opMload(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMload(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { offset := stack.pop() val := common.BigD(memory.Get(offset.Int64(), 32)) stack.push(val) } -func opMstore(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMstore(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { // pop value of the stack mStart, val := stack.pop(), stack.pop() memory.Set(mStart.Uint64(), 32, common.BigToBytes(val, 256)) } -func opMstore8(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMstore8(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { off, val := stack.pop().Int64(), stack.pop().Int64() memory.store[off] = byte(val & 0xff) } -func opSload(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSload(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { loc := common.BigToHash(stack.pop()) - val := env.State().GetState(context.Address(), loc).Big() + val := env.Db().GetState(contract.Address(), loc).Big() stack.push(val) } -func opSstore(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opSstore(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { loc := common.BigToHash(stack.pop()) val := stack.pop() - env.State().SetState(context.Address(), loc, common.BigToHash(val)) + env.Db().SetState(contract.Address(), loc, common.BigToHash(val)) } -func opJump(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) {} -func opJumpi(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) {} -func opJumpdest(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) {} +func opJump(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { +} +func opJumpi(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { +} +func opJumpdest(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { +} -func opPc(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opPc(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(new(big.Int).Set(instr.data)) } -func opMsize(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opMsize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { stack.push(big.NewInt(int64(memory.Len()))) } -func opGas(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - stack.push(new(big.Int).Set(context.Gas)) +func opGas(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.push(new(big.Int).Set(contract.Gas)) } -func opCreate(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCreate(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { var ( value = stack.pop() offset, size = stack.pop(), stack.pop() input = memory.Get(offset.Int64(), size.Int64()) - gas = new(big.Int).Set(context.Gas) + gas = new(big.Int).Set(contract.Gas) addr common.Address + ret []byte + suberr error ) - context.UseGas(context.Gas) - ret, suberr, ref := env.Create(context, input, gas, context.Price, value) + contract.UseGas(contract.Gas) + ret, addr, suberr = env.Create(contract, input, gas, contract.Price, value) if suberr != nil { stack.push(new(big.Int)) - } else { // gas < len(ret) * Createinstr.dataGas == NO_CODE dataGas := big.NewInt(int64(len(ret))) dataGas.Mul(dataGas, params.CreateDataGas) - if context.UseGas(dataGas) { - ref.SetCode(ret) + if contract.UseGas(dataGas) { + env.Db().SetCode(addr, ret) } - addr = ref.Address() stack.push(addr.Big()) } } -func opCall(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCall(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { gas := stack.pop() // pop gas and value of the stack. addr, value := stack.pop(), stack.pop() @@ -478,7 +487,7 @@ func opCall(instr instruction, env Environment, context *Context, memory *Memory gas.Add(gas, params.CallStipend) } - ret, err := env.Call(context, address, args, gas, context.Price, value) + ret, err := env.Call(contract, address, args, gas, contract.Price, value) if err != nil { stack.push(new(big.Int)) @@ -490,7 +499,7 @@ func opCall(instr instruction, env Environment, context *Context, memory *Memory } } -func opCallCode(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { +func opCallCode(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { gas := stack.pop() // pop gas and value of the stack. addr, value := stack.pop(), stack.pop() @@ -509,7 +518,7 @@ func opCallCode(instr instruction, env Environment, context *Context, memory *Me gas.Add(gas, params.CallStipend) } - ret, err := env.CallCode(context, address, args, gas, context.Price, value) + ret, err := env.CallCode(contract, address, args, gas, contract.Price, value) if err != nil { stack.push(new(big.Int)) @@ -521,14 +530,58 @@ func opCallCode(instr instruction, env Environment, context *Context, memory *Me } } -func opReturn(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) {} -func opStop(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) {} +func opReturn(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { +} +func opStop(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { +} + +func opSuicide(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + //receiver := env.Db().GetOrNewStateObject(common.BigToAddress(stack.pop())) + //receiver.AddBalance(balance) + balance := env.Db().GetBalance(contract.Address()) + env.Db().AddBalance(common.BigToAddress(stack.pop()), balance) + + env.Db().Delete(contract.Address()) +} + +// following functions are used by the instruction jump table + +// make log instruction function +func makeLog(size int) instrFn { + return func(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + topics := make([]common.Hash, size) + mStart, mSize := stack.pop(), stack.pop() + for i := 0; i < size; i++ { + topics[i] = common.BigToHash(stack.pop()) + } + + d := memory.Get(mStart.Int64(), mSize.Int64()) + log := NewLog(contract.Address(), topics, d, env.BlockNumber().Uint64()) + env.AddLog(log) + } +} -func opSuicide(instr instruction, env Environment, context *Context, memory *Memory, stack *stack) { - receiver := env.State().GetOrNewStateObject(common.BigToAddress(stack.pop())) - balance := env.State().GetBalance(context.Address()) +// make push instruction function +func makePush(size uint64, bsize *big.Int) instrFn { + return func(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + byts := getData(contract.Code, new(big.Int).SetUint64(*pc+1), bsize) + stack.push(common.Bytes2Big(byts)) + *pc += size + } +} - receiver.AddBalance(balance) +// make push instruction function +func makeDup(size int64) instrFn { + return func(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.dup(int(size)) + } +} - env.State().Delete(context.Address()) +// make swap instruction function +func makeSwap(size int64) instrFn { + // switch n + 1 otherwise n would be swapped with n + size += 1 + return func(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *stack) { + stack.swap(int(size)) + } } diff --git a/core/vm/jit.go b/core/vm/jit.go index 084d2a3f3..6ad574917 100644 --- a/core/vm/jit.go +++ b/core/vm/jit.go @@ -20,10 +20,12 @@ import ( "fmt" "math/big" "sync/atomic" + "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/params" "github.com/hashicorp/golang-lru" ) @@ -35,6 +37,14 @@ const ( progCompile progReady progError + + defaultJitMaxCache int = 64 +) + +var ( + EnableJit bool // Enables the JIT VM + ForceJit bool // Force the JIT, skip byte VM + MaxProgSize int // Max cache size for JIT Programs ) var programs *lru.Cache @@ -74,7 +84,7 @@ type Program struct { Id common.Hash // Id of the program status int32 // status should be accessed atomically - context *Context + contract *Contract instructions []instruction // instruction set mapping map[uint64]int // real PC mapping to array indices @@ -108,7 +118,7 @@ func (p *Program) addInstr(op OpCode, pc uint64, fn instrFn, data *big.Int) { baseOp = DUP1 } base := _baseCheck[baseOp] - instr := instruction{op, pc, fn, nil, data, base.gas, base.stackPop, base.stackPush} + instr := instruction{op, pc, fn, data, base.gas, base.stackPop, base.stackPush} p.instructions = append(p.instructions, instr) p.mapping[pc] = len(p.instructions) - 1 @@ -127,6 +137,13 @@ func CompileProgram(program *Program) (err error) { atomic.StoreInt32(&program.status, int32(progReady)) } }() + if glog.V(logger.Debug) { + glog.Infof("compiling %x\n", program.Id[:4]) + tstart := time.Now() + defer func() { + glog.Infof("compiled %x instrc: %d time: %v\n", program.Id[:4], len(program.instructions), time.Since(tstart)) + }() + } // loop thru the opcodes and "compile" in to instructions for pc := uint64(0); pc < uint64(len(program.code)); pc++ { @@ -264,7 +281,7 @@ func CompileProgram(program *Program) (err error) { program.addInstr(op, pc, opReturn, nil) case SUICIDE: program.addInstr(op, pc, opSuicide, nil) - case STOP: // Stop the context + case STOP: // Stop the contract program.addInstr(op, pc, opStop, nil) default: program.addInstr(op, pc, nil, nil) @@ -274,23 +291,24 @@ func CompileProgram(program *Program) (err error) { return nil } -// RunProgram runs the program given the enviroment and context and returns an +// RunProgram runs the program given the enviroment and contract and returns an // error if the execution failed (non-consensus) -func RunProgram(program *Program, env Environment, context *Context, input []byte) ([]byte, error) { - return runProgram(program, 0, NewMemory(), newstack(), env, context, input) +func RunProgram(program *Program, env Environment, contract *Contract, input []byte) ([]byte, error) { + return runProgram(program, 0, NewMemory(), newstack(), env, contract, input) } -func runProgram(program *Program, pcstart uint64, mem *Memory, stack *stack, env Environment, context *Context, input []byte) ([]byte, error) { - context.Input = input +func runProgram(program *Program, pcstart uint64, mem *Memory, stack *stack, env Environment, contract *Contract, input []byte) ([]byte, error) { + contract.Input = input var ( - caller = context.caller - statedb = env.State() - pc int = program.mapping[pcstart] + caller = contract.caller + statedb = env.Db() + pc int = program.mapping[pcstart] + instrCount = 0 jump = func(to *big.Int) error { if !validDest(program.destinations, to) { - nop := context.GetOp(to.Uint64()) + nop := contract.GetOp(to.Uint64()) return fmt.Errorf("invalid jump destination (%v) %v", nop, to) } @@ -300,18 +318,28 @@ func runProgram(program *Program, pcstart uint64, mem *Memory, stack *stack, env } ) + if glog.V(logger.Debug) { + glog.Infof("running JIT program %x\n", program.Id[:4]) + tstart := time.Now() + defer func() { + glog.Infof("JIT program %x done. time: %v instrc: %v\n", program.Id[:4], time.Since(tstart), instrCount) + }() + } + for pc < len(program.instructions) { + instrCount++ + instr := program.instructions[pc] // calculate the new memory size and gas price for the current executing opcode - newMemSize, cost, err := jitCalculateGasAndSize(env, context, caller, instr, statedb, mem, stack) + newMemSize, cost, err := jitCalculateGasAndSize(env, contract, caller, instr, statedb, mem, stack) if err != nil { return nil, err } // Use the calculated gas. When insufficient gas is present, use all gas and return an // Out Of Gas error - if !context.UseGas(cost) { + if !contract.UseGas(cost) { return nil, OutOfGasError } // Resize the memory calculated previously @@ -338,27 +366,27 @@ func runProgram(program *Program, pcstart uint64, mem *Memory, stack *stack, env offset, size := stack.pop(), stack.pop() ret := mem.GetPtr(offset.Int64(), size.Int64()) - return context.Return(ret), nil + return contract.Return(ret), nil case SUICIDE: - instr.fn(instr, env, context, mem, stack) + instr.fn(instr, nil, env, contract, mem, stack) - return context.Return(nil), nil + return contract.Return(nil), nil case STOP: - return context.Return(nil), nil + return contract.Return(nil), nil default: if instr.fn == nil { return nil, fmt.Errorf("Invalid opcode %x", instr.op) } - instr.fn(instr, env, context, mem, stack) + instr.fn(instr, nil, env, contract, mem, stack) } pc++ } - context.Input = nil + contract.Input = nil - return context.Return(nil), nil + return contract.Return(nil), nil } // validDest checks if the given distination is a valid one given the @@ -375,7 +403,7 @@ func validDest(dests map[uint64]struct{}, dest *big.Int) bool { // jitCalculateGasAndSize calculates the required given the opcode and stack items calculates the new memorysize for // the operation. This does not reduce gas or resizes the memory. -func jitCalculateGasAndSize(env Environment, context *Context, caller ContextRef, instr instruction, statedb *state.StateDB, mem *Memory, stack *stack) (*big.Int, *big.Int, error) { +func jitCalculateGasAndSize(env Environment, contract *Contract, caller ContractRef, instr instruction, statedb Database, mem *Memory, stack *stack) (*big.Int, *big.Int, error) { var ( gas = new(big.Int) newMemSize *big.Int = new(big.Int) @@ -426,27 +454,25 @@ func jitCalculateGasAndSize(env Environment, context *Context, caller ContextRef var g *big.Int y, x := stack.data[stack.len()-2], stack.data[stack.len()-1] - val := statedb.GetState(context.Address(), common.BigToHash(x)) + val := statedb.GetState(contract.Address(), common.BigToHash(x)) // This checks for 3 scenario's and calculates gas accordingly // 1. From a zero-value address to a non-zero value (NEW VALUE) // 2. From a non-zero value address to a zero-value address (DELETE) // 3. From a nen-zero to a non-zero (CHANGE) if common.EmptyHash(val) && !common.EmptyHash(common.BigToHash(y)) { - // 0 => non 0 g = params.SstoreSetGas } else if !common.EmptyHash(val) && common.EmptyHash(common.BigToHash(y)) { - statedb.Refund(params.SstoreRefundGas) + statedb.AddRefund(params.SstoreRefundGas) g = params.SstoreClearGas } else { - // non 0 => non 0 (or 0 => 0) g = params.SstoreClearGas } gas.Set(g) case SUICIDE: - if !statedb.IsDeleted(context.Address()) { - statedb.Refund(params.SuicideRefundGas) + if !statedb.IsDeleted(contract.Address()) { + statedb.AddRefund(params.SuicideRefundGas) } case MLOAD: newMemSize = calcMemSize(stack.peek(), u256(32)) @@ -483,7 +509,8 @@ func jitCalculateGasAndSize(env Environment, context *Context, caller ContextRef gas.Add(gas, stack.data[stack.len()-1]) if op == CALL { - if env.State().GetStateObject(common.BigToAddress(stack.data[stack.len()-2])) == nil { + //if env.Db().GetStateObject(common.BigToAddress(stack.data[stack.len()-2])) == nil { + if !env.Db().Exist(common.BigToAddress(stack.data[stack.len()-2])) { gas.Add(gas, params.CallNewAccountGas) } } @@ -497,29 +524,7 @@ func jitCalculateGasAndSize(env Environment, context *Context, caller ContextRef newMemSize = common.BigMax(x, y) } - - if newMemSize.Cmp(common.Big0) > 0 { - newMemSizeWords := toWordSize(newMemSize) - newMemSize.Mul(newMemSizeWords, u256(32)) - - if newMemSize.Cmp(u256(int64(mem.Len()))) > 0 { - // be careful reusing variables here when changing. - // The order has been optimised to reduce allocation - oldSize := toWordSize(big.NewInt(int64(mem.Len()))) - pow := new(big.Int).Exp(oldSize, common.Big2, Zero) - linCoef := oldSize.Mul(oldSize, params.MemoryGas) - quadCoef := new(big.Int).Div(pow, params.QuadCoeffDiv) - oldTotalFee := new(big.Int).Add(linCoef, quadCoef) - - pow.Exp(newMemSizeWords, common.Big2, Zero) - linCoef = linCoef.Mul(newMemSizeWords, params.MemoryGas) - quadCoef = quadCoef.Div(pow, params.QuadCoeffDiv) - newTotalFee := linCoef.Add(linCoef, quadCoef) - - fee := newTotalFee.Sub(newTotalFee, oldTotalFee) - gas.Add(gas, fee) - } - } + quadMemGas(mem, newMemSize, gas) return newMemSize, gas, nil } diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go index d8e442637..8c45f2ce7 100644 --- a/core/vm/jit_test.go +++ b/core/vm/jit_test.go @@ -21,13 +21,56 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" ) const maxRun = 1000 +func TestCompiling(t *testing.T) { + prog := NewProgram([]byte{0x60, 0x10}) + err := CompileProgram(prog) + if err != nil { + t.Error("didn't expect compile error") + } + + if len(prog.instructions) != 1 { + t.Error("exected 1 compiled instruction, got", len(prog.instructions)) + } +} + +func TestResetInput(t *testing.T) { + var sender account + + env := NewEnv() + contract := NewContract(sender, sender, big.NewInt(100), big.NewInt(10000), big.NewInt(0)) + contract.CodeAddr = &common.Address{} + + program := NewProgram([]byte{}) + RunProgram(program, env, contract, []byte{0xbe, 0xef}) + if contract.Input != nil { + t.Errorf("expected input to be nil, got %x", contract.Input) + } +} + +func TestPcMappingToInstruction(t *testing.T) { + program := NewProgram([]byte{byte(PUSH2), 0xbe, 0xef, byte(ADD)}) + CompileProgram(program) + if program.mapping[3] != 1 { + t.Error("expected mapping PC 4 to me instr no. 2, got", program.mapping[4]) + } +} + +var benchmarks = map[string]vmBench{ + "pushes": vmBench{ + false, false, false, + common.Hex2Bytes("600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01"), nil, + }, +} + +func BenchmarkPushes(b *testing.B) { + runVmBench(benchmarks["pushes"], b) +} + type vmBench struct { precompile bool // compile prior to executing nojit bool // ignore jit (sets DisbaleJit = true @@ -37,9 +80,19 @@ type vmBench struct { input []byte } +type account struct{} + +func (account) SubBalance(amount *big.Int) {} +func (account) AddBalance(amount *big.Int) {} +func (account) SetBalance(*big.Int) {} +func (account) SetNonce(uint64) {} +func (account) Balance() *big.Int { return nil } +func (account) Address() common.Address { return common.Address{} } +func (account) ReturnGas(*big.Int, *big.Int) {} +func (account) SetCode([]byte) {} + func runVmBench(test vmBench, b *testing.B) { - db, _ := ethdb.NewMemDatabase() - sender := state.NewStateObject(common.Address{}, db) + var sender account if test.precompile && !test.forcejit { NewProgram(test.code) @@ -52,7 +105,7 @@ func runVmBench(test vmBench, b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - context := NewContext(sender, sender, big.NewInt(100), big.NewInt(10000), big.NewInt(0)) + context := NewContract(sender, sender, big.NewInt(100), big.NewInt(10000), big.NewInt(0)) context.Code = test.code context.CodeAddr = &common.Address{} _, err := New(env).Run(context, test.input) @@ -63,17 +116,6 @@ func runVmBench(test vmBench, b *testing.B) { } } -var benchmarks = map[string]vmBench{ - "pushes": vmBench{ - false, false, false, - common.Hex2Bytes("600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01600a600a01"), nil, - }, -} - -func BenchmarkPushes(b *testing.B) { - runVmBench(benchmarks["pushes"], b) -} - type Env struct { gasLimit *big.Int depth int @@ -93,30 +135,32 @@ func (self *Env) StructLogs() []StructLog { //func (self *Env) PrevHash() []byte { return self.parent } func (self *Env) Coinbase() common.Address { return common.Address{} } +func (self *Env) MakeSnapshot() Database { return nil } +func (self *Env) SetSnapshot(Database) {} func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } -func (self *Env) State() *state.StateDB { return nil } +func (self *Env) Db() Database { return nil } func (self *Env) GasLimit() *big.Int { return self.gasLimit } func (self *Env) VmType() Type { return StdVmTy } func (self *Env) GetHash(n uint64) common.Hash { return common.BytesToHash(crypto.Sha3([]byte(big.NewInt(int64(n)).String()))) } -func (self *Env) AddLog(log *state.Log) { +func (self *Env) AddLog(log *Log) { } func (self *Env) Depth() int { return self.depth } func (self *Env) SetDepth(i int) { self.depth = i } -func (self *Env) CanTransfer(from Account, balance *big.Int) bool { - return from.Balance().Cmp(balance) >= 0 +func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { + return true } func (self *Env) Transfer(from, to Account, amount *big.Int) error { return nil } -func (self *Env) Call(caller ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { +func (self *Env) Call(caller ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { return nil, nil } -func (self *Env) CallCode(caller ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { +func (self *Env) CallCode(caller ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { return nil, nil } -func (self *Env) Create(caller ContextRef, data []byte, gas, price, value *big.Int) ([]byte, error, ContextRef) { - return nil, nil, nil +func (self *Env) Create(caller ContractRef, data []byte, gas, price, value *big.Int) ([]byte, common.Address, error) { + return nil, common.Address{}, nil } diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go new file mode 100644 index 000000000..ab899647f --- /dev/null +++ b/core/vm/jump_table.go @@ -0,0 +1,143 @@ +package vm + +import "math/big" + +type jumpPtr struct { + fn instrFn + valid bool +} + +var jumpTable [256]jumpPtr + +func init() { + jumpTable[ADD] = jumpPtr{opAdd, true} + jumpTable[SUB] = jumpPtr{opSub, true} + jumpTable[MUL] = jumpPtr{opMul, true} + jumpTable[DIV] = jumpPtr{opDiv, true} + jumpTable[SDIV] = jumpPtr{opSdiv, true} + jumpTable[MOD] = jumpPtr{opMod, true} + jumpTable[SMOD] = jumpPtr{opSmod, true} + jumpTable[EXP] = jumpPtr{opExp, true} + jumpTable[SIGNEXTEND] = jumpPtr{opSignExtend, true} + jumpTable[NOT] = jumpPtr{opNot, true} + jumpTable[LT] = jumpPtr{opLt, true} + jumpTable[GT] = jumpPtr{opGt, true} + jumpTable[SLT] = jumpPtr{opSlt, true} + jumpTable[SGT] = jumpPtr{opSgt, true} + jumpTable[EQ] = jumpPtr{opEq, true} + jumpTable[ISZERO] = jumpPtr{opIszero, true} + jumpTable[AND] = jumpPtr{opAnd, true} + jumpTable[OR] = jumpPtr{opOr, true} + jumpTable[XOR] = jumpPtr{opXor, true} + jumpTable[BYTE] = jumpPtr{opByte, true} + jumpTable[ADDMOD] = jumpPtr{opAddmod, true} + jumpTable[MULMOD] = jumpPtr{opMulmod, true} + jumpTable[SHA3] = jumpPtr{opSha3, true} + jumpTable[ADDRESS] = jumpPtr{opAddress, true} + jumpTable[BALANCE] = jumpPtr{opBalance, true} + jumpTable[ORIGIN] = jumpPtr{opOrigin, true} + jumpTable[CALLER] = jumpPtr{opCaller, true} + jumpTable[CALLVALUE] = jumpPtr{opCallValue, true} + jumpTable[CALLDATALOAD] = jumpPtr{opCalldataLoad, true} + jumpTable[CALLDATASIZE] = jumpPtr{opCalldataSize, true} + jumpTable[CALLDATACOPY] = jumpPtr{opCalldataCopy, true} + jumpTable[CODESIZE] = jumpPtr{opCodeSize, true} + jumpTable[EXTCODESIZE] = jumpPtr{opExtCodeSize, true} + jumpTable[CODECOPY] = jumpPtr{opCodeCopy, true} + jumpTable[EXTCODECOPY] = jumpPtr{opExtCodeCopy, true} + jumpTable[GASPRICE] = jumpPtr{opGasprice, true} + jumpTable[BLOCKHASH] = jumpPtr{opBlockhash, true} + jumpTable[COINBASE] = jumpPtr{opCoinbase, true} + jumpTable[TIMESTAMP] = jumpPtr{opTimestamp, true} + jumpTable[NUMBER] = jumpPtr{opNumber, true} + jumpTable[DIFFICULTY] = jumpPtr{opDifficulty, true} + jumpTable[GASLIMIT] = jumpPtr{opGasLimit, true} + jumpTable[POP] = jumpPtr{opPop, true} + jumpTable[MLOAD] = jumpPtr{opMload, true} + jumpTable[MSTORE] = jumpPtr{opMstore, true} + jumpTable[MSTORE8] = jumpPtr{opMstore8, true} + jumpTable[SLOAD] = jumpPtr{opSload, true} + jumpTable[SSTORE] = jumpPtr{opSstore, true} + jumpTable[JUMPDEST] = jumpPtr{opJumpdest, true} + jumpTable[PC] = jumpPtr{nil, true} + jumpTable[MSIZE] = jumpPtr{opMsize, true} + jumpTable[GAS] = jumpPtr{opGas, true} + jumpTable[CREATE] = jumpPtr{opCreate, true} + jumpTable[CALL] = jumpPtr{opCall, true} + jumpTable[CALLCODE] = jumpPtr{opCallCode, true} + jumpTable[LOG0] = jumpPtr{makeLog(0), true} + jumpTable[LOG1] = jumpPtr{makeLog(1), true} + jumpTable[LOG2] = jumpPtr{makeLog(2), true} + jumpTable[LOG3] = jumpPtr{makeLog(3), true} + jumpTable[LOG4] = jumpPtr{makeLog(4), true} + jumpTable[SWAP1] = jumpPtr{makeSwap(1), true} + jumpTable[SWAP2] = jumpPtr{makeSwap(2), true} + jumpTable[SWAP3] = jumpPtr{makeSwap(3), true} + jumpTable[SWAP4] = jumpPtr{makeSwap(4), true} + jumpTable[SWAP5] = jumpPtr{makeSwap(5), true} + jumpTable[SWAP6] = jumpPtr{makeSwap(6), true} + jumpTable[SWAP7] = jumpPtr{makeSwap(7), true} + jumpTable[SWAP8] = jumpPtr{makeSwap(8), true} + jumpTable[SWAP9] = jumpPtr{makeSwap(9), true} + jumpTable[SWAP10] = jumpPtr{makeSwap(10), true} + jumpTable[SWAP11] = jumpPtr{makeSwap(11), true} + jumpTable[SWAP12] = jumpPtr{makeSwap(12), true} + jumpTable[SWAP13] = jumpPtr{makeSwap(13), true} + jumpTable[SWAP14] = jumpPtr{makeSwap(14), true} + jumpTable[SWAP15] = jumpPtr{makeSwap(15), true} + jumpTable[SWAP16] = jumpPtr{makeSwap(16), true} + jumpTable[PUSH1] = jumpPtr{makePush(1, big.NewInt(1)), true} + jumpTable[PUSH2] = jumpPtr{makePush(2, big.NewInt(2)), true} + jumpTable[PUSH3] = jumpPtr{makePush(3, big.NewInt(3)), true} + jumpTable[PUSH4] = jumpPtr{makePush(4, big.NewInt(4)), true} + jumpTable[PUSH5] = jumpPtr{makePush(5, big.NewInt(5)), true} + jumpTable[PUSH6] = jumpPtr{makePush(6, big.NewInt(6)), true} + jumpTable[PUSH7] = jumpPtr{makePush(7, big.NewInt(7)), true} + jumpTable[PUSH8] = jumpPtr{makePush(8, big.NewInt(8)), true} + jumpTable[PUSH9] = jumpPtr{makePush(9, big.NewInt(9)), true} + jumpTable[PUSH10] = jumpPtr{makePush(10, big.NewInt(10)), true} + jumpTable[PUSH11] = jumpPtr{makePush(11, big.NewInt(11)), true} + jumpTable[PUSH12] = jumpPtr{makePush(12, big.NewInt(12)), true} + jumpTable[PUSH13] = jumpPtr{makePush(13, big.NewInt(13)), true} + jumpTable[PUSH14] = jumpPtr{makePush(14, big.NewInt(14)), true} + jumpTable[PUSH15] = jumpPtr{makePush(15, big.NewInt(15)), true} + jumpTable[PUSH16] = jumpPtr{makePush(16, big.NewInt(16)), true} + jumpTable[PUSH17] = jumpPtr{makePush(17, big.NewInt(17)), true} + jumpTable[PUSH18] = jumpPtr{makePush(18, big.NewInt(18)), true} + jumpTable[PUSH19] = jumpPtr{makePush(19, big.NewInt(19)), true} + jumpTable[PUSH20] = jumpPtr{makePush(20, big.NewInt(20)), true} + jumpTable[PUSH21] = jumpPtr{makePush(21, big.NewInt(21)), true} + jumpTable[PUSH22] = jumpPtr{makePush(22, big.NewInt(22)), true} + jumpTable[PUSH23] = jumpPtr{makePush(23, big.NewInt(23)), true} + jumpTable[PUSH24] = jumpPtr{makePush(24, big.NewInt(24)), true} + jumpTable[PUSH25] = jumpPtr{makePush(25, big.NewInt(25)), true} + jumpTable[PUSH26] = jumpPtr{makePush(26, big.NewInt(26)), true} + jumpTable[PUSH27] = jumpPtr{makePush(27, big.NewInt(27)), true} + jumpTable[PUSH28] = jumpPtr{makePush(28, big.NewInt(28)), true} + jumpTable[PUSH29] = jumpPtr{makePush(29, big.NewInt(29)), true} + jumpTable[PUSH30] = jumpPtr{makePush(30, big.NewInt(30)), true} + jumpTable[PUSH31] = jumpPtr{makePush(31, big.NewInt(31)), true} + jumpTable[PUSH32] = jumpPtr{makePush(32, big.NewInt(32)), true} + jumpTable[DUP1] = jumpPtr{makeDup(1), true} + jumpTable[DUP2] = jumpPtr{makeDup(2), true} + jumpTable[DUP3] = jumpPtr{makeDup(3), true} + jumpTable[DUP4] = jumpPtr{makeDup(4), true} + jumpTable[DUP5] = jumpPtr{makeDup(5), true} + jumpTable[DUP6] = jumpPtr{makeDup(6), true} + jumpTable[DUP7] = jumpPtr{makeDup(7), true} + jumpTable[DUP8] = jumpPtr{makeDup(8), true} + jumpTable[DUP9] = jumpPtr{makeDup(9), true} + jumpTable[DUP10] = jumpPtr{makeDup(10), true} + jumpTable[DUP11] = jumpPtr{makeDup(11), true} + jumpTable[DUP12] = jumpPtr{makeDup(12), true} + jumpTable[DUP13] = jumpPtr{makeDup(13), true} + jumpTable[DUP14] = jumpPtr{makeDup(14), true} + jumpTable[DUP15] = jumpPtr{makeDup(15), true} + jumpTable[DUP16] = jumpPtr{makeDup(16), true} + + jumpTable[RETURN] = jumpPtr{nil, true} + jumpTable[SUICIDE] = jumpPtr{nil, true} + jumpTable[JUMP] = jumpPtr{nil, true} + jumpTable[JUMPI] = jumpPtr{nil, true} + jumpTable[STOP] = jumpPtr{nil, true} +} diff --git a/core/state/log.go b/core/vm/log.go index 5d7d7357d..354f0ad35 100644 --- a/core/state/log.go +++ b/core/vm/log.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. -package state +package vm import ( "fmt" diff --git a/core/vm/logger.go b/core/vm/logger.go index 736f595f6..2bd02319f 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" ) +// StdErrFormat formats a slice of StructLogs to human readable format func StdErrFormat(logs []StructLog) { fmt.Fprintf(os.Stderr, "VM STAT %d OPs\n", len(logs)) for _, log := range logs { diff --git a/core/vm/memory.go b/core/vm/memory.go index 0109050d7..d01188417 100644 --- a/core/vm/memory.go +++ b/core/vm/memory.go @@ -18,6 +18,7 @@ package vm import "fmt" +// Memory implements a simple memory model for the ethereum virtual machine. type Memory struct { store []byte } @@ -26,6 +27,7 @@ func NewMemory() *Memory { return &Memory{nil} } +// Set sets offset + size to value func (m *Memory) Set(offset, size uint64, value []byte) { // length of store may never be less than offset + size. // The store should be resized PRIOR to setting the memory @@ -40,12 +42,14 @@ func (m *Memory) Set(offset, size uint64, value []byte) { } } +// Resize resizes the memory to size func (m *Memory) Resize(size uint64) { if uint64(m.Len()) < size { m.store = append(m.store, make([]byte, size-uint64(m.Len()))...) } } +// Get returns offset + size as a new slice func (self *Memory) Get(offset, size int64) (cpy []byte) { if size == 0 { return nil @@ -61,6 +65,7 @@ func (self *Memory) Get(offset, size int64) (cpy []byte) { return } +// GetPtr returns the offset + size func (self *Memory) GetPtr(offset, size int64) []byte { if size == 0 { return nil @@ -73,10 +78,12 @@ func (self *Memory) GetPtr(offset, size int64) []byte { return nil } +// Len returns the length of the backing slice func (m *Memory) Len() int { return len(m.store) } +// Data returns the backing slice func (m *Memory) Data() []byte { return m.store } diff --git a/core/vm/opcodes.go b/core/vm/opcodes.go index ecced3650..986c35ef8 100644 --- a/core/vm/opcodes.go +++ b/core/vm/opcodes.go @@ -20,9 +20,9 @@ import ( "fmt" ) +// OpCode is an EVM opcode type OpCode byte -// Op codes const ( // 0x0 range - arithmetic ops STOP OpCode = iota diff --git a/core/vm/virtual_machine.go b/core/vm/virtual_machine.go index 047723744..9b3340bb2 100644 --- a/core/vm/virtual_machine.go +++ b/core/vm/virtual_machine.go @@ -16,7 +16,8 @@ package vm +// VirtualMachine is an EVM interface type VirtualMachine interface { Env() Environment - Run(context *Context, data []byte) ([]byte, error) + Run(*Contract, []byte) ([]byte, error) } diff --git a/core/vm/vm.go b/core/vm/vm.go index d9e1a0ce5..57dd4dac3 100644 --- a/core/vm/vm.go +++ b/core/vm/vm.go @@ -14,33 +14,32 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. -// Package vm implements the Ethereum Virtual Machine. package vm import ( "fmt" "math/big" + "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/params" ) -// Vm implements VirtualMachine +// Vm is an EVM and implements VirtualMachine type Vm struct { env Environment } -// New returns a new Virtual Machine +// New returns a new Vm func New(env Environment) *Vm { return &Vm{env: env} } // Run loops and evaluates the contract's code with the given input data -func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { +func (self *Vm) Run(contract *Contract, input []byte) (ret []byte, err error) { self.env.SetDepth(self.env.Depth() + 1) defer self.env.SetDepth(self.env.Depth() - 1) @@ -48,42 +47,48 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { defer func() { if err != nil { // In case of a VM exception (known exceptions) all gas consumed (panics NOT included). - context.UseGas(context.Gas) + contract.UseGas(contract.Gas) - ret = context.Return(nil) + ret = contract.Return(nil) } }() - if context.CodeAddr != nil { - if p := Precompiled[context.CodeAddr.Str()]; p != nil { - return self.RunPrecompiled(p, input, context) + if contract.CodeAddr != nil { + if p := Precompiled[contract.CodeAddr.Str()]; p != nil { + return self.RunPrecompiled(p, input, contract) } } + // Don't bother with the execution if there's no code. + if len(contract.Code) == 0 { + return contract.Return(nil), nil + } + var ( - codehash = crypto.Sha3Hash(context.Code) // codehash is used when doing jump dest caching + codehash = crypto.Sha3Hash(contract.Code) // codehash is used when doing jump dest caching program *Program ) if EnableJit { - // Fetch program status. - // * If ready run using JIT - // * If unknown, compile in a seperate goroutine - // * If forced wait for compilation and run once done - if status := GetProgramStatus(codehash); status == progReady { - return RunProgram(GetProgram(codehash), self.env, context, input) - } else if status == progUnknown { + // If the JIT is enabled check the status of the JIT program, + // if it doesn't exist compile a new program in a seperate + // goroutine or wait for compilation to finish if the JIT is + // forced. + switch GetProgramStatus(codehash) { + case progReady: + return RunProgram(GetProgram(codehash), self.env, contract, input) + case progUnknown: if ForceJit { // Create and compile program - program = NewProgram(context.Code) + program = NewProgram(contract.Code) perr := CompileProgram(program) if perr == nil { - return RunProgram(program, self.env, context, input) + return RunProgram(program, self.env, contract, input) } glog.V(logger.Info).Infoln("error compiling program", err) } else { // create and compile the program. Compilation // is done in a seperate goroutine - program = NewProgram(context.Code) + program = NewProgram(contract.Code) go func() { err := CompileProgram(program) if err != nil { @@ -96,15 +101,14 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { } var ( - caller = context.caller - code = context.Code - value = context.value - price = context.Price - - op OpCode // current opcode - mem = NewMemory() // bound memory - stack = newstack() // local stack - statedb = self.env.State() // current state + caller = contract.caller + code = contract.Code + instrCount = 0 + + op OpCode // current opcode + mem = NewMemory() // bound memory + stack = newstack() // local stack + statedb = self.env.Db() // current state // For optimisation reason we're using uint64 as the program counter. // It's theoretically possible to go above 2^64. The YP defines the PC to be uint256. Pratically much less so feasible. pc = uint64(0) // program counter @@ -112,8 +116,8 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { // jump evaluates and checks whether the given jump destination is a valid one // if valid move the `pc` otherwise return an error. jump = func(from uint64, to *big.Int) error { - if !context.jumpdests.has(codehash, code, to) { - nop := context.GetOp(to.Uint64()) + if !contract.jumpdests.has(codehash, code, to) { + nop := contract.GetOp(to.Uint64()) return fmt.Errorf("invalid jump destination (%v) %v", nop, to) } @@ -125,552 +129,92 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { newMemSize *big.Int cost *big.Int ) + contract.Input = input // User defer pattern to check for an error and, based on the error being nil or not, use all gas and return. defer func() { if err != nil { - self.log(pc, op, context.Gas, cost, mem, stack, context, err) + self.log(pc, op, contract.Gas, cost, mem, stack, contract, err) } }() - // Don't bother with the execution if there's no code. - if len(code) == 0 { - return context.Return(nil), nil + if glog.V(logger.Debug) { + glog.Infof("running byte VM %x\n", codehash[:4]) + tstart := time.Now() + defer func() { + glog.Infof("byte VM %x done. time: %v instrc: %v\n", codehash[:4], time.Since(tstart), instrCount) + }() } - for { - // Overhead of the atomic read might not be worth it - /* TODO this still causes a few issues in the tests - if program != nil && progStatus(atomic.LoadInt32(&program.status)) == progReady { - // move execution - glog.V(logger.Info).Infoln("Moved execution to JIT") - return runProgram(program, pc, mem, stack, self.env, context, input) - } + for ; ; instrCount++ { + /* + if EnableJit && it%100 == 0 { + if program != nil && progStatus(atomic.LoadInt32(&program.status)) == progReady { + // move execution + fmt.Println("moved", it) + glog.V(logger.Info).Infoln("Moved execution to JIT") + return runProgram(program, pc, mem, stack, self.env, contract, input) + } + } */ - // The base for all big integer arithmetic - base := new(big.Int) // Get the memory location of pc - op = context.GetOp(pc) + op = contract.GetOp(pc) // calculate the new memory size and gas price for the current executing opcode - newMemSize, cost, err = calculateGasAndSize(self.env, context, caller, op, statedb, mem, stack) + newMemSize, cost, err = calculateGasAndSize(self.env, contract, caller, op, statedb, mem, stack) if err != nil { return nil, err } // Use the calculated gas. When insufficient gas is present, use all gas and return an // Out Of Gas error - if !context.UseGas(cost) { + if !contract.UseGas(cost) { return nil, OutOfGasError } // Resize the memory calculated previously mem.Resize(newMemSize.Uint64()) // Add a log message - self.log(pc, op, context.Gas, cost, mem, stack, context, nil) - - switch op { - case ADD: - x, y := stack.pop(), stack.pop() - - base.Add(x, y) - - U256(base) - - // pop result back on the stack - stack.push(base) - case SUB: - x, y := stack.pop(), stack.pop() - - base.Sub(x, y) - - U256(base) - - // pop result back on the stack - stack.push(base) - case MUL: - x, y := stack.pop(), stack.pop() - - base.Mul(x, y) - - U256(base) - - // pop result back on the stack - stack.push(base) - case DIV: - x, y := stack.pop(), stack.pop() - - if y.Cmp(common.Big0) != 0 { - base.Div(x, y) - } - - U256(base) - - // pop result back on the stack - stack.push(base) - case SDIV: - x, y := S256(stack.pop()), S256(stack.pop()) - - if y.Cmp(common.Big0) == 0 { - base.Set(common.Big0) - } else { - n := new(big.Int) - if new(big.Int).Mul(x, y).Cmp(common.Big0) < 0 { - n.SetInt64(-1) - } else { - n.SetInt64(1) - } - - base.Div(x.Abs(x), y.Abs(y)).Mul(base, n) - - U256(base) - } - - stack.push(base) - case MOD: - x, y := stack.pop(), stack.pop() - - if y.Cmp(common.Big0) == 0 { - base.Set(common.Big0) - } else { - base.Mod(x, y) - } - - U256(base) - - stack.push(base) - case SMOD: - x, y := S256(stack.pop()), S256(stack.pop()) - - if y.Cmp(common.Big0) == 0 { - base.Set(common.Big0) - } else { - n := new(big.Int) - if x.Cmp(common.Big0) < 0 { - n.SetInt64(-1) - } else { - n.SetInt64(1) - } - - base.Mod(x.Abs(x), y.Abs(y)).Mul(base, n) - - U256(base) - } - - stack.push(base) - - case EXP: - x, y := stack.pop(), stack.pop() - - base.Exp(x, y, Pow256) - - U256(base) - - stack.push(base) - case SIGNEXTEND: - back := stack.pop() - if back.Cmp(big.NewInt(31)) < 0 { - bit := uint(back.Uint64()*8 + 7) - num := stack.pop() - mask := new(big.Int).Lsh(common.Big1, bit) - mask.Sub(mask, common.Big1) - if common.BitTest(num, int(bit)) { - num.Or(num, mask.Not(mask)) - } else { - num.And(num, mask) - } - - num = U256(num) - - stack.push(num) - } - case NOT: - stack.push(U256(new(big.Int).Not(stack.pop()))) - case LT: - x, y := stack.pop(), stack.pop() - - // x < y - if x.Cmp(y) < 0 { - stack.push(common.BigTrue) - } else { - stack.push(common.BigFalse) - } - case GT: - x, y := stack.pop(), stack.pop() - - // x > y - if x.Cmp(y) > 0 { - stack.push(common.BigTrue) - } else { - stack.push(common.BigFalse) - } - - case SLT: - x, y := S256(stack.pop()), S256(stack.pop()) - - // x < y - if x.Cmp(S256(y)) < 0 { - stack.push(common.BigTrue) - } else { - stack.push(common.BigFalse) - } - case SGT: - x, y := S256(stack.pop()), S256(stack.pop()) - - // x > y - if x.Cmp(y) > 0 { - stack.push(common.BigTrue) - } else { - stack.push(common.BigFalse) - } - - case EQ: - x, y := stack.pop(), stack.pop() - - // x == y - if x.Cmp(y) == 0 { - stack.push(common.BigTrue) - } else { - stack.push(common.BigFalse) - } - case ISZERO: - x := stack.pop() - if x.Cmp(common.BigFalse) > 0 { - stack.push(common.BigFalse) - } else { - stack.push(common.BigTrue) - } - - case AND: - x, y := stack.pop(), stack.pop() - - stack.push(base.And(x, y)) - case OR: - x, y := stack.pop(), stack.pop() - - stack.push(base.Or(x, y)) - case XOR: - x, y := stack.pop(), stack.pop() - - stack.push(base.Xor(x, y)) - case BYTE: - th, val := stack.pop(), stack.pop() - - if th.Cmp(big.NewInt(32)) < 0 { - byt := big.NewInt(int64(common.LeftPadBytes(val.Bytes(), 32)[th.Int64()])) - - base.Set(byt) - } else { - base.Set(common.BigFalse) - } - - stack.push(base) - case ADDMOD: - x := stack.pop() - y := stack.pop() - z := stack.pop() - - if z.Cmp(Zero) > 0 { - add := new(big.Int).Add(x, y) - base.Mod(add, z) - - base = U256(base) - } - - stack.push(base) - case MULMOD: - x := stack.pop() - y := stack.pop() - z := stack.pop() - - if z.Cmp(Zero) > 0 { - mul := new(big.Int).Mul(x, y) - base.Mod(mul, z) - - U256(base) - } - - stack.push(base) - - case SHA3: - offset, size := stack.pop(), stack.pop() - data := crypto.Sha3(mem.Get(offset.Int64(), size.Int64())) - - stack.push(common.BigD(data)) + self.log(pc, op, contract.Gas, cost, mem, stack, contract, nil) - case ADDRESS: - stack.push(common.Bytes2Big(context.Address().Bytes())) - - case BALANCE: - addr := common.BigToAddress(stack.pop()) - balance := statedb.GetBalance(addr) - - stack.push(new(big.Int).Set(balance)) - - case ORIGIN: - origin := self.env.Origin() - - stack.push(origin.Big()) - - case CALLER: - caller := context.caller.Address() - stack.push(common.Bytes2Big(caller.Bytes())) - - case CALLVALUE: - stack.push(new(big.Int).Set(value)) - - case CALLDATALOAD: - data := getData(input, stack.pop(), common.Big32) - - stack.push(common.Bytes2Big(data)) - case CALLDATASIZE: - l := int64(len(input)) - stack.push(big.NewInt(l)) - - case CALLDATACOPY: - var ( - mOff = stack.pop() - cOff = stack.pop() - l = stack.pop() - ) - data := getData(input, cOff, l) - - mem.Set(mOff.Uint64(), l.Uint64(), data) - - case CODESIZE, EXTCODESIZE: - var code []byte - if op == EXTCODESIZE { - addr := common.BigToAddress(stack.pop()) - - code = statedb.GetCode(addr) - } else { - code = context.Code - } - - l := big.NewInt(int64(len(code))) - stack.push(l) - - case CODECOPY, EXTCODECOPY: - var code []byte - if op == EXTCODECOPY { - addr := common.BigToAddress(stack.pop()) - code = statedb.GetCode(addr) + if opPtr := jumpTable[op]; opPtr.valid { + if opPtr.fn != nil { + opPtr.fn(instruction{}, &pc, self.env, contract, mem, stack) } else { - code = context.Code - } - - var ( - mOff = stack.pop() - cOff = stack.pop() - l = stack.pop() - ) - - codeCopy := getData(code, cOff, l) - - mem.Set(mOff.Uint64(), l.Uint64(), codeCopy) - - case GASPRICE: - stack.push(new(big.Int).Set(context.Price)) - - case BLOCKHASH: - num := stack.pop() - - n := new(big.Int).Sub(self.env.BlockNumber(), common.Big257) - if num.Cmp(n) > 0 && num.Cmp(self.env.BlockNumber()) < 0 { - stack.push(self.env.GetHash(num.Uint64()).Big()) - } else { - stack.push(common.Big0) - } - - case COINBASE: - coinbase := self.env.Coinbase() - - stack.push(coinbase.Big()) - - case TIMESTAMP: - time := self.env.Time() - - stack.push(new(big.Int).Set(time)) - - case NUMBER: - number := self.env.BlockNumber() - - stack.push(U256(number)) - - case DIFFICULTY: - difficulty := self.env.Difficulty() - - stack.push(new(big.Int).Set(difficulty)) - - case GASLIMIT: - - stack.push(new(big.Int).Set(self.env.GasLimit())) - - case PUSH1, PUSH2, PUSH3, PUSH4, PUSH5, PUSH6, PUSH7, PUSH8, PUSH9, PUSH10, PUSH11, PUSH12, PUSH13, PUSH14, PUSH15, PUSH16, PUSH17, PUSH18, PUSH19, PUSH20, PUSH21, PUSH22, PUSH23, PUSH24, PUSH25, PUSH26, PUSH27, PUSH28, PUSH29, PUSH30, PUSH31, PUSH32: - size := uint64(op - PUSH1 + 1) - byts := getData(code, new(big.Int).SetUint64(pc+1), new(big.Int).SetUint64(size)) - // push value to stack - stack.push(common.Bytes2Big(byts)) - pc += size - - case POP: - stack.pop() - case DUP1, DUP2, DUP3, DUP4, DUP5, DUP6, DUP7, DUP8, DUP9, DUP10, DUP11, DUP12, DUP13, DUP14, DUP15, DUP16: - n := int(op - DUP1 + 1) - stack.dup(n) - - case SWAP1, SWAP2, SWAP3, SWAP4, SWAP5, SWAP6, SWAP7, SWAP8, SWAP9, SWAP10, SWAP11, SWAP12, SWAP13, SWAP14, SWAP15, SWAP16: - n := int(op - SWAP1 + 2) - stack.swap(n) - - case LOG0, LOG1, LOG2, LOG3, LOG4: - n := int(op - LOG0) - topics := make([]common.Hash, n) - mStart, mSize := stack.pop(), stack.pop() - for i := 0; i < n; i++ { - topics[i] = common.BigToHash(stack.pop()) - } - - data := mem.Get(mStart.Int64(), mSize.Int64()) - log := state.NewLog(context.Address(), topics, data, self.env.BlockNumber().Uint64()) - self.env.AddLog(log) - - case MLOAD: - offset := stack.pop() - val := common.BigD(mem.Get(offset.Int64(), 32)) - stack.push(val) - - case MSTORE: - // pop value of the stack - mStart, val := stack.pop(), stack.pop() - mem.Set(mStart.Uint64(), 32, common.BigToBytes(val, 256)) - - case MSTORE8: - off, val := stack.pop().Int64(), stack.pop().Int64() - - mem.store[off] = byte(val & 0xff) - - case SLOAD: - loc := common.BigToHash(stack.pop()) - val := statedb.GetState(context.Address(), loc).Big() - stack.push(val) + switch op { + case PC: + opPc(instruction{data: new(big.Int).SetUint64(pc)}, &pc, self.env, contract, mem, stack) + case JUMP: + if err := jump(pc, stack.pop()); err != nil { + return nil, err + } - case SSTORE: - loc := common.BigToHash(stack.pop()) - val := stack.pop() + continue + case JUMPI: + pos, cond := stack.pop(), stack.pop() - statedb.SetState(context.Address(), loc, common.BigToHash(val)) + if cond.Cmp(common.BigTrue) >= 0 { + if err := jump(pc, pos); err != nil { + return nil, err + } - case JUMP: - if err := jump(pc, stack.pop()); err != nil { - return nil, err - } + continue + } + case RETURN: + offset, size := stack.pop(), stack.pop() + ret := mem.GetPtr(offset.Int64(), size.Int64()) - continue - case JUMPI: - pos, cond := stack.pop(), stack.pop() + return contract.Return(ret), nil + case SUICIDE: + opSuicide(instruction{}, nil, self.env, contract, mem, stack) - if cond.Cmp(common.BigTrue) >= 0 { - if err := jump(pc, pos); err != nil { - return nil, err + fallthrough + case STOP: // Stop the contract + return contract.Return(nil), nil } - - continue - } - - case JUMPDEST: - case PC: - stack.push(new(big.Int).SetUint64(pc)) - case MSIZE: - stack.push(big.NewInt(int64(mem.Len()))) - case GAS: - stack.push(new(big.Int).Set(context.Gas)) - case CREATE: - - var ( - value = stack.pop() - offset, size = stack.pop(), stack.pop() - input = mem.Get(offset.Int64(), size.Int64()) - gas = new(big.Int).Set(context.Gas) - addr common.Address - ) - - context.UseGas(context.Gas) - ret, suberr, ref := self.env.Create(context, input, gas, price, value) - if suberr != nil { - stack.push(common.BigFalse) - - } else { - // gas < len(ret) * CreateDataGas == NO_CODE - dataGas := big.NewInt(int64(len(ret))) - dataGas.Mul(dataGas, params.CreateDataGas) - if context.UseGas(dataGas) { - ref.SetCode(ret) - } - addr = ref.Address() - - stack.push(addr.Big()) - - } - - case CALL, CALLCODE: - gas := stack.pop() - // pop gas and value of the stack. - addr, value := stack.pop(), stack.pop() - value = U256(value) - // pop input size and offset - inOffset, inSize := stack.pop(), stack.pop() - // pop return size and offset - retOffset, retSize := stack.pop(), stack.pop() - - address := common.BigToAddress(addr) - - // Get the arguments from the memory - args := mem.Get(inOffset.Int64(), inSize.Int64()) - - if len(value.Bytes()) > 0 { - gas.Add(gas, params.CallStipend) - } - - var ( - ret []byte - err error - ) - if op == CALLCODE { - ret, err = self.env.CallCode(context, address, args, gas, price, value) - } else { - ret, err = self.env.Call(context, address, args, gas, price, value) - } - - if err != nil { - stack.push(common.BigFalse) - - } else { - stack.push(common.BigTrue) - - mem.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - - case RETURN: - offset, size := stack.pop(), stack.pop() - ret := mem.GetPtr(offset.Int64(), size.Int64()) - - return context.Return(ret), nil - case SUICIDE: - receiver := statedb.GetOrNewStateObject(common.BigToAddress(stack.pop())) - balance := statedb.GetBalance(context.Address()) - - receiver.AddBalance(balance) - - statedb.Delete(context.Address()) - - fallthrough - case STOP: // Stop the context - - return context.Return(nil), nil - default: - + } else { return nil, fmt.Errorf("Invalid opcode %x", op) } @@ -681,7 +225,7 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { // calculateGasAndSize calculates the required given the opcode and stack items calculates the new memorysize for // the operation. This does not reduce gas or resizes the memory. -func calculateGasAndSize(env Environment, context *Context, caller ContextRef, op OpCode, statedb *state.StateDB, mem *Memory, stack *stack) (*big.Int, *big.Int, error) { +func calculateGasAndSize(env Environment, contract *Contract, caller ContractRef, op OpCode, statedb Database, mem *Memory, stack *stack) (*big.Int, *big.Int, error) { var ( gas = new(big.Int) newMemSize *big.Int = new(big.Int) @@ -731,7 +275,7 @@ func calculateGasAndSize(env Environment, context *Context, caller ContextRef, o var g *big.Int y, x := stack.data[stack.len()-2], stack.data[stack.len()-1] - val := statedb.GetState(context.Address(), common.BigToHash(x)) + val := statedb.GetState(contract.Address(), common.BigToHash(x)) // This checks for 3 scenario's and calculates gas accordingly // 1. From a zero-value address to a non-zero value (NEW VALUE) @@ -741,7 +285,7 @@ func calculateGasAndSize(env Environment, context *Context, caller ContextRef, o // 0 => non 0 g = params.SstoreSetGas } else if !common.EmptyHash(val) && common.EmptyHash(common.BigToHash(y)) { - statedb.Refund(params.SstoreRefundGas) + statedb.AddRefund(params.SstoreRefundGas) g = params.SstoreClearGas } else { @@ -750,8 +294,8 @@ func calculateGasAndSize(env Environment, context *Context, caller ContextRef, o } gas.Set(g) case SUICIDE: - if !statedb.IsDeleted(context.Address()) { - statedb.Refund(params.SuicideRefundGas) + if !statedb.IsDeleted(contract.Address()) { + statedb.AddRefund(params.SuicideRefundGas) } case MLOAD: newMemSize = calcMemSize(stack.peek(), u256(32)) @@ -788,7 +332,8 @@ func calculateGasAndSize(env Environment, context *Context, caller ContextRef, o gas.Add(gas, stack.data[stack.len()-1]) if op == CALL { - if env.State().GetStateObject(common.BigToAddress(stack.data[stack.len()-2])) == nil { + //if env.Db().GetStateObject(common.BigToAddress(stack.data[stack.len()-2])) == nil { + if !env.Db().Exist(common.BigToAddress(stack.data[stack.len()-2])) { gas.Add(gas, params.CallNewAccountGas) } } @@ -802,38 +347,18 @@ func calculateGasAndSize(env Environment, context *Context, caller ContextRef, o newMemSize = common.BigMax(x, y) } - - if newMemSize.Cmp(common.Big0) > 0 { - newMemSizeWords := toWordSize(newMemSize) - newMemSize.Mul(newMemSizeWords, u256(32)) - - if newMemSize.Cmp(u256(int64(mem.Len()))) > 0 { - oldSize := toWordSize(big.NewInt(int64(mem.Len()))) - pow := new(big.Int).Exp(oldSize, common.Big2, Zero) - linCoef := new(big.Int).Mul(oldSize, params.MemoryGas) - quadCoef := new(big.Int).Div(pow, params.QuadCoeffDiv) - oldTotalFee := new(big.Int).Add(linCoef, quadCoef) - - pow.Exp(newMemSizeWords, common.Big2, Zero) - linCoef = new(big.Int).Mul(newMemSizeWords, params.MemoryGas) - quadCoef = new(big.Int).Div(pow, params.QuadCoeffDiv) - newTotalFee := new(big.Int).Add(linCoef, quadCoef) - - fee := new(big.Int).Sub(newTotalFee, oldTotalFee) - gas.Add(gas, fee) - } - } + quadMemGas(mem, newMemSize, gas) return newMemSize, gas, nil } // RunPrecompile runs and evaluate the output of a precompiled contract defined in contracts.go -func (self *Vm) RunPrecompiled(p *PrecompiledAccount, input []byte, context *Context) (ret []byte, err error) { +func (self *Vm) RunPrecompiled(p *PrecompiledAccount, input []byte, contract *Contract) (ret []byte, err error) { gas := p.Gas(len(input)) - if context.UseGas(gas) { + if contract.UseGas(gas) { ret = p.Call(input) - return context.Return(ret), nil + return contract.Return(ret), nil } else { return nil, OutOfGasError } @@ -841,18 +366,20 @@ func (self *Vm) RunPrecompiled(p *PrecompiledAccount, input []byte, context *Con // log emits a log event to the environment for each opcode encountered. This is not to be confused with the // LOG* opcode. -func (self *Vm) log(pc uint64, op OpCode, gas, cost *big.Int, memory *Memory, stack *stack, context *Context, err error) { +func (self *Vm) log(pc uint64, op OpCode, gas, cost *big.Int, memory *Memory, stack *stack, contract *Contract, err error) { if Debug { mem := make([]byte, len(memory.Data())) copy(mem, memory.Data()) stck := make([]*big.Int, len(stack.Data())) copy(stck, stack.Data()) - object := context.self.(*state.StateObject) storage := make(map[common.Hash][]byte) - object.EachStorage(func(k, v []byte) { - storage[common.BytesToHash(k)] = v - }) + /* + object := contract.self.(*state.StateObject) + object.EachStorage(func(k, v []byte) { + storage[common.BytesToHash(k)] = v + }) + */ self.env.AddStructLog(StructLog{pc, op, new(big.Int).Set(gas), cost, mem, stck, storage, err}) } diff --git a/core/vm/vm_jit.go b/core/vm/vm_jit.go index 339cb8ea8..07cb52d4a 100644 --- a/core/vm/vm_jit.go +++ b/core/vm/vm_jit.go @@ -30,6 +30,7 @@ void evmjit_destroy(void* _jit); */ import "C" +/* import ( "bytes" "errors" @@ -385,4 +386,4 @@ func env_extcode(_vm unsafe.Pointer, _addr unsafe.Pointer, o_size *uint64) *byte code := vm.Env().State().GetCode(addr) *o_size = uint64(len(code)) return getDataPtr(code) -} +}*/ diff --git a/core/vm_env.go b/core/vm_env.go index a08f024fe..467e34c6b 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -30,13 +30,13 @@ type VMEnv struct { header *types.Header msg Message depth int - chain *ChainManager + chain *BlockChain typ vm.Type // structured logging logs []vm.StructLog } -func NewEnv(state *state.StateDB, chain *ChainManager, msg Message, header *types.Header) *VMEnv { +func NewEnv(state *state.StateDB, chain *BlockChain, msg Message, header *types.Header) *VMEnv { return &VMEnv{ chain: chain, state: state, @@ -53,7 +53,7 @@ func (self *VMEnv) Time() *big.Int { return self.header.Time } func (self *VMEnv) Difficulty() *big.Int { return self.header.Difficulty } func (self *VMEnv) GasLimit() *big.Int { return self.header.GasLimit } func (self *VMEnv) Value() *big.Int { return self.msg.Value() } -func (self *VMEnv) State() *state.StateDB { return self.state } +func (self *VMEnv) Db() vm.Database { return self.state } func (self *VMEnv) Depth() int { return self.depth } func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) VmType() vm.Type { return self.typ } @@ -66,30 +66,34 @@ func (self *VMEnv) GetHash(n uint64) common.Hash { return common.Hash{} } -func (self *VMEnv) AddLog(log *state.Log) { +func (self *VMEnv) AddLog(log *vm.Log) { self.state.AddLog(log) } -func (self *VMEnv) CanTransfer(from vm.Account, balance *big.Int) bool { - return from.Balance().Cmp(balance) >= 0 +func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { + return self.state.GetBalance(from).Cmp(balance) >= 0 +} + +func (self *VMEnv) MakeSnapshot() vm.Database { + return self.state.Copy() +} + +func (self *VMEnv) SetSnapshot(copy vm.Database) { + self.state.Set(copy.(*state.StateDB)) } func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) error { - return vm.Transfer(from, to, amount) + return Transfer(from, to, amount) } -func (self *VMEnv) Call(me vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { - exe := NewExecution(self, &addr, data, gas, price, value) - return exe.Call(addr, me) +func (self *VMEnv) Call(me vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { + return Call(self, me, addr, data, gas, price, value) } -func (self *VMEnv) CallCode(me vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { - maddr := me.Address() - exe := NewExecution(self, &maddr, data, gas, price, value) - return exe.Call(addr, me) +func (self *VMEnv) CallCode(me vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { + return CallCode(self, me, addr, data, gas, price, value) } -func (self *VMEnv) Create(me vm.ContextRef, data []byte, gas, price, value *big.Int) ([]byte, error, vm.ContextRef) { - exe := NewExecution(self, nil, data, gas, price, value) - return exe.Create(me) +func (self *VMEnv) Create(me vm.ContractRef, data []byte, gas, price, value *big.Int) ([]byte, common.Address, error) { + return Create(self, me, data, gas, price, value) } func (self *VMEnv) StructLogs() []vm.StructLog { diff --git a/eth/backend.go b/eth/backend.go index 349dfa613..a480b4931 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -217,7 +217,7 @@ type Ethereum struct { // State manager for processing new blocks and managing the over all states blockProcessor *core.BlockProcessor txPool *core.TxPool - chainManager *core.ChainManager + blockchain *core.BlockChain accountManager *accounts.Manager whisper *whisper.Whisper pow *ethash.Ethash @@ -365,7 +365,7 @@ func New(config *Config) (*Ethereum, error) { eth.pow = ethash.New() } //genesis := core.GenesisBlock(uint64(config.GenesisNonce), stateDb) - eth.chainManager, err = core.NewChainManager(chainDb, eth.pow, eth.EventMux()) + eth.blockchain, err = core.NewBlockChain(chainDb, eth.pow, eth.EventMux()) if err != nil { if err == core.ErrNoGenesis { return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`) @@ -373,11 +373,11 @@ func New(config *Config) (*Ethereum, error) { return nil, err } - eth.txPool = core.NewTxPool(eth.EventMux(), eth.chainManager.State, eth.chainManager.GasLimit) + eth.txPool = core.NewTxPool(eth.EventMux(), eth.blockchain.State, eth.blockchain.GasLimit) - eth.blockProcessor = core.NewBlockProcessor(chainDb, eth.pow, eth.chainManager, eth.EventMux()) - eth.chainManager.SetProcessor(eth.blockProcessor) - eth.protocolManager = NewProtocolManager(config.NetworkId, eth.eventMux, eth.txPool, eth.pow, eth.chainManager, chainDb) + eth.blockProcessor = core.NewBlockProcessor(chainDb, eth.pow, eth.blockchain, eth.EventMux()) + eth.blockchain.SetProcessor(eth.blockProcessor) + eth.protocolManager = NewProtocolManager(config.NetworkId, eth.eventMux, eth.txPool, eth.pow, eth.blockchain, chainDb) eth.miner = miner.New(eth, eth.EventMux(), eth.pow) eth.miner.SetGasPrice(config.GasPrice) @@ -441,7 +441,7 @@ func (s *Ethereum) NodeInfo() *NodeInfo { DiscPort: int(node.UDP), TCPPort: int(node.TCP), ListenAddr: s.net.ListenAddr, - Td: s.ChainManager().Td().String(), + Td: s.BlockChain().Td().String(), } } @@ -478,7 +478,7 @@ func (s *Ethereum) PeersInfo() (peersinfo []*PeerInfo) { } func (s *Ethereum) ResetWithGenesisBlock(gb *types.Block) { - s.chainManager.ResetWithGenesisBlock(gb) + s.blockchain.ResetWithGenesisBlock(gb) } func (s *Ethereum) StartMining(threads int) error { @@ -518,7 +518,7 @@ func (s *Ethereum) Miner() *miner.Miner { return s.miner } // func (s *Ethereum) Logger() logger.LogSystem { return s.logger } func (s *Ethereum) Name() string { return s.net.Name } func (s *Ethereum) AccountManager() *accounts.Manager { return s.accountManager } -func (s *Ethereum) ChainManager() *core.ChainManager { return s.chainManager } +func (s *Ethereum) BlockChain() *core.BlockChain { return s.blockchain } func (s *Ethereum) BlockProcessor() *core.BlockProcessor { return s.blockProcessor } func (s *Ethereum) TxPool() *core.TxPool { return s.txPool } func (s *Ethereum) Whisper() *whisper.Whisper { return s.whisper } @@ -581,7 +581,7 @@ func (self *Ethereum) AddPeer(nodeURL string) error { func (s *Ethereum) Stop() { s.net.Stop() - s.chainManager.Stop() + s.blockchain.Stop() s.protocolManager.Stop() s.txPool.Stop() s.eventMux.Stop() @@ -622,7 +622,7 @@ func (self *Ethereum) StartAutoDAG() { select { case <-timer: glog.V(logger.Info).Infof("checking DAG (ethash dir: %s)", ethash.DefaultDir) - currentBlock := self.ChainManager().CurrentBlock().NumberU64() + currentBlock := self.BlockChain().CurrentBlock().NumberU64() thisEpoch := currentBlock / epochLength if nextEpoch <= thisEpoch { if currentBlock%epochLength > autoDAGepochHeight { diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index f038e24e4..64fb1b57b 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -154,7 +154,7 @@ type Downloader struct { blockCh chan blockPack // [eth/61] Channel receiving inbound blocks headerCh chan headerPack // [eth/62] Channel receiving inbound block headers bodyCh chan bodyPack // [eth/62] Channel receiving inbound block bodies - processCh chan bool // Channel to signal the block fetcher of new or finished work + wakeCh chan bool // Channel to signal the block/body fetcher of new tasks cancelCh chan struct{} // Channel to cancel mid-flight syncs cancelLock sync.RWMutex // Lock to protect the cancel channel in delivers @@ -188,7 +188,7 @@ func New(mux *event.TypeMux, hasBlock hashCheckFn, getBlock blockRetrievalFn, he blockCh: make(chan blockPack, 1), headerCh: make(chan headerPack, 1), bodyCh: make(chan bodyPack, 1), - processCh: make(chan bool, 1), + wakeCh: make(chan bool, 1), } } @@ -282,6 +282,10 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error d.queue.Reset() d.peers.Reset() + select { + case <-d.wakeCh: + default: + } // Create cancel channel for aborting mid-flight d.cancelLock.Lock() d.cancelCh = make(chan struct{}) @@ -633,7 +637,7 @@ func (d *Downloader) fetchHashes61(p *peer, td *big.Int, from uint64) error { glog.V(logger.Debug).Infof("%v: no available hashes", p) select { - case d.processCh <- false: + case d.wakeCh <- false: case <-d.cancelCh: } // If no hashes were retrieved at all, the peer violated it's TD promise that it had a @@ -664,12 +668,18 @@ func (d *Downloader) fetchHashes61(p *peer, td *big.Int, from uint64) error { return errBadPeer } // Notify the block fetcher of new hashes, but stop if queue is full - cont := d.queue.Pending() < maxQueuedHashes - select { - case d.processCh <- cont: - default: - } - if !cont { + if d.queue.Pending() < maxQueuedHashes { + // We still have hashes to fetch, send continuation wake signal (potential) + select { + case d.wakeCh <- true: + default: + } + } else { + // Hash limit reached, send a termination wake signal (enforced) + select { + case d.wakeCh <- false: + case <-d.cancelCh: + } return nil } // Queue not yet full, fetch the next batch @@ -766,7 +776,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { default: } - case cont := <-d.processCh: + case cont := <-d.wakeCh: // The hash fetcher sent a continuation flag, check if it's done if !cont { finished = true @@ -806,7 +816,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { } // Send a download request to all idle peers, until throttled throttled := false - for _, peer := range d.peers.IdlePeers() { + for _, peer := range d.peers.IdlePeers(eth61) { // Short circuit if throttling activated if d.queue.Throttle() { throttled = true @@ -1053,7 +1063,7 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { glog.V(logger.Debug).Infof("%v: no available headers", p) select { - case d.processCh <- false: + case d.wakeCh <- false: case <-d.cancelCh: } // If no headers were retrieved at all, the peer violated it's TD promise that it had a @@ -1084,12 +1094,18 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { return errBadPeer } // Notify the block fetcher of new headers, but stop if queue is full - cont := d.queue.Pending() < maxQueuedHeaders - select { - case d.processCh <- cont: - default: - } - if !cont { + if d.queue.Pending() < maxQueuedHeaders { + // We still have headers to fetch, send continuation wake signal (potential) + select { + case d.wakeCh <- true: + default: + } + } else { + // Header limit reached, send a termination wake signal (enforced) + select { + case d.wakeCh <- false: + case <-d.cancelCh: + } return nil } // Queue not yet full, fetch the next batch @@ -1104,8 +1120,8 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { // Finish the sync gracefully instead of dumping the gathered data though select { - case d.processCh <- false: - default: + case d.wakeCh <- false: + case <-d.cancelCh: } return nil } @@ -1199,7 +1215,7 @@ func (d *Downloader) fetchBodies(from uint64) error { default: } - case cont := <-d.processCh: + case cont := <-d.wakeCh: // The header fetcher sent a continuation flag, check if it's done if !cont { finished = true @@ -1239,7 +1255,7 @@ func (d *Downloader) fetchBodies(from uint64) error { } // Send a download request to all idle peers, until throttled queuedEmptyBlocks, throttled := false, false - for _, peer := range d.peers.IdlePeers() { + for _, peer := range d.peers.IdlePeers(eth62) { // Short circuit if throttling activated if d.queue.Throttle() { throttled = true diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 885fab8bd..96096527e 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -205,9 +205,17 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha dl.lock.Lock() defer dl.lock.Unlock() - err := dl.downloader.RegisterPeer(id, version, hashes[0], - dl.peerGetRelHashesFn(id, delay), dl.peerGetAbsHashesFn(id, delay), dl.peerGetBlocksFn(id, delay), - dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay)) + var err error + switch version { + case 61: + err = dl.downloader.RegisterPeer(id, version, hashes[0], dl.peerGetRelHashesFn(id, delay), dl.peerGetAbsHashesFn(id, delay), dl.peerGetBlocksFn(id, delay), nil, nil, nil) + case 62: + err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay)) + case 63: + err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay)) + case 64: + err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay)) + } if err == nil { // Assign the owned hashes and blocks to the peer (deep copy) dl.peerHashes[id] = make([]common.Hash, len(hashes)) @@ -618,6 +626,41 @@ func testMultiSynchronisation(t *testing.T, protocol int) { } } +// Tests that synchronisations behave well in multi-version protocol environments +// and not wreak havok on other nodes in the network. +func TestMultiProtocolSynchronisation61(t *testing.T) { testMultiProtocolSynchronisation(t, 61) } +func TestMultiProtocolSynchronisation62(t *testing.T) { testMultiProtocolSynchronisation(t, 62) } +func TestMultiProtocolSynchronisation63(t *testing.T) { testMultiProtocolSynchronisation(t, 63) } +func TestMultiProtocolSynchronisation64(t *testing.T) { testMultiProtocolSynchronisation(t, 64) } + +func testMultiProtocolSynchronisation(t *testing.T, protocol int) { + // Create a small enough block chain to download + targetBlocks := blockCacheLimit - 15 + hashes, blocks := makeChain(targetBlocks, 0, genesis) + + // Create peers of every type + tester := newTester() + tester.newPeer("peer 61", 61, hashes, blocks) + tester.newPeer("peer 62", 62, hashes, blocks) + tester.newPeer("peer 63", 63, hashes, blocks) + tester.newPeer("peer 64", 64, hashes, blocks) + + // Synchronise with the requestd peer and make sure all blocks were retrieved + if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) + } + // Check that no peers have been dropped off + for _, version := range []int{61, 62, 63, 64} { + peer := fmt.Sprintf("peer %d", version) + if _, ok := tester.peerHashes[peer]; !ok { + t.Errorf("%s dropped", peer) + } + } +} + // Tests that if a block is empty (i.e. header only), no body request should be // made, and instead the header should be assembled into a whole block in itself. func TestEmptyBlockShortCircuit62(t *testing.T) { testEmptyBlockShortCircuit(t, 62) } diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go index 8fd1f9a99..c1d20ac61 100644 --- a/eth/downloader/peer.go +++ b/eth/downloader/peer.go @@ -312,14 +312,16 @@ func (ps *peerSet) AllPeers() []*peer { // IdlePeers retrieves a flat list of all the currently idle peers within the // active peer set, ordered by their reputation. -func (ps *peerSet) IdlePeers() []*peer { +func (ps *peerSet) IdlePeers(version int) []*peer { ps.lock.RLock() defer ps.lock.RUnlock() list := make([]*peer, 0, len(ps.peers)) for _, p := range ps.peers { - if atomic.LoadInt32(&p.idle) == 0 { - list = append(list, p) + if (version == eth61 && p.version == eth61) || (version >= eth62 && p.version >= eth62) { + if atomic.LoadInt32(&p.idle) == 0 { + list = append(list, p) + } } } for i := 0; i < len(list); i++ { diff --git a/core/filter.go b/eth/filters/filter.go index b328ffff3..2bcf20d0c 100644 --- a/core/filter.go +++ b/eth/filters/filter.go @@ -1,4 +1,4 @@ -// Copyright 2014 The go-ethereum Authors +// Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -14,16 +14,16 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. -package core +package filters import ( "math" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" ) type AccountChange struct { @@ -32,7 +32,7 @@ type AccountChange struct { // Filtering interface type Filter struct { - eth Backend + db ethdb.Database earliest int64 latest int64 skip int @@ -40,15 +40,15 @@ type Filter struct { max int topics [][]common.Hash - BlockCallback func(*types.Block, state.Logs) + BlockCallback func(*types.Block, vm.Logs) TransactionCallback func(*types.Transaction) - LogsCallback func(state.Logs) + LogsCallback func(vm.Logs) } // Create a new filter which uses a bloom filter on blocks to figure out whether a particular block // is interesting or not. -func NewFilter(eth Backend) *Filter { - return &Filter{eth: eth} +func New(db ethdb.Database) *Filter { + return &Filter{db: db} } // Set the earliest and latest block for filtering. @@ -79,8 +79,8 @@ func (self *Filter) SetSkip(skip int) { } // Run filters logs with the current parameters set -func (self *Filter) Find() state.Logs { - earliestBlock := self.eth.ChainManager().CurrentBlock() +func (self *Filter) Find() vm.Logs { + earliestBlock := core.GetBlock(self.db, core.GetHeadBlockHash(self.db)) var earliestBlockNo uint64 = uint64(self.earliest) if self.earliest == -1 { earliestBlockNo = earliestBlock.NumberU64() @@ -91,9 +91,13 @@ func (self *Filter) Find() state.Logs { } var ( - logs state.Logs - block = self.eth.ChainManager().GetBlockByNumber(latestBlockNo) + logs vm.Logs + block *types.Block ) + hash := core.GetCanonicalHash(self.db, latestBlockNo) + if hash != (common.Hash{}) { + block = core.GetBlock(self.db, hash) + } done: for i := 0; block != nil; i++ { @@ -111,17 +115,17 @@ done: // current parameters if self.bloomFilter(block) { // Get the logs of the block - unfiltered, err := self.eth.BlockProcessor().GetLogs(block) - if err != nil { - glog.V(logger.Warn).Infoln("err: filter get logs ", err) - - break + var ( + receipts = core.GetBlockReceipts(self.db, block.Hash()) + unfiltered vm.Logs + ) + for _, receipt := range receipts { + unfiltered = append(unfiltered, receipt.Logs()...) } - logs = append(logs, self.FilterLogs(unfiltered)...) } - block = self.eth.ChainManager().GetBlock(block.ParentHash()) + block = core.GetBlock(self.db, block.ParentHash()) } skip := int(math.Min(float64(len(logs)), float64(self.skip))) @@ -139,8 +143,8 @@ func includes(addresses []common.Address, a common.Address) bool { return false } -func (self *Filter) FilterLogs(logs state.Logs) state.Logs { - var ret state.Logs +func (self *Filter) FilterLogs(logs vm.Logs) vm.Logs { + var ret vm.Logs // Filter the logs for interesting stuff Logs: diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go new file mode 100644 index 000000000..4972dcd59 --- /dev/null +++ b/eth/filters/filter_system.go @@ -0,0 +1,133 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +// package filters implements an ethereum filtering system for block, +// transactions and log events. +package filters + +import ( + "sync" + + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/event" +) + +// FilterSystem manages filters that filter specific events such as +// block, transaction and log events. The Filtering system can be used to listen +// for specific LOG events fired by the EVM (Ethereum Virtual Machine). +type FilterSystem struct { + eventMux *event.TypeMux + + filterMu sync.RWMutex + filterId int + filters map[int]*Filter + + quit chan struct{} +} + +// NewFilterSystem returns a newly allocated filter manager +func NewFilterSystem(mux *event.TypeMux) *FilterSystem { + fs := &FilterSystem{ + eventMux: mux, + filters: make(map[int]*Filter), + } + go fs.filterLoop() + return fs +} + +// Stop quits the filter loop required for polling events +func (fs *FilterSystem) Stop() { + close(fs.quit) +} + +// Add adds a filter to the filter manager +func (fs *FilterSystem) Add(filter *Filter) (id int) { + fs.filterMu.Lock() + defer fs.filterMu.Unlock() + id = fs.filterId + fs.filters[id] = filter + fs.filterId++ + + return id +} + +// Remove removes a filter by filter id +func (fs *FilterSystem) Remove(id int) { + fs.filterMu.Lock() + defer fs.filterMu.Unlock() + if _, ok := fs.filters[id]; ok { + delete(fs.filters, id) + } +} + +// Get retrieves a filter installed using Add The filter may not be modified. +func (fs *FilterSystem) Get(id int) *Filter { + fs.filterMu.RLock() + defer fs.filterMu.RUnlock() + return fs.filters[id] +} + +// filterLoop waits for specific events from ethereum and fires their handlers +// when the filter matches the requirements. +func (fs *FilterSystem) filterLoop() { + // Subscribe to events + events := fs.eventMux.Subscribe( + //core.PendingBlockEvent{}, + core.ChainEvent{}, + core.TxPreEvent{}, + vm.Logs(nil)) + +out: + for { + select { + case <-fs.quit: + break out + case event := <-events.Chan(): + switch event := event.(type) { + case core.ChainEvent: + fs.filterMu.RLock() + for _, filter := range fs.filters { + if filter.BlockCallback != nil { + filter.BlockCallback(event.Block, event.Logs) + } + } + fs.filterMu.RUnlock() + + case core.TxPreEvent: + fs.filterMu.RLock() + for _, filter := range fs.filters { + if filter.TransactionCallback != nil { + filter.TransactionCallback(event.Tx) + } + } + fs.filterMu.RUnlock() + + case vm.Logs: + fs.filterMu.RLock() + for _, filter := range fs.filters { + if filter.LogsCallback != nil { + msgs := filter.FilterLogs(event) + if len(msgs) > 0 { + filter.LogsCallback(msgs) + } + } + } + fs.filterMu.RUnlock() + } + } + } +} diff --git a/eth/gasprice.go b/eth/gasprice.go index 3caad73c6..c08b96129 100644 --- a/eth/gasprice.go +++ b/eth/gasprice.go @@ -36,7 +36,7 @@ type blockPriceInfo struct { type GasPriceOracle struct { eth *Ethereum - chain *core.ChainManager + chain *core.BlockChain events event.Subscription blocks map[uint64]*blockPriceInfo firstProcessed, lastProcessed uint64 @@ -48,7 +48,7 @@ func NewGasPriceOracle(eth *Ethereum) (self *GasPriceOracle) { self = &GasPriceOracle{} self.blocks = make(map[uint64]*blockPriceInfo) self.eth = eth - self.chain = eth.chainManager + self.chain = eth.blockchain self.events = eth.EventMux().Subscribe( core.ChainEvent{}, core.ChainSplitEvent{}, diff --git a/eth/handler.go b/eth/handler.go index 52c9c4151..fc92338b4 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -60,9 +60,9 @@ func (ep extProt) GetHashes(hash common.Hash) error { return ep.getHashes(has func (ep extProt) GetBlock(hashes []common.Hash) error { return ep.getBlocks(hashes) } type ProtocolManager struct { - txpool txPool - chainman *core.ChainManager - chaindb ethdb.Database + txpool txPool + blockchain *core.BlockChain + chaindb ethdb.Database downloader *downloader.Downloader fetcher *fetcher.Fetcher @@ -87,17 +87,17 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, chainman *core.ChainManager, chaindb ethdb.Database) *ProtocolManager { +func NewProtocolManager(networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, blockchain *core.BlockChain, chaindb ethdb.Database) *ProtocolManager { // Create the protocol manager with the base fields manager := &ProtocolManager{ - eventMux: mux, - txpool: txpool, - chainman: chainman, - chaindb: chaindb, - peers: newPeerSet(), - newPeerCh: make(chan *peer, 1), - txsyncCh: make(chan *txsync), - quitSync: make(chan struct{}), + eventMux: mux, + txpool: txpool, + blockchain: blockchain, + chaindb: chaindb, + peers: newPeerSet(), + newPeerCh: make(chan *peer, 1), + txsyncCh: make(chan *txsync), + quitSync: make(chan struct{}), } // Initiate a sub-protocol for every implemented version we can handle manager.SubProtocols = make([]p2p.Protocol, len(ProtocolVersions)) @@ -116,15 +116,15 @@ func NewProtocolManager(networkId int, mux *event.TypeMux, txpool txPool, pow po } } // Construct the different synchronisation mechanisms - manager.downloader = downloader.New(manager.eventMux, manager.chainman.HasBlock, manager.chainman.GetBlock, manager.chainman.CurrentBlock, manager.chainman.GetTd, manager.chainman.InsertChain, manager.removePeer) + manager.downloader = downloader.New(manager.eventMux, manager.blockchain.HasBlock, manager.blockchain.GetBlock, manager.blockchain.CurrentBlock, manager.blockchain.GetTd, manager.blockchain.InsertChain, manager.removePeer) validator := func(block *types.Block, parent *types.Block) error { return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) } heighter := func() uint64 { - return manager.chainman.CurrentBlock().NumberU64() + return manager.blockchain.CurrentBlock().NumberU64() } - manager.fetcher = fetcher.New(manager.chainman.GetBlock, validator, manager.BroadcastBlock, heighter, manager.chainman.InsertChain, manager.removePeer) + manager.fetcher = fetcher.New(manager.blockchain.GetBlock, validator, manager.BroadcastBlock, heighter, manager.blockchain.InsertChain, manager.removePeer) return manager } @@ -187,7 +187,7 @@ func (pm *ProtocolManager) handle(p *peer) error { glog.V(logger.Debug).Infof("%v: peer connected [%s]", p, p.Name()) // Execute the Ethereum handshake - td, head, genesis := pm.chainman.Status() + td, head, genesis := pm.blockchain.Status() if err := p.Handshake(td, head, genesis); err != nil { glog.V(logger.Debug).Infof("%v: handshake failed: %v", p, err) return err @@ -252,7 +252,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { request.Amount = uint64(downloader.MaxHashFetch) } // Retrieve the hashes from the block chain and return them - hashes := pm.chainman.GetBlockHashesFromHash(request.Hash, request.Amount) + hashes := pm.blockchain.GetBlockHashesFromHash(request.Hash, request.Amount) if len(hashes) == 0 { glog.V(logger.Debug).Infof("invalid block hash %x", request.Hash.Bytes()[:4]) } @@ -268,9 +268,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { request.Amount = uint64(downloader.MaxHashFetch) } // Calculate the last block that should be retrieved, and short circuit if unavailable - last := pm.chainman.GetBlockByNumber(request.Number + request.Amount - 1) + last := pm.blockchain.GetBlockByNumber(request.Number + request.Amount - 1) if last == nil { - last = pm.chainman.CurrentBlock() + last = pm.blockchain.CurrentBlock() request.Amount = last.NumberU64() - request.Number + 1 } if last.NumberU64() < request.Number { @@ -278,7 +278,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // Retrieve the hashes from the last block backwards, reverse and return hashes := []common.Hash{last.Hash()} - hashes = append(hashes, pm.chainman.GetBlockHashesFromHash(last.Hash(), request.Amount-1)...) + hashes = append(hashes, pm.blockchain.GetBlockHashesFromHash(last.Hash(), request.Amount-1)...) for i := 0; i < len(hashes)/2; i++ { hashes[i], hashes[len(hashes)-1-i] = hashes[len(hashes)-1-i], hashes[i] @@ -318,7 +318,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Retrieve the requested block, stopping if enough was found - if block := pm.chainman.GetBlock(hash); block != nil { + if block := pm.blockchain.GetBlock(hash); block != nil { blocks = append(blocks, block) bytes += block.Size() } @@ -358,9 +358,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Retrieve the next header satisfying the query var origin *types.Header if query.Origin.Hash != (common.Hash{}) { - origin = pm.chainman.GetHeader(query.Origin.Hash) + origin = pm.blockchain.GetHeader(query.Origin.Hash) } else { - origin = pm.chainman.GetHeaderByNumber(query.Origin.Number) + origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) } if origin == nil { break @@ -373,7 +373,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { case query.Origin.Hash != (common.Hash{}) && query.Reverse: // Hash based traversal towards the genesis block for i := 0; i < int(query.Skip)+1; i++ { - if header := pm.chainman.GetHeader(query.Origin.Hash); header != nil { + if header := pm.blockchain.GetHeader(query.Origin.Hash); header != nil { query.Origin.Hash = header.ParentHash } else { unknown = true @@ -382,8 +382,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } case query.Origin.Hash != (common.Hash{}) && !query.Reverse: // Hash based traversal towards the leaf block - if header := pm.chainman.GetHeaderByNumber(origin.Number.Uint64() + query.Skip + 1); header != nil { - if pm.chainman.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { + if header := pm.blockchain.GetHeaderByNumber(origin.Number.Uint64() + query.Skip + 1); header != nil { + if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { query.Origin.Hash = header.Hash() } else { unknown = true @@ -466,7 +466,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Retrieve the requested block body, stopping if enough was found - if data := pm.chainman.GetBodyRLP(hash); len(data) != 0 { + if data := pm.blockchain.GetBodyRLP(hash); len(data) != 0 { bodies = append(bodies, data) bytes += len(data) } @@ -562,7 +562,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Schedule all the unknown hashes for retrieval unknown := make([]announce, 0, len(announces)) for _, block := range announces { - if !pm.chainman.HasBlock(block.Hash) { + if !pm.blockchain.HasBlock(block.Hash) { unknown = append(unknown, block) } } @@ -586,7 +586,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { request.Block.ReceivedAt = msg.ReceivedAt // Mark the block's arrival for whatever reason - _, chainHead, _ := pm.chainman.Status() + _, chainHead, _ := pm.blockchain.Status() jsonlogger.LogJson(&logger.EthChainReceivedNewBlock{ BlockHash: request.Block.Hash().Hex(), BlockNumber: request.Block.Number(), @@ -603,7 +603,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Update the peers total difficulty if needed, schedule a download if gapped if request.TD.Cmp(p.Td()) > 0 { p.SetTd(request.TD) - if request.TD.Cmp(new(big.Int).Add(pm.chainman.Td(), request.Block.Difficulty())) > 0 { + if request.TD.Cmp(new(big.Int).Add(pm.blockchain.Td(), request.Block.Difficulty())) > 0 { go pm.synchronise(p) } } @@ -645,8 +645,8 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { if propagate { // Calculate the TD of the block (it's not imported yet, so block.Td is not valid) var td *big.Int - if parent := pm.chainman.GetBlock(block.ParentHash()); parent != nil { - td = new(big.Int).Add(block.Difficulty(), pm.chainman.GetTd(block.ParentHash())) + if parent := pm.blockchain.GetBlock(block.ParentHash()); parent != nil { + td = new(big.Int).Add(block.Difficulty(), pm.blockchain.GetTd(block.ParentHash())) } else { glog.V(logger.Error).Infof("propagating dangling block #%d [%x]", block.NumberU64(), hash[:4]) return @@ -659,7 +659,7 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { glog.V(logger.Detail).Infof("propagated block %x to %d peers in %v", hash[:4], len(transfer), time.Since(block.ReceivedAt)) } // Otherwise if the block is indeed in out own chain, announce it - if pm.chainman.HasBlock(hash) { + if pm.blockchain.HasBlock(hash) { for _, peer := range peers { if peer.version < eth62 { peer.SendNewBlockHashes61([]common.Hash{hash}) diff --git a/eth/handler_test.go b/eth/handler_test.go index 6400d4e78..2b8c6168a 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -33,23 +33,23 @@ func testGetBlockHashes(t *testing.T, protocol int) { number int result int }{ - {common.Hash{}, 1, 0}, // Make sure non existent hashes don't return results - {pm.chainman.Genesis().Hash(), 1, 0}, // There are no hashes to retrieve up from the genesis - {pm.chainman.GetBlockByNumber(5).Hash(), 5, 5}, // All the hashes including the genesis requested - {pm.chainman.GetBlockByNumber(5).Hash(), 10, 5}, // More hashes than available till the genesis requested - {pm.chainman.GetBlockByNumber(100).Hash(), 10, 10}, // All hashes available from the middle of the chain - {pm.chainman.CurrentBlock().Hash(), 10, 10}, // All hashes available from the head of the chain - {pm.chainman.CurrentBlock().Hash(), limit, limit}, // Request the maximum allowed hash count - {pm.chainman.CurrentBlock().Hash(), limit + 1, limit}, // Request more than the maximum allowed hash count + {common.Hash{}, 1, 0}, // Make sure non existent hashes don't return results + {pm.blockchain.Genesis().Hash(), 1, 0}, // There are no hashes to retrieve up from the genesis + {pm.blockchain.GetBlockByNumber(5).Hash(), 5, 5}, // All the hashes including the genesis requested + {pm.blockchain.GetBlockByNumber(5).Hash(), 10, 5}, // More hashes than available till the genesis requested + {pm.blockchain.GetBlockByNumber(100).Hash(), 10, 10}, // All hashes available from the middle of the chain + {pm.blockchain.CurrentBlock().Hash(), 10, 10}, // All hashes available from the head of the chain + {pm.blockchain.CurrentBlock().Hash(), limit, limit}, // Request the maximum allowed hash count + {pm.blockchain.CurrentBlock().Hash(), limit + 1, limit}, // Request more than the maximum allowed hash count } // Run each of the tests and verify the results against the chain for i, tt := range tests { // Assemble the hash response we would like to receive resp := make([]common.Hash, tt.result) if len(resp) > 0 { - from := pm.chainman.GetBlock(tt.origin).NumberU64() - 1 + from := pm.blockchain.GetBlock(tt.origin).NumberU64() - 1 for j := 0; j < len(resp); j++ { - resp[j] = pm.chainman.GetBlockByNumber(uint64(int(from) - j)).Hash() + resp[j] = pm.blockchain.GetBlockByNumber(uint64(int(from) - j)).Hash() } } // Send the hash request and verify the response @@ -76,11 +76,11 @@ func testGetBlockHashesFromNumber(t *testing.T, protocol int) { number int result int }{ - {pm.chainman.CurrentBlock().NumberU64() + 1, 1, 0}, // Out of bounds requests should return empty - {pm.chainman.CurrentBlock().NumberU64(), 1, 1}, // Make sure the head hash can be retrieved - {pm.chainman.CurrentBlock().NumberU64() - 4, 5, 5}, // All hashes, including the head hash requested - {pm.chainman.CurrentBlock().NumberU64() - 4, 10, 5}, // More hashes requested than available till the head - {pm.chainman.CurrentBlock().NumberU64() - 100, 10, 10}, // All hashes available from the middle of the chain + {pm.blockchain.CurrentBlock().NumberU64() + 1, 1, 0}, // Out of bounds requests should return empty + {pm.blockchain.CurrentBlock().NumberU64(), 1, 1}, // Make sure the head hash can be retrieved + {pm.blockchain.CurrentBlock().NumberU64() - 4, 5, 5}, // All hashes, including the head hash requested + {pm.blockchain.CurrentBlock().NumberU64() - 4, 10, 5}, // More hashes requested than available till the head + {pm.blockchain.CurrentBlock().NumberU64() - 100, 10, 10}, // All hashes available from the middle of the chain {0, 10, 10}, // All hashes available from the root of the chain {0, limit, limit}, // Request the maximum allowed hash count {0, limit + 1, limit}, // Request more than the maximum allowed hash count @@ -91,7 +91,7 @@ func testGetBlockHashesFromNumber(t *testing.T, protocol int) { // Assemble the hash response we would like to receive resp := make([]common.Hash, tt.result) for j := 0; j < len(resp); j++ { - resp[j] = pm.chainman.GetBlockByNumber(tt.origin + uint64(j)).Hash() + resp[j] = pm.blockchain.GetBlockByNumber(tt.origin + uint64(j)).Hash() } // Send the hash request and verify the response p2p.Send(peer.app, 0x08, getBlockHashesFromNumberData{tt.origin, uint64(tt.number)}) @@ -117,22 +117,22 @@ func testGetBlocks(t *testing.T, protocol int) { available []bool // Availability of explicitly requested blocks expected int // Total number of existing blocks to expect }{ - {1, nil, nil, 1}, // A single random block should be retrievable - {10, nil, nil, 10}, // Multiple random blocks should be retrievable - {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable - {limit + 1, nil, nil, limit}, // No more that the possible block count should be returned - {0, []common.Hash{pm.chainman.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable - {0, []common.Hash{pm.chainman.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable - {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned + {0, []common.Hash{pm.blockchain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{pm.blockchain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned // Existing and non-existing blocks interleaved should not cause problems {0, []common.Hash{ common.Hash{}, - pm.chainman.GetBlockByNumber(1).Hash(), + pm.blockchain.GetBlockByNumber(1).Hash(), common.Hash{}, - pm.chainman.GetBlockByNumber(10).Hash(), + pm.blockchain.GetBlockByNumber(10).Hash(), common.Hash{}, - pm.chainman.GetBlockByNumber(100).Hash(), + pm.blockchain.GetBlockByNumber(100).Hash(), common.Hash{}, }, []bool{false, true, false, true, false, true, false}, 3}, } @@ -144,11 +144,11 @@ func testGetBlocks(t *testing.T, protocol int) { for j := 0; j < tt.random; j++ { for { - num := rand.Int63n(int64(pm.chainman.CurrentBlock().NumberU64())) + num := rand.Int63n(int64(pm.blockchain.CurrentBlock().NumberU64())) if !seen[num] { seen[num] = true - block := pm.chainman.GetBlockByNumber(uint64(num)) + block := pm.blockchain.GetBlockByNumber(uint64(num)) hashes = append(hashes, block.Hash()) if len(blocks) < tt.expected { blocks = append(blocks, block) @@ -160,7 +160,7 @@ func testGetBlocks(t *testing.T, protocol int) { for j, hash := range tt.explicit { hashes = append(hashes, hash) if tt.available[j] && len(blocks) < tt.expected { - blocks = append(blocks, pm.chainman.GetBlock(hash)) + blocks = append(blocks, pm.blockchain.GetBlock(hash)) } } // Send the hash request and verify the response @@ -194,83 +194,83 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }{ // A single random block should be retrievable by hash and number too { - &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.chainman.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, - []common.Hash{pm.chainman.GetBlockByNumber(limit / 2).Hash()}, + &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, }, { &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, - []common.Hash{pm.chainman.GetBlockByNumber(limit / 2).Hash()}, + []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, }, // Multiple headers should be retrievable in both directions { &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, []common.Hash{ - pm.chainman.GetBlockByNumber(limit / 2).Hash(), - pm.chainman.GetBlockByNumber(limit/2 + 1).Hash(), - pm.chainman.GetBlockByNumber(limit/2 + 2).Hash(), + pm.blockchain.GetBlockByNumber(limit / 2).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 + 1).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 + 2).Hash(), }, }, { &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, []common.Hash{ - pm.chainman.GetBlockByNumber(limit / 2).Hash(), - pm.chainman.GetBlockByNumber(limit/2 - 1).Hash(), - pm.chainman.GetBlockByNumber(limit/2 - 2).Hash(), + pm.blockchain.GetBlockByNumber(limit / 2).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 - 1).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 - 2).Hash(), }, }, // Multiple headers with skip lists should be retrievable { &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, []common.Hash{ - pm.chainman.GetBlockByNumber(limit / 2).Hash(), - pm.chainman.GetBlockByNumber(limit/2 + 4).Hash(), - pm.chainman.GetBlockByNumber(limit/2 + 8).Hash(), + pm.blockchain.GetBlockByNumber(limit / 2).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 + 4).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 + 8).Hash(), }, }, { &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ - pm.chainman.GetBlockByNumber(limit / 2).Hash(), - pm.chainman.GetBlockByNumber(limit/2 - 4).Hash(), - pm.chainman.GetBlockByNumber(limit/2 - 8).Hash(), + pm.blockchain.GetBlockByNumber(limit / 2).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 - 4).Hash(), + pm.blockchain.GetBlockByNumber(limit/2 - 8).Hash(), }, }, // The chain endpoints should be retrievable { &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, - []common.Hash{pm.chainman.GetBlockByNumber(0).Hash()}, + []common.Hash{pm.blockchain.GetBlockByNumber(0).Hash()}, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64()}, Amount: 1}, - []common.Hash{pm.chainman.CurrentBlock().Hash()}, + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64()}, Amount: 1}, + []common.Hash{pm.blockchain.CurrentBlock().Hash()}, }, // Ensure protocol limits are honored { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, - pm.chainman.GetBlockHashesFromHash(pm.chainman.CurrentBlock().Hash(), limit), + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + pm.blockchain.GetBlockHashesFromHash(pm.blockchain.CurrentBlock().Hash(), limit), }, // Check that requesting more than available is handled gracefully { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, []common.Hash{ - pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 4).Hash(), - pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64()).Hash(), + pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(), + pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64()).Hash(), }, }, { &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ - pm.chainman.GetBlockByNumber(4).Hash(), - pm.chainman.GetBlockByNumber(0).Hash(), + pm.blockchain.GetBlockByNumber(4).Hash(), + pm.blockchain.GetBlockByNumber(0).Hash(), }, }, // Check that requesting more than available is handled gracefully, even if mid skip { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, []common.Hash{ - pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 4).Hash(), - pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 1).Hash(), + pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(), + pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 1).Hash(), }, }, { &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, []common.Hash{ - pm.chainman.GetBlockByNumber(4).Hash(), - pm.chainman.GetBlockByNumber(1).Hash(), + pm.blockchain.GetBlockByNumber(4).Hash(), + pm.blockchain.GetBlockByNumber(1).Hash(), }, }, // Check that non existing headers aren't returned @@ -278,7 +278,7 @@ func testGetBlockHeaders(t *testing.T, protocol int) { &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, []common.Hash{}, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() + 1}, Amount: 1}, + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() + 1}, Amount: 1}, []common.Hash{}, }, } @@ -287,7 +287,7 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Collect the headers to expect in the response headers := []*types.Header{} for _, hash := range tt.expect { - headers = append(headers, pm.chainman.GetBlock(hash).Header()) + headers = append(headers, pm.blockchain.GetBlock(hash).Header()) } // Send the hash request and verify the response p2p.Send(peer.app, 0x03, tt.query) @@ -315,22 +315,22 @@ func testGetBlockBodies(t *testing.T, protocol int) { available []bool // Availability of explicitly requested blocks expected int // Total number of existing blocks to expect }{ - {1, nil, nil, 1}, // A single random block should be retrievable - {10, nil, nil, 10}, // Multiple random blocks should be retrievable - {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable - {limit + 1, nil, nil, limit}, // No more that the possible block count should be returned - {0, []common.Hash{pm.chainman.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable - {0, []common.Hash{pm.chainman.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable - {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned + {0, []common.Hash{pm.blockchain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{pm.blockchain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned // Existing and non-existing blocks interleaved should not cause problems {0, []common.Hash{ common.Hash{}, - pm.chainman.GetBlockByNumber(1).Hash(), + pm.blockchain.GetBlockByNumber(1).Hash(), common.Hash{}, - pm.chainman.GetBlockByNumber(10).Hash(), + pm.blockchain.GetBlockByNumber(10).Hash(), common.Hash{}, - pm.chainman.GetBlockByNumber(100).Hash(), + pm.blockchain.GetBlockByNumber(100).Hash(), common.Hash{}, }, []bool{false, true, false, true, false, true, false}, 3}, } @@ -342,11 +342,11 @@ func testGetBlockBodies(t *testing.T, protocol int) { for j := 0; j < tt.random; j++ { for { - num := rand.Int63n(int64(pm.chainman.CurrentBlock().NumberU64())) + num := rand.Int63n(int64(pm.blockchain.CurrentBlock().NumberU64())) if !seen[num] { seen[num] = true - block := pm.chainman.GetBlockByNumber(uint64(num)) + block := pm.blockchain.GetBlockByNumber(uint64(num)) hashes = append(hashes, block.Hash()) if len(bodies) < tt.expected { bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) @@ -358,7 +358,7 @@ func testGetBlockBodies(t *testing.T, protocol int) { for j, hash := range tt.explicit { hashes = append(hashes, hash) if tt.available[j] && len(bodies) < tt.expected { - block := pm.chainman.GetBlock(hash) + block := pm.blockchain.GetBlock(hash) bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) } } @@ -442,11 +442,11 @@ func testGetNodeData(t *testing.T, protocol int) { statedb.Put(hashes[i].Bytes(), data[i]) } accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr} - for i := uint64(0); i <= pm.chainman.CurrentBlock().NumberU64(); i++ { - trie := state.New(pm.chainman.GetBlockByNumber(i).Root(), statedb) + for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { + trie := state.New(pm.blockchain.GetBlockByNumber(i).Root(), statedb) for j, acc := range accounts { - bw := pm.chainman.State().GetBalance(acc) + bw := pm.blockchain.State().GetBalance(acc) bh := trie.GetBalance(acc) if (bw != nil && bh == nil) || (bw == nil && bh != nil) { @@ -505,8 +505,8 @@ func testGetReceipt(t *testing.T, protocol int) { // Collect the hashes to request, and the response to expect hashes := []common.Hash{} - for i := uint64(0); i <= pm.chainman.CurrentBlock().NumberU64(); i++ { - for _, tx := range pm.chainman.GetBlockByNumber(i).Transactions() { + for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { + for _, tx := range pm.blockchain.GetBlockByNumber(i).Transactions() { hashes = append(hashes, tx.Hash()) } } diff --git a/eth/helper_test.go b/eth/helper_test.go index 034751f7f..e42fa1f82 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -30,18 +30,18 @@ var ( // channels for different events. func newTestProtocolManager(blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager { var ( - evmux = new(event.TypeMux) - pow = new(core.FakePow) - db, _ = ethdb.NewMemDatabase() - genesis = core.WriteGenesisBlockForTesting(db, core.GenesisAccount{testBankAddress, testBankFunds}) - chainman, _ = core.NewChainManager(db, pow, evmux) - blockproc = core.NewBlockProcessor(db, pow, chainman, evmux) + evmux = new(event.TypeMux) + pow = new(core.FakePow) + db, _ = ethdb.NewMemDatabase() + genesis = core.WriteGenesisBlockForTesting(db, core.GenesisAccount{testBankAddress, testBankFunds}) + blockchain, _ = core.NewBlockChain(db, pow, evmux) + blockproc = core.NewBlockProcessor(db, pow, blockchain, evmux) ) - chainman.SetProcessor(blockproc) - if _, err := chainman.InsertChain(core.GenerateChain(genesis, db, blocks, generator)); err != nil { + blockchain.SetProcessor(blockproc) + if _, err := blockchain.InsertChain(core.GenerateChain(genesis, db, blocks, generator)); err != nil { panic(err) } - pm := NewProtocolManager(NetworkId, evmux, &testTxPool{added: newtx}, pow, chainman, db) + pm := NewProtocolManager(NetworkId, evmux, &testTxPool{added: newtx}, pow, blockchain, db) pm.Start() return pm } @@ -116,7 +116,7 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te } // Execute any implicitly requested handshakes and return if shake { - td, head, genesis := pm.chainman.Status() + td, head, genesis := pm.blockchain.Status() tp.handshake(nil, td, head, genesis) } return tp, errc diff --git a/eth/protocol_test.go b/eth/protocol_test.go index bc3b5acfc..523e6c1eb 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -45,7 +45,7 @@ func TestStatusMsgErrors64(t *testing.T) { testStatusMsgErrors(t, 64) } func testStatusMsgErrors(t *testing.T, protocol int) { pm := newTestProtocolManager(0, nil, nil) - td, currentBlock, genesis := pm.chainman.Status() + td, currentBlock, genesis := pm.blockchain.Status() defer pm.Stop() tests := []struct { diff --git a/eth/sync.go b/eth/sync.go index b4dea4b0f..5a2031c68 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -160,7 +160,7 @@ func (pm *ProtocolManager) synchronise(peer *peer) { return } // Make sure the peer's TD is higher than our own. If not drop. - if peer.Td().Cmp(pm.chainman.Td()) <= 0 { + if peer.Td().Cmp(pm.blockchain.Td()) <= 0 { return } // Otherwise try to sync with the downloader diff --git a/event/filter/eth_filter.go b/event/filter/eth_filter.go deleted file mode 100644 index 6f61e2b60..000000000 --- a/event/filter/eth_filter.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package filter - -// TODO make use of the generic filtering system - -import ( - "sync" - - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/event" -) - -type FilterManager struct { - eventMux *event.TypeMux - - filterMu sync.RWMutex - filterId int - filters map[int]*core.Filter - - quit chan struct{} -} - -func NewFilterManager(mux *event.TypeMux) *FilterManager { - return &FilterManager{ - eventMux: mux, - filters: make(map[int]*core.Filter), - } -} - -func (self *FilterManager) Start() { - go self.filterLoop() -} - -func (self *FilterManager) Stop() { - close(self.quit) -} - -func (self *FilterManager) InstallFilter(filter *core.Filter) (id int) { - self.filterMu.Lock() - defer self.filterMu.Unlock() - id = self.filterId - self.filters[id] = filter - self.filterId++ - - return id -} - -func (self *FilterManager) UninstallFilter(id int) { - self.filterMu.Lock() - defer self.filterMu.Unlock() - if _, ok := self.filters[id]; ok { - delete(self.filters, id) - } -} - -// GetFilter retrieves a filter installed using InstallFilter. -// The filter may not be modified. -func (self *FilterManager) GetFilter(id int) *core.Filter { - self.filterMu.RLock() - defer self.filterMu.RUnlock() - return self.filters[id] -} - -func (self *FilterManager) filterLoop() { - // Subscribe to events - events := self.eventMux.Subscribe( - //core.PendingBlockEvent{}, - core.ChainEvent{}, - core.TxPreEvent{}, - state.Logs(nil)) - -out: - for { - select { - case <-self.quit: - break out - case event := <-events.Chan(): - switch event := event.(type) { - case core.ChainEvent: - self.filterMu.RLock() - for _, filter := range self.filters { - if filter.BlockCallback != nil { - filter.BlockCallback(event.Block, event.Logs) - } - } - self.filterMu.RUnlock() - - case core.TxPreEvent: - self.filterMu.RLock() - for _, filter := range self.filters { - if filter.TransactionCallback != nil { - filter.TransactionCallback(event.Tx) - } - } - self.filterMu.RUnlock() - - case state.Logs: - self.filterMu.RLock() - for _, filter := range self.filters { - if filter.LogsCallback != nil { - msgs := filter.FilterLogs(event) - if len(msgs) > 0 { - filter.LogsCallback(msgs) - } - } - } - self.filterMu.RUnlock() - } - } - } -} diff --git a/miner/worker.go b/miner/worker.go index 22d0b9b6e..8be2db93e 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" @@ -99,7 +100,7 @@ type worker struct { pow pow.PoW eth core.Backend - chain *core.ChainManager + chain *core.BlockChain proc *core.BlockProcessor chainDb ethdb.Database @@ -130,7 +131,7 @@ func newWorker(coinbase common.Address, eth core.Backend) *worker { chainDb: eth.ChainDb(), recv: make(chan *Result, resultQueueSize), gasPrice: new(big.Int), - chain: eth.ChainManager(), + chain: eth.BlockChain(), proc: eth.BlockProcessor(), possibleUncles: make(map[common.Hash]*types.Block), coinbase: coinbase, @@ -266,7 +267,6 @@ func (self *worker) wait() { block := result.Block work := result.Work - work.state.Sync() if self.fullValidation { if _, err := self.chain.InsertChain(types.Blocks{block}); err != nil { glog.V(logger.Error).Infoln("mining err", err) @@ -274,6 +274,7 @@ func (self *worker) wait() { } go self.mux.Post(core.NewMinedBlockEvent{block}) } else { + work.state.Commit() parent := self.chain.GetBlock(block.ParentHash()) if parent == nil { glog.V(logger.Error).Infoln("Invalid block found during mining") @@ -298,7 +299,7 @@ func (self *worker) wait() { } // broadcast before waiting for validation - go func(block *types.Block, logs state.Logs, receipts []*types.Receipt) { + go func(block *types.Block, logs vm.Logs, receipts []*types.Receipt) { self.mux.Post(core.NewMinedBlockEvent{block}) self.mux.Post(core.ChainEvent{block, block.Hash(), logs}) if stat == core.CanonStatTy { @@ -528,8 +529,7 @@ func (self *worker) commitNewWork() { if atomic.LoadInt32(&self.mining) == 1 { // commit state root after all state transitions. core.AccumulateRewards(work.state, header, uncles) - work.state.SyncObjects() - header.Root = work.state.Root() + header.Root = work.state.IntermediateRoot() } // create the new block whose nonce will be mined. diff --git a/p2p/discover/database.go b/p2p/discover/database.go index d5c594364..e8e3371ff 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -21,6 +21,7 @@ package discover import ( "bytes" + "crypto/rand" "encoding/binary" "os" "sync" @@ -46,11 +47,8 @@ var ( // nodeDB stores all nodes we know about. type nodeDB struct { - lvl *leveldb.DB // Interface to the database itself - seeder iterator.Iterator // Iterator for fetching possible seed nodes - - self NodeID // Own node id to prevent adding it into the database - + lvl *leveldb.DB // Interface to the database itself + self NodeID // Own node id to prevent adding it into the database runner sync.Once // Ensures we can start at most one expirer quit chan struct{} // Channel to signal the expiring thread to stop } @@ -302,52 +300,70 @@ func (db *nodeDB) updateFindFails(id NodeID, fails int) error { return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails)) } -// querySeeds retrieves a batch of nodes to be used as potential seed servers -// during bootstrapping the node into the network. -// -// Ideal seeds are the most recently seen nodes (highest probability to be still -// alive), but yet untried. However, since leveldb only supports dumb iteration -// we will instead start pulling in potential seeds that haven't been yet pinged -// since the start of the boot procedure. -// -// If the database runs out of potential seeds, we restart the startup counter -// and start iterating over the peers again. -func (db *nodeDB) querySeeds(n int) []*Node { - // Create a new seed iterator if none exists - if db.seeder == nil { - db.seeder = db.lvl.NewIterator(nil, nil) +// querySeeds retrieves random nodes to be used as potential seed nodes +// for bootstrapping. +func (db *nodeDB) querySeeds(n int, maxAge time.Duration) []*Node { + var ( + now = time.Now() + nodes = make([]*Node, 0, n) + it = db.lvl.NewIterator(nil, nil) + id NodeID + ) + defer it.Release() + +seek: + for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ { + // Seek to a random entry. The first byte is incremented by a + // random amount each time in order to increase the likelihood + // of hitting all existing nodes in very small databases. + ctr := id[0] + rand.Read(id[:]) + id[0] = ctr + id[0]%16 + it.Seek(makeKey(id, nodeDBDiscoverRoot)) + + n := nextNode(it) + if n == nil { + id[0] = 0 + continue seek // iterator exhausted + } + if n.ID == db.self { + continue seek + } + if now.Sub(db.lastPong(n.ID)) > maxAge { + continue seek + } + for i := range nodes { + if nodes[i].ID == n.ID { + continue seek // duplicate + } + } + nodes = append(nodes, n) } - // Iterate over the nodes and find suitable seeds - nodes := make([]*Node, 0, n) - for len(nodes) < n && db.seeder.Next() { - // Iterate until a discovery node is found - id, field := splitKey(db.seeder.Key()) + return nodes +} + +// reads the next node record from the iterator, skipping over other +// database entries. +func nextNode(it iterator.Iterator) *Node { + for end := false; !end; end = !it.Next() { + id, field := splitKey(it.Key()) if field != nodeDBDiscoverRoot { continue } - // Dump it if its a self reference - if bytes.Compare(id[:], db.self[:]) == 0 { - db.deleteNode(id) + var n Node + if err := rlp.DecodeBytes(it.Value(), &n); err != nil { + if glog.V(logger.Warn) { + glog.Errorf("invalid node %x: %v", id, err) + } continue } - // Load it as a potential seed - if node := db.node(id); node != nil { - nodes = append(nodes, node) - } - } - // Release the iterator if we reached the end - if len(nodes) == 0 { - db.seeder.Release() - db.seeder = nil + return &n } - return nodes + return nil } // close flushes and closes the database files. func (db *nodeDB) close() { - if db.seeder != nil { - db.seeder.Release() - } close(db.quit) db.lvl.Close() } diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index 569585903..80c1a6ff2 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -162,9 +162,33 @@ var nodeDBSeedQueryNodes = []struct { node *Node pong time.Time }{ + // This one should not be in the result set because its last + // pong time is too far in the past. { node: newNode( - MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + MustHexID("0x84d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-3 * time.Hour), + }, + // This one shouldn't be in in the result set because its + // nodeID is the local node's ID. + { + node: newNode( + MustHexID("0x57d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-4 * time.Second), + }, + + // These should be in the result set. + { + node: newNode( + MustHexID("0x22d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), net.IP{127, 0, 0, 1}, 30303, 30303, @@ -173,7 +197,7 @@ var nodeDBSeedQueryNodes = []struct { }, { node: newNode( - MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + MustHexID("0x44d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), net.IP{127, 0, 0, 2}, 30303, 30303, @@ -182,7 +206,7 @@ var nodeDBSeedQueryNodes = []struct { }, { node: newNode( - MustHexID("0x03d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + MustHexID("0xe2d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), net.IP{127, 0, 0, 3}, 30303, 30303, @@ -192,7 +216,7 @@ var nodeDBSeedQueryNodes = []struct { } func TestNodeDBSeedQuery(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID) defer db.close() // Insert a batch of nodes for querying @@ -200,20 +224,24 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } + if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to insert lastPong: %v", i, err) + } } + // Retrieve the entire batch and check for duplicates - seeds := db.querySeeds(2 * len(nodeDBSeedQueryNodes)) - if len(seeds) != len(nodeDBSeedQueryNodes) { - t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(nodeDBSeedQueryNodes)) - } + seeds := db.querySeeds(len(nodeDBSeedQueryNodes)*2, time.Hour) have := make(map[NodeID]struct{}) for _, seed := range seeds { have[seed.ID] = struct{}{} } want := make(map[NodeID]struct{}) - for _, seed := range nodeDBSeedQueryNodes { + for _, seed := range nodeDBSeedQueryNodes[2:] { want[seed.node.ID] = struct{}{} } + if len(seeds) != len(want) { + t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(want)) + } for id, _ := range have { if _, ok := want[id]; !ok { t.Errorf("extra seed: %v", id) @@ -224,63 +252,6 @@ func TestNodeDBSeedQuery(t *testing.T) { t.Errorf("missing seed: %v", id) } } - // Make sure the next batch is empty (seed EOF) - seeds = db.querySeeds(2 * len(nodeDBSeedQueryNodes)) - if len(seeds) != 0 { - t.Errorf("seed count mismatch: have %v, want %v", len(seeds), 0) - } -} - -func TestNodeDBSeedQueryContinuation(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) - defer db.close() - - // Insert a batch of nodes for querying - for i, seed := range nodeDBSeedQueryNodes { - if err := db.updateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) - } - } - // Iteratively retrieve the batch, checking for an empty batch on reset - for i := 0; i < len(nodeDBSeedQueryNodes); i++ { - if seeds := db.querySeeds(1); len(seeds) != 1 { - t.Errorf("1st iteration %d: seed count mismatch: have %v, want %v", i, len(seeds), 1) - } - } - if seeds := db.querySeeds(1); len(seeds) != 0 { - t.Errorf("reset: seed count mismatch: have %v, want %v", len(seeds), 0) - } - for i := 0; i < len(nodeDBSeedQueryNodes); i++ { - if seeds := db.querySeeds(1); len(seeds) != 1 { - t.Errorf("2nd iteration %d: seed count mismatch: have %v, want %v", i, len(seeds), 1) - } - } -} - -func TestNodeDBSelfSeedQuery(t *testing.T) { - // Assign a node as self to verify evacuation - self := nodeDBSeedQueryNodes[0].node.ID - db, _ := newNodeDB("", Version, self) - defer db.close() - - // Insert a batch of nodes for querying - for i, seed := range nodeDBSeedQueryNodes { - if err := db.updateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) - } - } - // Retrieve the entire batch and check that self was evacuated - seeds := db.querySeeds(2 * len(nodeDBSeedQueryNodes)) - if len(seeds) != len(nodeDBSeedQueryNodes)-1 { - t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(nodeDBSeedQueryNodes)-1) - } - have := make(map[NodeID]struct{}) - for _, seed := range seeds { - have[seed.ID] = struct{}{} - } - if _, ok := have[self]; ok { - t.Errorf("self not evacuated") - } } func TestNodeDBPersistency(t *testing.T) { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 972bc1077..c128c2ed1 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -44,6 +44,10 @@ const ( maxBondingPingPongs = 16 maxFindnodeFailures = 5 + + autoRefreshInterval = 1 * time.Hour + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { @@ -52,6 +56,10 @@ type Table struct { nursery []*Node // bootstrap nodes db *nodeDB // database of known nodes + refreshReq chan struct{} + closeReq chan struct{} + closed chan struct{} + bondmu sync.Mutex bonding map[NodeID]*bondproc bondslots chan struct{} // limits total number of active bonding processes @@ -80,10 +88,7 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. -type bucket struct { - lastLookup time.Time - entries []*Node -} +type bucket struct{ entries []*Node } func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) *Table { // If no node database was given, use an in-memory one @@ -93,11 +98,14 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string db, _ = newNodeDB("", Version, ourID) } tab := &Table{ - net: t, - db: db, - self: newNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), - bonding: make(map[NodeID]*bondproc), - bondslots: make(chan struct{}, maxBondingPingPongs), + net: t, + db: db, + self: newNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), + bonding: make(map[NodeID]*bondproc), + bondslots: make(chan struct{}, maxBondingPingPongs), + refreshReq: make(chan struct{}), + closeReq: make(chan struct{}), + closed: make(chan struct{}), } for i := 0; i < cap(tab.bondslots); i++ { tab.bondslots <- struct{}{} @@ -105,6 +113,7 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string for i := range tab.buckets { tab.buckets[i] = new(bucket) } + go tab.refreshLoop() return tab } @@ -163,10 +172,12 @@ func randUint(max uint32) uint32 { // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { - if tab.net != nil { - tab.net.close() + select { + case <-tab.closed: + // already closed. + case tab.closeReq <- struct{}{}: + <-tab.closed // wait for refreshLoop to end. } - tab.db.close() } // Bootstrap sets the bootstrap nodes. These nodes are used to connect @@ -183,7 +194,7 @@ func (tab *Table) Bootstrap(nodes []*Node) { tab.nursery = append(tab.nursery, &cpy) } tab.mutex.Unlock() - tab.refresh() + tab.requestRefresh() } // Lookup performs a network search for nodes close @@ -204,15 +215,13 @@ func (tab *Table) Lookup(targetID NodeID) []*Node { asked[tab.self.ID] = true tab.mutex.Lock() - // update last lookup stamp (for refresh logic) - tab.buckets[logdist(tab.self.sha, target)].lastLookup = time.Now() // generate initial result set result := tab.closest(target, bucketSize) tab.mutex.Unlock() - // If the result set is empty, all nodes were dropped, refresh + // If the result set is empty, all nodes were dropped, refresh. if len(result.entries) == 0 { - tab.refresh() + tab.requestRefresh() return nil } @@ -257,56 +266,86 @@ func (tab *Table) Lookup(targetID NodeID) []*Node { return result.entries } -// refresh performs a lookup for a random target to keep buckets full, or seeds -// the table if it is empty (initial bootstrap or discarded faulty peers). -func (tab *Table) refresh() { - seed := true +func (tab *Table) requestRefresh() { + select { + case tab.refreshReq <- struct{}{}: + case <-tab.closed: + } +} - // If the discovery table is empty, seed with previously known nodes - tab.mutex.Lock() - for _, bucket := range tab.buckets { - if len(bucket.entries) > 0 { - seed = false - break +func (tab *Table) refreshLoop() { + defer func() { + tab.db.close() + if tab.net != nil { + tab.net.close() } - } - tab.mutex.Unlock() + close(tab.closed) + }() - // If the table is not empty, try to refresh using the live entries - if !seed { - // The Kademlia paper specifies that the bucket refresh should - // perform a refresh in the least recently used bucket. We cannot - // adhere to this because the findnode target is a 512bit value - // (not hash-sized) and it is not easily possible to generate a - // sha3 preimage that falls into a chosen bucket. - // - // We perform a lookup with a random target instead. - var target NodeID - rand.Read(target[:]) - - result := tab.Lookup(target) - if len(result) == 0 { - // Lookup failed, seed after all - seed = true + timer := time.NewTicker(autoRefreshInterval) + var done chan struct{} + for { + select { + case <-timer.C: + if done == nil { + done = make(chan struct{}) + go tab.doRefresh(done) + } + case <-tab.refreshReq: + if done == nil { + done = make(chan struct{}) + go tab.doRefresh(done) + } + case <-done: + done = nil + case <-tab.closeReq: + if done != nil { + <-done + } + return } } +} - if seed { - // Pick a batch of previously know seeds to lookup with - seeds := tab.db.querySeeds(10) - for _, seed := range seeds { - glog.V(logger.Debug).Infoln("Seeding network with", seed) - } - nodes := append(tab.nursery, seeds...) +// doRefresh performs a lookup for a random target to keep buckets +// full. seed nodes are inserted if the table is empty (initial +// bootstrap or discarded faulty peers). +func (tab *Table) doRefresh(done chan struct{}) { + defer close(done) + + // The Kademlia paper specifies that the bucket refresh should + // perform a lookup in the least recently used bucket. We cannot + // adhere to this because the findnode target is a 512bit value + // (not hash-sized) and it is not easily possible to generate a + // sha3 preimage that falls into a chosen bucket. + // We perform a lookup with a random target instead. + var target NodeID + rand.Read(target[:]) + result := tab.Lookup(target) + if len(result) > 0 { + return + } - // Bond with all the seed nodes (will pingpong only if failed recently) - bonded := tab.bondall(nodes) - if len(bonded) > 0 { - tab.Lookup(tab.self.ID) + // The table is empty. Load nodes from the database and insert + // them. This should yield a few previously seen nodes that are + // (hopefully) still alive. + seeds := tab.db.querySeeds(seedCount, seedMaxAge) + seeds = tab.bondall(append(seeds, tab.nursery...)) + if glog.V(logger.Debug) { + if len(seeds) == 0 { + glog.Infof("no seed nodes found") + } + for _, n := range seeds { + age := time.Since(tab.db.lastPong(n.ID)) + glog.Infof("seed node (age %v): %v", age, n) } - // TODO: the Kademlia paper says that we're supposed to perform - // random lookups in all buckets further away than our closest neighbor. } + tab.mutex.Lock() + tab.stuff(seeds) + tab.mutex.Unlock() + + // Finally, do a self lookup to fill up the buckets. + tab.Lookup(tab.self.ID) } // closest returns the n nodes in the table that are closest to the @@ -373,8 +412,9 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 } // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch var result error - if node == nil || fails > 0 { - glog.V(logger.Detail).Infof("Bonding %x: known=%v, fails=%v", id[:8], node != nil, fails) + age := time.Since(tab.db.lastPong(id)) + if node == nil || fails > 0 || age > nodeDBNodeExpiration { + glog.V(logger.Detail).Infof("Bonding %x: known=%t, fails=%d age=%v", id[:8], node != nil, fails, age) tab.bondmu.Lock() w := tab.bonding[id] @@ -435,13 +475,17 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd // ping a remote endpoint and wait for a reply, also updating the node // database accordingly. func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { - // Update the last ping and send the message tab.db.updateLastPing(id, time.Now()) if err := tab.net.ping(id, addr); err != nil { return err } - // Pong received, update the database and return tab.db.updateLastPong(id, time.Now()) + + // Start the background expiration goroutine after the first + // successful communication. Subsequent calls have no effect if it + // is already running. We do this here instead of somewhere else + // so that the search for seed nodes also considers older nodes + // that would otherwise be removed by the expiration. tab.db.ensureExpirer() return nil } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 426f4e9cc..84962a1a5 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -514,9 +514,6 @@ func (tn *preminedTestnet) findnode(toid NodeID, toaddr *net.UDPAddr, target Nod if toaddr.Port == 0 { panic("query to node at distance 0") } - if target != tn.target { - panic("findnode with wrong target") - } next := uint16(toaddr.Port) - 1 var result []*Node for i, id := range tn.dists[toaddr.Port] { diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index afb31ee69..20f69cf08 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -39,7 +39,6 @@ var ( errPacketTooSmall = errors.New("too small") errBadHash = errors.New("bad hash") errExpired = errors.New("expired") - errBadVersion = errors.New("version mismatch") errUnsolicitedReply = errors.New("unsolicited reply") errUnknownNode = errors.New("unknown node") errTimeout = errors.New("RPC timeout") @@ -52,8 +51,6 @@ const ( respTimeout = 500 * time.Millisecond sendTimeout = 500 * time.Millisecond expiration = 20 * time.Second - - refreshInterval = 1 * time.Hour ) // RPC packet types @@ -312,10 +309,8 @@ func (t *udp) loop() { plist = list.New() timeout = time.NewTimer(0) nextTimeout *pending // head of plist when timeout was last reset - refresh = time.NewTicker(refreshInterval) ) <-timeout.C // ignore first timeout - defer refresh.Stop() defer timeout.Stop() resetTimeout := func() { @@ -344,9 +339,6 @@ func (t *udp) loop() { resetTimeout() select { - case <-refresh.C: - go t.refresh() - case <-t.closing: for el := plist.Front(); el != nil; el = el.Next() { el.Value.(*pending).errc <- errClosed @@ -529,9 +521,6 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if expired(req.Expiration) { return errExpired } - if req.Version != Version { - return errBadVersion - } t.send(from, pongPacket, pong{ To: makeEndpoint(from, req.From.TCP), ReplyTok: mac, diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index a86d3737b..913199c26 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -122,7 +122,6 @@ func TestUDP_packetErrors(t *testing.T) { defer test.table.Close() test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version}) - test.packetIn(errBadVersion, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 99, Expiration: futureExp}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) diff --git a/rpc/api/admin.go b/rpc/api/admin.go index 8af69b189..6aa04e667 100644 --- a/rpc/api/admin.go +++ b/rpc/api/admin.go @@ -151,7 +151,7 @@ func (self *adminApi) DataDir(req *shared.Request) (interface{}, error) { return self.ethereum.DataDir, nil } -func hasAllBlocks(chain *core.ChainManager, bs []*types.Block) bool { +func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool { for _, b := range bs { if !chain.HasBlock(b.Hash()) { return false @@ -193,10 +193,10 @@ func (self *adminApi) ImportChain(req *shared.Request) (interface{}, error) { break } // Import the batch. - if hasAllBlocks(self.ethereum.ChainManager(), blocks[:i]) { + if hasAllBlocks(self.ethereum.BlockChain(), blocks[:i]) { continue } - if _, err := self.ethereum.ChainManager().InsertChain(blocks[:i]); err != nil { + if _, err := self.ethereum.BlockChain().InsertChain(blocks[:i]); err != nil { return false, fmt.Errorf("invalid block %d: %v", n, err) } } @@ -214,7 +214,7 @@ func (self *adminApi) ExportChain(req *shared.Request) (interface{}, error) { return false, err } defer fh.Close() - if err := self.ethereum.ChainManager().Export(fh); err != nil { + if err := self.ethereum.BlockChain().Export(fh); err != nil { return false, err } diff --git a/rpc/api/debug.go b/rpc/api/debug.go index d325b1720..e193f7ad2 100644 --- a/rpc/api/debug.go +++ b/rpc/api/debug.go @@ -152,7 +152,7 @@ func (self *debugApi) SetHead(req *shared.Request) (interface{}, error) { return nil, fmt.Errorf("block #%d not found", args.BlockNumber) } - self.ethereum.ChainManager().SetHead(block) + self.ethereum.BlockChain().SetHead(block) return nil, nil } diff --git a/rpc/api/eth.go b/rpc/api/eth.go index 4cd5f2695..6db006a46 100644 --- a/rpc/api/eth.go +++ b/rpc/api/eth.go @@ -168,7 +168,7 @@ func (self *ethApi) IsMining(req *shared.Request) (interface{}, error) { } func (self *ethApi) IsSyncing(req *shared.Request) (interface{}, error) { - current := self.ethereum.ChainManager().CurrentBlock().NumberU64() + current := self.ethereum.BlockChain().CurrentBlock().NumberU64() origin, height := self.ethereum.Downloader().Boundaries() if current < height { diff --git a/rpc/api/eth_args.go b/rpc/api/eth_args.go index 8bd077e20..66c190a51 100644 --- a/rpc/api/eth_args.go +++ b/rpc/api/eth_args.go @@ -24,8 +24,8 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/rpc/shared" ) @@ -830,7 +830,7 @@ type LogRes struct { TransactionIndex *hexnum `json:"transactionIndex"` } -func NewLogRes(log *state.Log) LogRes { +func NewLogRes(log *vm.Log) LogRes { var l LogRes l.Topics = make([]*hexdata, len(log.Topics)) for j, topic := range log.Topics { @@ -847,7 +847,7 @@ func NewLogRes(log *state.Log) LogRes { return l } -func NewLogsRes(logs state.Logs) (ls []LogRes) { +func NewLogsRes(logs vm.Logs) (ls []LogRes) { ls = make([]LogRes, len(logs)) for i, log := range logs { diff --git a/tests/block_test_util.go b/tests/block_test_util.go index 3ca00bae8..fb9ca16e6 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -181,7 +181,7 @@ func runBlockTest(test *BlockTest) error { return fmt.Errorf("InsertPreState: %v", err) } - cm := ethereum.ChainManager() + cm := ethereum.BlockChain() validBlocks, err := test.TryBlocksInsert(cm) if err != nil { return err @@ -253,13 +253,13 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro statedb.SetState(common.HexToAddress(addrString), common.HexToHash(k), common.HexToHash(v)) } } - // sync objects to trie - statedb.SyncObjects() - // sync trie to disk - statedb.Sync() - if !bytes.Equal(t.Genesis.Root().Bytes(), statedb.Root().Bytes()) { - return nil, fmt.Errorf("computed state root does not match genesis block %x %x", t.Genesis.Root().Bytes()[:4], statedb.Root().Bytes()[:4]) + root, err := statedb.Commit() + if err != nil { + return nil, fmt.Errorf("error writing state: %v", err) + } + if t.Genesis.Root() != root { + return nil, fmt.Errorf("computed state root does not match genesis block: genesis=%x computed=%x", t.Genesis.Root().Bytes()[:4], root.Bytes()[:4]) } return statedb, nil } @@ -276,7 +276,7 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro expected we are expected to ignore it and continue processing and then validate the post state. */ -func (t *BlockTest) TryBlocksInsert(chainManager *core.ChainManager) ([]btBlock, error) { +func (t *BlockTest) TryBlocksInsert(blockchain *core.BlockChain) ([]btBlock, error) { validBlocks := make([]btBlock, 0) // insert the test blocks, which will execute all transactions for _, b := range t.Json.Blocks { @@ -289,7 +289,7 @@ func (t *BlockTest) TryBlocksInsert(chainManager *core.ChainManager) ([]btBlock, } } // RLP decoding worked, try to insert into chain: - _, err = chainManager.InsertChain(types.Blocks{cb}) + _, err = blockchain.InsertChain(types.Blocks{cb}) if err != nil { if b.BlockHeader == nil { continue // OK - block is supposed to be invalid, continue with next block @@ -426,7 +426,7 @@ func (t *BlockTest) ValidatePostState(statedb *state.StateDB) error { return nil } -func (test *BlockTest) ValidateImportedHeaders(cm *core.ChainManager, validBlocks []btBlock) error { +func (test *BlockTest) ValidateImportedHeaders(cm *core.BlockChain, validBlocks []btBlock) error { // to get constant lookup when verifying block headers by hash (some tests have many blocks) bmap := make(map[string]btBlock, len(test.Json.Blocks)) for _, b := range validBlocks { diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 95ecdd0a8..a1c066c82 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -168,7 +168,7 @@ func runStateTest(test VmTest) error { ret []byte // gas *big.Int // err error - logs state.Logs + logs vm.Logs ) ret, logs, _, _ = RunState(statedb, env, test.Transaction) @@ -201,9 +201,9 @@ func runStateTest(test VmTest) error { } } - statedb.Sync() - if common.HexToHash(test.PostStateRoot) != statedb.Root() { - return fmt.Errorf("Post state root error. Expected %s, got %x", test.PostStateRoot, statedb.Root()) + root, _ := statedb.Commit() + if common.HexToHash(test.PostStateRoot) != root { + return fmt.Errorf("Post state root error. Expected %s, got %x", test.PostStateRoot, root) } // check logs @@ -216,7 +216,7 @@ func runStateTest(test VmTest) error { return nil } -func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state.Logs, *big.Int, error) { +func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, vm.Logs, *big.Int, error) { var ( data = common.FromHex(tx["data"]) gas = common.Big(tx["gasLimit"]) @@ -247,7 +247,7 @@ func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state. if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) { statedb.Set(snapshot) } - statedb.SyncObjects() + statedb.Commit() return ret, vmenv.state.Logs(), vmenv.Gas, err } diff --git a/tests/util.go b/tests/util.go index 72d927ada..fb9e518c8 100644 --- a/tests/util.go +++ b/tests/util.go @@ -30,7 +30,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" ) -func checkLogs(tlog []Log, logs state.Logs) error { +func checkLogs(tlog []Log, logs vm.Logs) error { if len(tlog) != len(logs) { return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs)) @@ -53,7 +53,7 @@ func checkLogs(tlog []Log, logs state.Logs) error { } } } - genBloom := common.LeftPadBytes(types.LogsBloom(state.Logs{logs[i]}).Bytes(), 256) + genBloom := common.LeftPadBytes(types.LogsBloom(vm.Logs{logs[i]}).Bytes(), 256) if !bytes.Equal(genBloom, common.Hex2Bytes(log.BloomF)) { return fmt.Errorf("bloom mismatch") @@ -181,18 +181,18 @@ func (self *Env) BlockNumber() *big.Int { return self.number } func (self *Env) Coinbase() common.Address { return self.coinbase } func (self *Env) Time() *big.Int { return self.time } func (self *Env) Difficulty() *big.Int { return self.difficulty } -func (self *Env) State() *state.StateDB { return self.state } +func (self *Env) Db() vm.Database { return self.state } func (self *Env) GasLimit() *big.Int { return self.gasLimit } func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) GetHash(n uint64) common.Hash { return common.BytesToHash(crypto.Sha3([]byte(big.NewInt(int64(n)).String()))) } -func (self *Env) AddLog(log *state.Log) { +func (self *Env) AddLog(log *vm.Log) { self.state.AddLog(log) } func (self *Env) Depth() int { return self.depth } func (self *Env) SetDepth(i int) { self.depth = i } -func (self *Env) CanTransfer(from vm.Account, balance *big.Int) bool { +func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { if self.skipTransfer { if self.initial { self.initial = false @@ -200,58 +200,53 @@ func (self *Env) CanTransfer(from vm.Account, balance *big.Int) bool { } } - return from.Balance().Cmp(balance) >= 0 + return self.state.GetBalance(from).Cmp(balance) >= 0 +} +func (self *Env) MakeSnapshot() vm.Database { + return self.state.Copy() +} +func (self *Env) SetSnapshot(copy vm.Database) { + self.state.Set(copy.(*state.StateDB)) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) error { if self.skipTransfer { return nil } - return vm.Transfer(from, to, amount) -} - -func (self *Env) vm(addr *common.Address, data []byte, gas, price, value *big.Int) *core.Execution { - exec := core.NewExecution(self, addr, data, gas, price, value) - - return exec + return core.Transfer(from, to, amount) } -func (self *Env) Call(caller vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { +func (self *Env) Call(caller vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { if self.vmTest && self.depth > 0 { caller.ReturnGas(gas, price) return nil, nil } - exe := self.vm(&addr, data, gas, price, value) - ret, err := exe.Call(addr, caller) - self.Gas = exe.Gas + ret, err := core.Call(self, caller, addr, data, gas, price, value) + self.Gas = gas return ret, err } -func (self *Env) CallCode(caller vm.ContextRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { +func (self *Env) CallCode(caller vm.ContractRef, addr common.Address, data []byte, gas, price, value *big.Int) ([]byte, error) { if self.vmTest && self.depth > 0 { caller.ReturnGas(gas, price) return nil, nil } - - caddr := caller.Address() - exe := self.vm(&caddr, data, gas, price, value) - return exe.Call(addr, caller) + return core.CallCode(self, caller, addr, data, gas, price, value) } -func (self *Env) Create(caller vm.ContextRef, data []byte, gas, price, value *big.Int) ([]byte, error, vm.ContextRef) { - exe := self.vm(nil, data, gas, price, value) +func (self *Env) Create(caller vm.ContractRef, data []byte, gas, price, value *big.Int) ([]byte, common.Address, error) { if self.vmTest { caller.ReturnGas(gas, price) nonce := self.state.GetNonce(caller.Address()) obj := self.state.GetOrNewStateObject(crypto.CreateAddress(caller.Address(), nonce)) - return nil, nil, obj + return nil, obj.Address(), nil } else { - return exe.Create(caller) + return core.Create(self, caller, data, gas, price, value) } } diff --git a/tests/vm_test.go b/tests/vm_test.go index 96718db3c..34beb85e5 100644 --- a/tests/vm_test.go +++ b/tests/vm_test.go @@ -24,14 +24,14 @@ import ( func BenchmarkVmAckermann32Tests(b *testing.B) { fn := filepath.Join(vmTestDir, "vmPerformanceTest.json") - if err := BenchVmTest(fn, bconf{"ackermann32", true, os.Getenv("JITVM") == "true"}, b); err != nil { + if err := BenchVmTest(fn, bconf{"ackermann32", os.Getenv("JITFORCE") == "true", os.Getenv("JITVM") == "true"}, b); err != nil { b.Error(err) } } func BenchmarkVmFibonacci16Tests(b *testing.B) { fn := filepath.Join(vmTestDir, "vmPerformanceTest.json") - if err := BenchVmTest(fn, bconf{"fibonacci16", true, os.Getenv("JITVM") == "true"}, b); err != nil { + if err := BenchVmTest(fn, bconf{"fibonacci16", os.Getenv("JITFORCE") == "true", os.Getenv("JITVM") == "true"}, b); err != nil { b.Error(err) } } diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 71a4f5e33..b61995e31 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -185,7 +185,7 @@ func runVmTest(test VmTest) error { ret []byte gas *big.Int err error - logs state.Logs + logs vm.Logs ) ret, logs, gas, err = RunVm(statedb, env, test.Exec) @@ -234,7 +234,7 @@ func runVmTest(test VmTest) error { return nil } -func RunVm(state *state.StateDB, env, exec map[string]string) ([]byte, state.Logs, *big.Int, error) { +func RunVm(state *state.StateDB, env, exec map[string]string) ([]byte, vm.Logs, *big.Int, error) { var ( to = common.HexToAddress(exec["address"]) from = common.HexToAddress(exec["caller"]) diff --git a/trie/arc.go b/trie/arc.go new file mode 100644 index 000000000..9da012e16 --- /dev/null +++ b/trie/arc.go @@ -0,0 +1,194 @@ +// Copyright (c) 2015 Hans Alexander Gugel <alexander.gugel@gmail.com> +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// This file contains a modified version of package arc from +// https://github.com/alexanderGugel/arc +// +// It implements the ARC (Adaptive Replacement Cache) algorithm as detailed in +// https://www.usenix.org/legacy/event/fast03/tech/full_papers/megiddo/megiddo.pdf + +package trie + +import ( + "container/list" + "sync" +) + +type arc struct { + p int + c int + t1 *list.List + b1 *list.List + t2 *list.List + b2 *list.List + cache map[string]*entry + mutex sync.Mutex +} + +type entry struct { + key hashNode + value node + ll *list.List + el *list.Element +} + +// newARC returns a new Adaptive Replacement Cache with the +// given capacity. +func newARC(c int) *arc { + return &arc{ + c: c, + t1: list.New(), + b1: list.New(), + t2: list.New(), + b2: list.New(), + cache: make(map[string]*entry, c), + } +} + +// Put inserts a new key-value pair into the cache. +// This optimizes future access to this entry (side effect). +func (a *arc) Put(key hashNode, value node) bool { + a.mutex.Lock() + defer a.mutex.Unlock() + ent, ok := a.cache[string(key)] + if ok != true { + ent = &entry{key: key, value: value} + a.req(ent) + a.cache[string(key)] = ent + } else { + ent.value = value + a.req(ent) + } + return ok +} + +// Get retrieves a previously via Set inserted entry. +// This optimizes future access to this entry (side effect). +func (a *arc) Get(key hashNode) (value node, ok bool) { + a.mutex.Lock() + defer a.mutex.Unlock() + ent, ok := a.cache[string(key)] + if ok { + a.req(ent) + return ent.value, ent.value != nil + } + return nil, false +} + +func (a *arc) req(ent *entry) { + if ent.ll == a.t1 || ent.ll == a.t2 { + // Case I + ent.setMRU(a.t2) + } else if ent.ll == a.b1 { + // Case II + // Cache Miss in t1 and t2 + + // Adaptation + var d int + if a.b1.Len() >= a.b2.Len() { + d = 1 + } else { + d = a.b2.Len() / a.b1.Len() + } + a.p = a.p + d + if a.p > a.c { + a.p = a.c + } + + a.replace(ent) + ent.setMRU(a.t2) + } else if ent.ll == a.b2 { + // Case III + // Cache Miss in t1 and t2 + + // Adaptation + var d int + if a.b2.Len() >= a.b1.Len() { + d = 1 + } else { + d = a.b1.Len() / a.b2.Len() + } + a.p = a.p - d + if a.p < 0 { + a.p = 0 + } + + a.replace(ent) + ent.setMRU(a.t2) + } else if ent.ll == nil { + // Case IV + + if a.t1.Len()+a.b1.Len() == a.c { + // Case A + if a.t1.Len() < a.c { + a.delLRU(a.b1) + a.replace(ent) + } else { + a.delLRU(a.t1) + } + } else if a.t1.Len()+a.b1.Len() < a.c { + // Case B + if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() >= a.c { + if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() == 2*a.c { + a.delLRU(a.b2) + } + a.replace(ent) + } + } + + ent.setMRU(a.t1) + } +} + +func (a *arc) delLRU(list *list.List) { + lru := list.Back() + list.Remove(lru) + delete(a.cache, string(lru.Value.(*entry).key)) +} + +func (a *arc) replace(ent *entry) { + if a.t1.Len() > 0 && ((a.t1.Len() > a.p) || (ent.ll == a.b2 && a.t1.Len() == a.p)) { + lru := a.t1.Back().Value.(*entry) + lru.value = nil + lru.setMRU(a.b1) + } else { + lru := a.t2.Back().Value.(*entry) + lru.value = nil + lru.setMRU(a.b2) + } +} + +func (e *entry) setLRU(list *list.List) { + e.detach() + e.ll = list + e.el = e.ll.PushBack(e) +} + +func (e *entry) setMRU(list *list.List) { + e.detach() + e.ll = list + e.el = e.ll.PushFront(e) +} + +func (e *entry) detach() { + if e.ll != nil { + e.ll.Remove(e.el) + } +} diff --git a/trie/cache.go b/trie/cache.go deleted file mode 100644 index e475fc861..000000000 --- a/trie/cache.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -import ( - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/logger/glog" - "github.com/syndtr/goleveldb/leveldb" -) - -type Backend interface { - Get([]byte) ([]byte, error) - Put([]byte, []byte) error -} - -type Cache struct { - batch *leveldb.Batch - store map[string][]byte - backend Backend -} - -func NewCache(backend Backend) *Cache { - return &Cache{new(leveldb.Batch), make(map[string][]byte), backend} -} - -func (self *Cache) Get(key []byte) []byte { - data := self.store[string(key)] - if data == nil { - data, _ = self.backend.Get(key) - } - - return data -} - -func (self *Cache) Put(key []byte, data []byte) { - self.batch.Put(key, data) - self.store[string(key)] = data -} - -// Flush flushes the trie to the backing layer. If this is a leveldb instance -// we'll use a batched write, otherwise we'll use regular put. -func (self *Cache) Flush() { - if db, ok := self.backend.(*ethdb.LDBDatabase); ok { - if err := db.LDB().Write(self.batch, nil); err != nil { - glog.Fatal("db write err:", err) - } - } else { - for k, v := range self.store { - self.backend.Put([]byte(k), v) - } - } -} - -func (self *Cache) Copy() *Cache { - cache := NewCache(self.backend) - for k, v := range self.store { - cache.store[k] = v - } - return cache -} - -func (self *Cache) Reset() { - //self.store = make(map[string][]byte) -} diff --git a/trie/encoding.go b/trie/encoding.go index 9c862d78f..3c172b843 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -16,34 +16,36 @@ package trie -func CompactEncode(hexSlice []byte) []byte { - terminator := 0 +func compactEncode(hexSlice []byte) []byte { + terminator := byte(0) if hexSlice[len(hexSlice)-1] == 16 { terminator = 1 - } - - if terminator == 1 { hexSlice = hexSlice[:len(hexSlice)-1] } - - oddlen := len(hexSlice) % 2 - flags := byte(2*terminator + oddlen) - if oddlen != 0 { - hexSlice = append([]byte{flags}, hexSlice...) - } else { - hexSlice = append([]byte{flags, 0}, hexSlice...) + var ( + odd = byte(len(hexSlice) % 2) + buflen = len(hexSlice)/2 + 1 + bi, hi = 0, 0 // indices + hs = byte(0) // shift: flips between 0 and 4 + ) + if odd == 0 { + bi = 1 + hs = 4 } - - l := len(hexSlice) / 2 - var buf = make([]byte, l) - for i := 0; i < l; i++ { - buf[i] = 16*hexSlice[2*i] + hexSlice[2*i+1] + buf := make([]byte, buflen) + buf[0] = terminator<<5 | byte(odd)<<4 + for bi < len(buf) && hi < len(hexSlice) { + buf[bi] |= hexSlice[hi] << hs + if hs == 0 { + bi++ + } + hi, hs = hi+1, hs^(1<<2) } return buf } -func CompactDecode(str []byte) []byte { - base := CompactHexDecode(str) +func compactDecode(str []byte) []byte { + base := compactHexDecode(str) base = base[:len(base)-1] if base[0] >= 2 { base = append(base, 16) @@ -53,11 +55,10 @@ func CompactDecode(str []byte) []byte { } else { base = base[2:] } - return base } -func CompactHexDecode(str []byte) []byte { +func compactHexDecode(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) for i, b := range str { @@ -68,7 +69,7 @@ func CompactHexDecode(str []byte) []byte { return nibbles } -func DecodeCompact(key []byte) []byte { +func decodeCompact(key []byte) []byte { l := len(key) / 2 var res = make([]byte, l) for i := 0; i < l; i++ { @@ -77,3 +78,30 @@ func DecodeCompact(key []byte) []byte { } return res } + +// prefixLen returns the length of the common prefix of a and b. +func prefixLen(a, b []byte) int { + var i, length = 0, len(a) + if len(b) < length { + length = len(b) + } + for ; i < length; i++ { + if a[i] != b[i] { + break + } + } + return i +} + +func hasTerm(s []byte) bool { + return s[len(s)-1] == 16 +} + +func remTerm(s []byte) []byte { + if hasTerm(s) { + b := make([]byte, len(s)-1) + copy(b, s) + return b + } + return s +} diff --git a/trie/encoding_test.go b/trie/encoding_test.go index e49b57ef0..061d48d58 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -23,7 +23,7 @@ import ( checker "gopkg.in/check.v1" ) -func Test(t *testing.T) { checker.TestingT(t) } +func TestEncoding(t *testing.T) { checker.TestingT(t) } type TrieEncodingSuite struct{} @@ -32,64 +32,64 @@ var _ = checker.Suite(&TrieEncodingSuite{}) func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { // even compact encode test1 := []byte{1, 2, 3, 4, 5} - res1 := CompactEncode(test1) + res1 := compactEncode(test1) c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) // odd compact encode test2 := []byte{0, 1, 2, 3, 4, 5} - res2 := CompactEncode(test2) + res2 := compactEncode(test2) c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) //odd terminated compact encode test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res3 := CompactEncode(test3) + res3 := compactEncode(test3) c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) // even terminated compact encode test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16} - res4 := CompactEncode(test4) + res4 := compactEncode(test4) c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8")) } func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) { exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} - res := CompactHexDecode([]byte("verb")) + res := compactHexDecode([]byte("verb")) c.Assert(res, checker.DeepEquals, exp) } func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) { // odd compact decode exp := []byte{1, 2, 3, 4, 5} - res := CompactDecode([]byte("\x11\x23\x45")) + res := compactDecode([]byte("\x11\x23\x45")) c.Assert(res, checker.DeepEquals, exp) // even compact decode exp = []byte{0, 1, 2, 3, 4, 5} - res = CompactDecode([]byte("\x00\x01\x23\x45")) + res = compactDecode([]byte("\x00\x01\x23\x45")) c.Assert(res, checker.DeepEquals, exp) // even terminated compact decode exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res = CompactDecode([]byte("\x20\x0f\x1c\xb8")) + res = compactDecode([]byte("\x20\x0f\x1c\xb8")) c.Assert(res, checker.DeepEquals, exp) // even terminated compact decode exp = []byte{15, 1, 12, 11, 8 /*term*/, 16} - res = CompactDecode([]byte("\x3f\x1c\xb8")) + res = compactDecode([]byte("\x3f\x1c\xb8")) c.Assert(res, checker.DeepEquals, exp) } func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) { exp, _ := hex.DecodeString("012345") - res := DecodeCompact([]byte{0, 1, 2, 3, 4, 5}) + res := decodeCompact([]byte{0, 1, 2, 3, 4, 5}) c.Assert(res, checker.DeepEquals, exp) exp, _ = hex.DecodeString("012345") - res = DecodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) + res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) c.Assert(res, checker.DeepEquals, exp) exp, _ = hex.DecodeString("abcdef") - res = DecodeCompact([]byte{10, 11, 12, 13, 14, 15}) + res = decodeCompact([]byte{10, 11, 12, 13, 14, 15}) c.Assert(res, checker.DeepEquals, exp) } @@ -97,29 +97,27 @@ func BenchmarkCompactEncode(b *testing.B) { testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} for i := 0; i < b.N; i++ { - CompactEncode(testBytes) + compactEncode(testBytes) } } func BenchmarkCompactDecode(b *testing.B) { testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} for i := 0; i < b.N; i++ { - CompactDecode(testBytes) + compactDecode(testBytes) } } func BenchmarkCompactHexDecode(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - CompactHexDecode(testBytes) + compactHexDecode(testBytes) } - } func BenchmarkDecodeCompact(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - DecodeCompact(testBytes) + decodeCompact(testBytes) } - } diff --git a/trie/fullnode.go b/trie/fullnode.go deleted file mode 100644 index 8ff019ec4..000000000 --- a/trie/fullnode.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -type FullNode struct { - trie *Trie - nodes [17]Node - dirty bool -} - -func NewFullNode(t *Trie) *FullNode { - return &FullNode{trie: t} -} - -func (self *FullNode) Dirty() bool { return self.dirty } -func (self *FullNode) Value() Node { - self.nodes[16] = self.trie.trans(self.nodes[16]) - return self.nodes[16] -} -func (self *FullNode) Branches() []Node { - return self.nodes[:16] -} - -func (self *FullNode) Copy(t *Trie) Node { - nnode := NewFullNode(t) - for i, node := range self.nodes { - if node != nil { - nnode.nodes[i] = node - } - } - nnode.dirty = true - - return nnode -} - -// Returns the length of non-nil nodes -func (self *FullNode) Len() (amount int) { - for _, node := range self.nodes { - if node != nil { - amount++ - } - } - - return -} - -func (self *FullNode) Hash() interface{} { - return self.trie.store(self) -} - -func (self *FullNode) RlpData() interface{} { - t := make([]interface{}, 17) - for i, node := range self.nodes { - if node != nil { - t[i] = node.Hash() - } else { - t[i] = "" - } - } - - return t -} - -func (self *FullNode) set(k byte, value Node) { - self.nodes[int(k)] = value - self.dirty = true -} - -func (self *FullNode) branch(i byte) Node { - if self.nodes[int(i)] != nil { - self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)]) - - return self.nodes[int(i)] - } - return nil -} - -func (self *FullNode) setDirty(dirty bool) { - self.dirty = dirty -} diff --git a/trie/hashnode.go b/trie/hashnode.go deleted file mode 100644 index d4a0bc7ec..000000000 --- a/trie/hashnode.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -import "github.com/ethereum/go-ethereum/common" - -type HashNode struct { - key []byte - trie *Trie - dirty bool -} - -func NewHash(key []byte, trie *Trie) *HashNode { - return &HashNode{key, trie, false} -} - -func (self *HashNode) RlpData() interface{} { - return self.key -} - -func (self *HashNode) Hash() interface{} { - return self.key -} - -func (self *HashNode) setDirty(dirty bool) { - self.dirty = dirty -} - -// These methods will never be called but we have to satisfy Node interface -func (self *HashNode) Value() Node { return nil } -func (self *HashNode) Dirty() bool { return true } -func (self *HashNode) Copy(t *Trie) Node { return NewHash(common.CopyBytes(self.key), t) } diff --git a/trie/iterator.go b/trie/iterator.go index 9c4c7fbe5..38555fe08 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -16,9 +16,7 @@ package trie -import ( - "bytes" -) +import "bytes" type Iterator struct { trie *Trie @@ -32,32 +30,29 @@ func NewIterator(trie *Trie) *Iterator { } func (self *Iterator) Next() bool { - self.trie.mu.Lock() - defer self.trie.mu.Unlock() - isIterStart := false if self.Key == nil { isIterStart = true self.Key = make([]byte, 32) } - key := RemTerm(CompactHexDecode(self.Key)) + key := remTerm(compactHexDecode(self.Key)) k := self.next(self.trie.root, key, isIterStart) - self.Key = []byte(DecodeCompact(k)) + self.Key = []byte(decodeCompact(k)) return len(k) > 0 } -func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte { +func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byte { if node == nil { return nil } switch node := node.(type) { - case *FullNode: + case fullNode: if len(key) > 0 { - k := self.next(node.branch(key[0]), key[1:], isIterStart) + k := self.next(node[key[0]], key[1:], isIterStart) if k != nil { return append([]byte{key[0]}, k...) } @@ -69,31 +64,31 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte { } for i := r; i < 16; i++ { - k := self.key(node.branch(byte(i))) + k := self.key(node[i]) if k != nil { return append([]byte{i}, k...) } } - case *ShortNode: - k := RemTerm(node.Key()) - if vnode, ok := node.Value().(*ValueNode); ok { + case shortNode: + k := remTerm(node.Key) + if vnode, ok := node.Val.(valueNode); ok { switch bytes.Compare([]byte(k), key) { case 0: if isIterStart { - self.Value = vnode.Val() + self.Value = vnode return k } case 1: - self.Value = vnode.Val() + self.Value = vnode return k } } else { - cnode := node.Value() + cnode := node.Val var ret []byte skey := key[len(k):] - if BeginsWith(key, k) { + if bytes.HasPrefix(key, k) { ret = self.next(cnode, skey, isIterStart) } else if bytes.Compare(k, key[:len(k)]) > 0 { return self.key(node) @@ -103,37 +98,36 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte { return append(k, ret...) } } - } + case hashNode: + return self.next(self.trie.resolveHash(node), key, isIterStart) + } return nil } -func (self *Iterator) key(node Node) []byte { +func (self *Iterator) key(node interface{}) []byte { switch node := node.(type) { - case *ShortNode: + case shortNode: // Leaf node - if vnode, ok := node.Value().(*ValueNode); ok { - k := RemTerm(node.Key()) - self.Value = vnode.Val() - + k := remTerm(node.Key) + if vnode, ok := node.Val.(valueNode); ok { + self.Value = vnode return k - } else { - k := RemTerm(node.Key()) - return append(k, self.key(node.Value())...) } - case *FullNode: - if node.Value() != nil { - self.Value = node.Value().(*ValueNode).Val() - + return append(k, self.key(node.Val)...) + case fullNode: + if node[16] != nil { + self.Value = node[16].(valueNode) return []byte{16} } - for i := 0; i < 16; i++ { - k := self.key(node.branch(byte(i))) + k := self.key(node[i]) if k != nil { return append([]byte{byte(i)}, k...) } } + case hashNode: + return self.key(self.trie.resolveHash(node)) } return nil diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 148f9adf9..fdc60b412 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -19,7 +19,7 @@ package trie import "testing" func TestIterator(t *testing.T) { - trie := NewEmpty() + trie := newEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, @@ -32,11 +32,11 @@ func TestIterator(t *testing.T) { v := make(map[string]bool) for _, val := range vals { v[val.k] = false - trie.UpdateString(val.k, val.v) + trie.Update([]byte(val.k), []byte(val.v)) } trie.Commit() - it := trie.Iterator() + it := NewIterator(trie) for it.Next() { v[string(it.Key)] = true } diff --git a/trie/node.go b/trie/node.go index 9d49029de..0bfa21dc4 100644 --- a/trie/node.go +++ b/trie/node.go @@ -16,46 +16,172 @@ package trie -import "fmt" +import ( + "fmt" + "io" + "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} -type Node interface { - Value() Node - Copy(*Trie) Node // All nodes, for now, return them self - Dirty() bool +type node interface { fstring(string) string - Hash() interface{} - RlpData() interface{} - setDirty(dirty bool) } -// Value node -func (self *ValueNode) String() string { return self.fstring("") } -func (self *FullNode) String() string { return self.fstring("") } -func (self *ShortNode) String() string { return self.fstring("") } -func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) } +type ( + fullNode [17]node + shortNode struct { + Key []byte + Val node + } + hashNode []byte + valueNode []byte +) -//func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("< %x > ", self.key) } -func (self *HashNode) fstring(ind string) string { - return fmt.Sprintf("%v", self.trie.trans(self)) -} +// Pretty printing. +func (n fullNode) String() string { return n.fstring("") } +func (n shortNode) String() string { return n.fstring("") } +func (n hashNode) String() string { return n.fstring("") } +func (n valueNode) String() string { return n.fstring("") } -// Full node -func (self *FullNode) fstring(ind string) string { +func (n fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) - for i, node := range self.nodes { + for i, node := range n { if node == nil { resp += fmt.Sprintf("%s: <nil> ", indices[i]) } else { resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" ")) } } - return resp + fmt.Sprintf("\n%s] ", ind) } +func (n shortNode) fstring(ind string) string { + return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" ")) +} +func (n hashNode) fstring(ind string) string { + return fmt.Sprintf("<%x> ", []byte(n)) +} +func (n valueNode) fstring(ind string) string { + return fmt.Sprintf("%x ", []byte(n)) +} + +func mustDecodeNode(dbkey, buf []byte) node { + n, err := decodeNode(buf) + if err != nil { + panic(fmt.Sprintf("node %x: %v", dbkey, err)) + } + return n +} + +// decodeNode parses the RLP encoding of a trie node. +func decodeNode(buf []byte) (node, error) { + if len(buf) == 0 { + return nil, io.ErrUnexpectedEOF + } + elems, _, err := rlp.SplitList(buf) + if err != nil { + return nil, fmt.Errorf("decode error: %v", err) + } + switch c, _ := rlp.CountValues(elems); c { + case 2: + n, err := decodeShort(elems) + return n, wrapError(err, "short") + case 17: + n, err := decodeFull(elems) + return n, wrapError(err, "full") + default: + return nil, fmt.Errorf("invalid number of list elements: %v", c) + } +} + +func decodeShort(buf []byte) (node, error) { + kbuf, rest, err := rlp.SplitString(buf) + if err != nil { + return nil, err + } + key := compactDecode(kbuf) + if key[len(key)-1] == 16 { + // value node + val, _, err := rlp.SplitString(rest) + if err != nil { + return nil, fmt.Errorf("invalid value node: %v", err) + } + return shortNode{key, valueNode(val)}, nil + } + r, _, err := decodeRef(rest) + if err != nil { + return nil, wrapError(err, "val") + } + return shortNode{key, r}, nil +} + +func decodeFull(buf []byte) (fullNode, error) { + var n fullNode + for i := 0; i < 16; i++ { + cld, rest, err := decodeRef(buf) + if err != nil { + return n, wrapError(err, fmt.Sprintf("[%d]", i)) + } + n[i], buf = cld, rest + } + val, _, err := rlp.SplitString(buf) + if err != nil { + return n, err + } + if len(val) > 0 { + n[16] = valueNode(val) + } + return n, nil +} + +const hashLen = len(common.Hash{}) + +func decodeRef(buf []byte) (node, []byte, error) { + kind, val, rest, err := rlp.Split(buf) + if err != nil { + return nil, buf, err + } + switch { + case kind == rlp.List: + // 'embedded' node reference. The encoding must be smaller + // than a hash in order to be valid. + if size := len(buf) - len(rest); size > hashLen { + err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) + return nil, buf, err + } + n, err := decodeNode(buf) + return n, rest, err + case kind == rlp.String && len(val) == 0: + // empty node + return nil, rest, nil + case kind == rlp.String && len(val) == 32: + return hashNode(val), rest, nil + default: + return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val)) + } +} + +// wraps a decoding error with information about the path to the +// invalid child node (for debugging encoding issues). +type decodeError struct { + what error + stack []string +} + +func wrapError(err error, ctx string) error { + if err == nil { + return nil + } + if decErr, ok := err.(*decodeError); ok { + decErr.stack = append(decErr.stack, ctx) + return decErr + } + return &decodeError{err, []string{ctx}} +} -// Short node -func (self *ShortNode) fstring(ind string) string { - return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" ")) +func (err *decodeError) Error() string { + return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-")) } diff --git a/trie/proof.go b/trie/proof.go new file mode 100644 index 000000000..a705c49db --- /dev/null +++ b/trie/proof.go @@ -0,0 +1,122 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +// Prove constructs a merkle proof for key. The result contains all +// encoded nodes on the path to the value at key. The value itself is +// also included in the last node and can be retrieved by verifying +// the proof. +// +// The returned proof is nil if the trie does not contain a value for key. +// For existing keys, the proof will have at least one element. +func (t *Trie) Prove(key []byte) []rlp.RawValue { + // Collect all nodes on the path to key. + key = compactHexDecode(key) + nodes := []node{} + tn := t.root + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + return nil + } + tn = n.Val + key = key[len(n.Key):] + nodes = append(nodes, n) + case fullNode: + tn = n[key[0]] + key = key[1:] + nodes = append(nodes, n) + case nil: + return nil + case hashNode: + tn = t.resolveHash(n) + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + if t.hasher == nil { + t.hasher = newHasher() + } + proof := make([]rlp.RawValue, 0, len(nodes)) + for i, n := range nodes { + // Don't bother checking for errors here since hasher panics + // if encoding doesn't work and we're not writing to any database. + n, _ = t.hasher.replaceChildren(n, nil) + hn, _ := t.hasher.store(n, nil, false) + if _, ok := hn.(hashNode); ok || i == 0 { + // If the node's database encoding is a hash (or is the + // root node), it becomes a proof element. + enc, _ := rlp.EncodeToBytes(n) + proof = append(proof, enc) + } + } + return proof +} + +// VerifyProof checks merkle proofs. The given proof must contain the +// value for key in a trie with the given root hash. VerifyProof +// returns an error if the proof contains invalid trie nodes or the +// wrong value. +func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { + key = compactHexDecode(key) + sha := sha3.NewKeccak256() + wantHash := rootHash.Bytes() + for i, buf := range proof { + sha.Reset() + sha.Write(buf) + if !bytes.Equal(sha.Sum(nil), wantHash) { + return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) + } + n, err := decodeNode(buf) + if err != nil { + return nil, fmt.Errorf("bad proof node %d: %v", i, err) + } + keyrest, cld := get(n, key) + switch cld := cld.(type) { + case nil: + return nil, fmt.Errorf("key mismatch at proof node %d", i) + case hashNode: + key = keyrest + wantHash = cld + case valueNode: + if i != len(proof)-1 { + return nil, errors.New("additional nodes at end of proof") + } + return cld, nil + } + } + return nil, errors.New("unexpected end of proof") +} + +func get(tn node, key []byte) ([]byte, node) { + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + return nil, nil + } + tn = n.Val + key = key[len(n.Key):] + case fullNode: + tn = n[key[0]] + key = key[1:] + case hashNode: + return key, n + case nil: + return key, nil + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + return nil, tn.(valueNode) +} diff --git a/trie/proof_test.go b/trie/proof_test.go new file mode 100644 index 000000000..6b5bef05c --- /dev/null +++ b/trie/proof_test.go @@ -0,0 +1,139 @@ +package trie + +import ( + "bytes" + crand "crypto/rand" + mrand "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +func init() { + mrand.Seed(time.Now().Unix()) +} + +func TestProof(t *testing.T) { + trie, vals := randomTrie(500) + root := trie.Hash() + for _, kv := range vals { + proof := trie.Prove(kv.k) + if proof == nil { + t.Fatalf("missing key %x while constructing proof", kv.k) + } + val, err := VerifyProof(root, kv.k, proof) + if err != nil { + t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) + } + } +} + +func TestOneElementProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + proof := trie.Prove([]byte("k")) + if proof == nil { + t.Fatal("nil proof") + } + if len(proof) != 1 { + t.Error("proof should have one element") + } + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof) + } + if !bytes.Equal(val, []byte("v")) { + t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) + } +} + +func TestVerifyBadProof(t *testing.T) { + trie, vals := randomTrie(800) + root := trie.Hash() + for _, kv := range vals { + proof := trie.Prove(kv.k) + if proof == nil { + t.Fatal("nil proof") + } + mutateByte(proof[mrand.Intn(len(proof))]) + if _, err := VerifyProof(root, kv.k, proof); err == nil { + t.Fatalf("expected proof to fail for key %x", kv.k) + } + } +} + +// mutateByte changes one byte in b. +func mutateByte(b []byte) { + for r := mrand.Intn(len(b)); ; { + new := byte(mrand.Intn(255)) + if new != b[r] { + b[r] = new + break + } + } +} + +func BenchmarkProve(b *testing.B) { + trie, vals := randomTrie(100) + var keys []string + for k := range vals { + keys = append(keys, k) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := vals[keys[i%len(keys)]] + if trie.Prove(kv.k) == nil { + b.Fatalf("nil proof for %x", kv.k) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + trie, vals := randomTrie(100) + root := trie.Hash() + var keys []string + var proofs [][]rlp.RawValue + for k := range vals { + keys = append(keys, k) + proofs = append(proofs, trie.Prove([]byte(k))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + im := i % len(keys) + if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + b.Fatalf("key %x: error", keys[im], err) + } + } +} + +func randomTrie(n int) (*Trie, map[string]*kv) { + trie := new(Trie) + vals := make(map[string]*kv) + for i := byte(0); i < 100; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + for i := 0; i < n; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + vals[string(value.k)] = value + } + return trie, vals +} + +func randBytes(n int) []byte { + r := make([]byte, n) + crand.Read(r) + return r +} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 47c7542bb..47d1934d0 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -16,46 +16,93 @@ package trie -import "github.com/ethereum/go-ethereum/crypto" +import ( + "hash" -var keyPrefix = []byte("secure-key-") + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/sha3" +) +var secureKeyPrefix = []byte("secure-key-") + +// SecureTrie wraps a trie with key hashing. In a secure trie, all +// access operations hash the key using keccak256. This prevents +// calling code from creating long chains of nodes that +// increase the access time. +// +// Contrary to a regular trie, a SecureTrie can only be created with +// New and must have an attached database. The database also stores +// the preimage of each key. +// +// SecureTrie is not safe for concurrent use. type SecureTrie struct { *Trie -} -func NewSecure(root []byte, backend Backend) *SecureTrie { - return &SecureTrie{New(root, backend)} + hash hash.Hash + secKeyBuf []byte + hashKeyBuf []byte } -func (self *SecureTrie) Update(key, value []byte) Node { - shaKey := crypto.Sha3(key) - self.Trie.cache.Put(append(keyPrefix, shaKey...), key) - - return self.Trie.Update(shaKey, value) -} -func (self *SecureTrie) UpdateString(key, value string) Node { - return self.Update([]byte(key), []byte(value)) +// NewSecure creates a trie with an existing root node from db. +// +// If root is the zero hash or the sha3 hash of an empty string, the +// trie is initially empty. Otherwise, New will panics if db is nil +// and returns ErrMissingRoot if the root node cannpt be found. +// Accessing the trie loads nodes from db on demand. +func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { + if db == nil { + panic("NewSecure called with nil database") + } + trie, err := New(root, db) + if err != nil { + return nil, err + } + return &SecureTrie{Trie: trie}, nil } -func (self *SecureTrie) Get(key []byte) []byte { - return self.Trie.Get(crypto.Sha3(key)) +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *SecureTrie) Get(key []byte) []byte { + return t.Trie.Get(t.hashKey(key)) } -func (self *SecureTrie) GetString(key string) []byte { - return self.Get([]byte(key)) + +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *SecureTrie) Update(key, value []byte) { + hk := t.hashKey(key) + t.Trie.Update(hk, value) + t.Trie.db.Put(t.secKey(hk), key) } -func (self *SecureTrie) Delete(key []byte) Node { - return self.Trie.Delete(crypto.Sha3(key)) +// Delete removes any existing value for key from the trie. +func (t *SecureTrie) Delete(key []byte) { + t.Trie.Delete(t.hashKey(key)) } -func (self *SecureTrie) DeleteString(key string) Node { - return self.Delete([]byte(key)) + +// GetKey returns the sha3 preimage of a hashed key that was +// previously used to store a value. +func (t *SecureTrie) GetKey(shaKey []byte) []byte { + key, _ := t.Trie.db.Get(t.secKey(shaKey)) + return key } -func (self *SecureTrie) Copy() *SecureTrie { - return &SecureTrie{self.Trie.Copy()} +func (t *SecureTrie) secKey(key []byte) []byte { + t.secKeyBuf = append(t.secKeyBuf[:0], secureKeyPrefix...) + t.secKeyBuf = append(t.secKeyBuf, key...) + return t.secKeyBuf } -func (self *SecureTrie) GetKey(shaKey []byte) []byte { - return self.Trie.cache.Get(append(keyPrefix, shaKey...)) +func (t *SecureTrie) hashKey(key []byte) []byte { + if t.hash == nil { + t.hash = sha3.NewKeccak256() + t.hashKeyBuf = make([]byte, 32) + } + t.hash.Reset() + t.hash.Write(key) + t.hashKeyBuf = t.hash.Sum(t.hashKeyBuf[:0]) + return t.hashKeyBuf } diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go new file mode 100644 index 000000000..13c6cd02e --- /dev/null +++ b/trie/secure_trie_test.go @@ -0,0 +1,74 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package trie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" +) + +func newEmptySecure() *SecureTrie { + db, _ := ethdb.NewMemDatabase() + trie, _ := NewSecure(common.Hash{}, db) + return trie +} + +func TestSecureDelete(t *testing.T) { + trie := newEmptySecure() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + if val.v != "" { + trie.Update([]byte(val.k), []byte(val.v)) + } else { + trie.Delete([]byte(val.k)) + } + } + hash := trie.Hash() + exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestSecureGetKey(t *testing.T) { + trie := newEmptySecure() + trie.Update([]byte("foo"), []byte("bar")) + + key := []byte("foo") + value := []byte("bar") + seckey := crypto.Sha3(key) + + if !bytes.Equal(trie.Get(key), value) { + t.Errorf("Get did not return bar") + } + if k := trie.GetKey(seckey); !bytes.Equal(k, key) { + t.Errorf("GetKey returned %q, want %q", k, key) + } +} diff --git a/trie/shortnode.go b/trie/shortnode.go deleted file mode 100644 index 569d5f109..000000000 --- a/trie/shortnode.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -import "github.com/ethereum/go-ethereum/common" - -type ShortNode struct { - trie *Trie - key []byte - value Node - dirty bool -} - -func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { - return &ShortNode{t, CompactEncode(key), value, false} -} -func (self *ShortNode) Value() Node { - self.value = self.trie.trans(self.value) - - return self.value -} -func (self *ShortNode) Dirty() bool { return self.dirty } -func (self *ShortNode) Copy(t *Trie) Node { - node := &ShortNode{t, nil, self.value.Copy(t), self.dirty} - node.key = common.CopyBytes(self.key) - node.dirty = true - return node -} - -func (self *ShortNode) RlpData() interface{} { - return []interface{}{self.key, self.value.Hash()} -} -func (self *ShortNode) Hash() interface{} { - return self.trie.store(self) -} - -func (self *ShortNode) Key() []byte { - return CompactDecode(self.key) -} - -func (self *ShortNode) setDirty(dirty bool) { - self.dirty = dirty -} diff --git a/trie/slice.go b/trie/slice.go deleted file mode 100644 index ccefbd064..000000000 --- a/trie/slice.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -import ( - "bytes" - "math" -) - -// Helper function for comparing slices -func CompareIntSlice(a, b []int) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - -// Returns the amount of nibbles that match each other from 0 ... -func MatchingNibbleLength(a, b []byte) int { - var i, length = 0, int(math.Min(float64(len(a)), float64(len(b)))) - - for i < length { - if a[i] != b[i] { - break - } - i++ - } - - return i -} - -func HasTerm(s []byte) bool { - return s[len(s)-1] == 16 -} - -func RemTerm(s []byte) []byte { - if HasTerm(s) { - return s[:len(s)-1] - } - - return s -} - -func BeginsWith(a, b []byte) bool { - if len(b) > len(a) { - return false - } - - return bytes.Equal(a[:len(b)], b) -} diff --git a/trie/trie.go b/trie/trie.go index abf48a850..aa8d39fe2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -19,372 +19,425 @@ package trie import ( "bytes" - "container/list" + "errors" "fmt" - "sync" + "hash" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rlp" ) -func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(nil, backend) +const defaultCacheCapacity = 800 - it := t1.Iterator() - for it.Next() { - t2.Update(it.Key, it.Value) - } - - return bytes.Equal(t2.Hash(), t1.Hash()), t2 -} - -type Trie struct { - mu sync.Mutex - root Node - roothash []byte - cache *Cache - - revisions *list.List -} - -func New(root []byte, backend Backend) *Trie { - trie := &Trie{} - trie.revisions = list.New() - trie.roothash = root - if backend != nil { - trie.cache = NewCache(backend) - } +var ( + // The global cache stores decoded trie nodes by hash as they get loaded. + globalCache = newARC(defaultCacheCapacity) + // This is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") +) - if root != nil { - value := common.NewValueFromBytes(trie.cache.Get(root)) - trie.root = trie.mknode(value) - } +var ErrMissingRoot = errors.New("missing root node") - return trie +// Database must be implemented by backing stores for the trie. +type Database interface { + DatabaseWriter + // Get returns the value for key from the database. + Get(key []byte) (value []byte, err error) } -func (self *Trie) Iterator() *Iterator { - return NewIterator(self) +// DatabaseWriter wraps the Put method of a backing store for the trie. +type DatabaseWriter interface { + // Put stores the mapping key->value in the database. + // Implementations must not hold onto the value bytes, the trie + // will reuse the slice across calls to Put. + Put(key, value []byte) error } -func (self *Trie) Copy() *Trie { - cpy := make([]byte, 32) - copy(cpy, self.roothash) // NOTE: cpy isn't being used anywhere? - trie := New(nil, nil) - trie.cache = self.cache.Copy() - if self.root != nil { - trie.root = self.root.Copy(trie) - } - - return trie +// Trie is a Merkle Patricia Trie. +// The zero value is an empty trie with no database. +// Use New to create a trie that sits on top of a database. +// +// Trie is not safe for concurrent use. +type Trie struct { + root node + db Database + *hasher } -// Legacy support -func (self *Trie) Root() []byte { return self.Hash() } -func (self *Trie) Hash() []byte { - var hash []byte - if self.root != nil { - t := self.root.Hash() - if byts, ok := t.([]byte); ok && len(byts) > 0 { - hash = byts - } else { - hash = crypto.Sha3(common.Encode(self.root.RlpData())) +// New creates a trie with an existing root node from db. +// +// If root is the zero hash or the sha3 hash of an empty string, the +// trie is initially empty and does not require a database. Otherwise, +// New will panics if db is nil or root does not exist in the +// database. Accessing the trie loads nodes from db on demand. +func New(root common.Hash, db Database) (*Trie, error) { + trie := &Trie{db: db} + if (root != common.Hash{}) && root != emptyRoot { + if db == nil { + panic("trie.New: cannot use existing root without a database") } - } else { - hash = crypto.Sha3(common.Encode("")) - } - - if !bytes.Equal(hash, self.roothash) { - self.revisions.PushBack(self.roothash) - self.roothash = hash + if v, _ := trie.db.Get(root[:]); len(v) == 0 { + return nil, ErrMissingRoot + } + trie.root = hashNode(root.Bytes()) } - - return hash + return trie, nil } -func (self *Trie) Commit() { - self.mu.Lock() - defer self.mu.Unlock() - // Hash first - self.Hash() - - self.cache.Flush() +// Iterator returns an iterator over all mappings in the trie. +func (t *Trie) Iterator() *Iterator { + return NewIterator(t) } -// Reset should only be called if the trie has been hashed -func (self *Trie) Reset() { - self.mu.Lock() - defer self.mu.Unlock() - - self.cache.Reset() - - if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).([]byte) - self.roothash = revision +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *Trie) Get(key []byte) []byte { + key = compactHexDecode(key) + tn := t.root + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + return nil + } + tn = n.Val + key = key[len(n.Key):] + case fullNode: + tn = n[key[0]] + key = key[1:] + case nil: + return nil + case hashNode: + tn = t.resolveHash(n) + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } } - value := common.NewValueFromBytes(self.cache.Get(self.roothash)) - self.root = self.mknode(value) + return tn.(valueNode) } -func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } -func (self *Trie) Update(key, value []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *Trie) Update(key, value []byte) { + k := compactHexDecode(key) if len(value) != 0 { - node := NewValueNode(self, value) - node.dirty = true - self.root = self.insert(self.root, k, node) + t.root = t.insert(t.root, k, valueNode(value)) } else { - self.root = self.delete(self.root, k) + t.root = t.delete(t.root, k) } - - return self.root -} - -func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } -func (self *Trie) Get(key []byte) []byte { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - - n := self.get(self.root, k) - if n != nil { - return n.(*ValueNode).Val() - } - - return nil } -func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } -func (self *Trie) Delete(key []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := CompactHexDecode(key) - self.root = self.delete(self.root, k) - - return self.root -} - -func (self *Trie) insert(node Node, key []byte, value Node) Node { +func (t *Trie) insert(n node, key []byte, value node) node { if len(key) == 0 { return value } - - if node == nil { - node := NewShortNode(self, key, value) - node.dirty = true - return node - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(k, key) { - node := NewShortNode(self, key, value) - node.dirty = true - return node - + switch n := n.(type) { + case shortNode: + matchlen := prefixLen(key, n.Key) + // If the whole key matches, keep this short node as is + // and only update the value. + if matchlen == len(n.Key) { + return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)} } - - var n Node - matchlength := MatchingNibbleLength(key, k) - if matchlength == len(k) { - n = self.insert(cnode, key[matchlength:], value) - } else { - pnode := self.insert(nil, k[matchlength+1:], cnode) - nnode := self.insert(nil, key[matchlength+1:], value) - fulln := NewFullNode(self) - fulln.dirty = true - fulln.set(k[matchlength], pnode) - fulln.set(key[matchlength], nnode) - n = fulln - } - if matchlength == 0 { - return n + // Otherwise branch out at the index where they differ. + var branch fullNode + branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val) + branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value) + // Replace this shortNode with the branch if it occurs at index 0. + if matchlen == 0 { + return branch } + // Otherwise, replace it with a short node leading up to the branch. + return shortNode{key[:matchlen], branch} - snode := NewShortNode(self, key[:matchlength], n) - snode.dirty = true - return snode + case fullNode: + n[key[0]] = t.insert(n[key[0]], key[1:], value) + return n - case *FullNode: - cpy := node.Copy(self).(*FullNode) - cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) - cpy.dirty = true + case nil: + return shortNode{key, value} - return cpy + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and insert into it. This leaves all child nodes on + // the path to the value in the trie. + // + // TODO: track whether insertion changed the value and keep + // n as a hash node if it didn't. + return t.insert(t.resolveHash(n), key, value) default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } } -func (self *Trie) get(node Node, key []byte) Node { - if len(key) == 0 { - return node - } - - if node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - - if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { - return self.get(cnode, key[len(k):]) - } - - return nil - case *FullNode: - return self.get(node.branch(key[0]), key[1:]) - default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) - } +// Delete removes any existing value for key from the trie. +func (t *Trie) Delete(key []byte) { + k := compactHexDecode(key) + t.root = t.delete(t.root, k) } -func (self *Trie) delete(node Node, key []byte) Node { - if len(key) == 0 && node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(key, k) { - return nil - } else if bytes.Equal(key[:len(k)], k) { - child := self.delete(cnode, key[len(k):]) - - var n Node - switch child := child.(type) { - case *ShortNode: - nkey := append(k, child.Key()...) - n = NewShortNode(self, nkey, child.Value()) - n.(*ShortNode).dirty = true - case *FullNode: - sn := NewShortNode(self, node.Key(), child) - sn.dirty = true - sn.key = node.key - n = sn - } - - return n - } else { - return node +// delete returns the new root of the trie with key deleted. +// It reduces the trie to minimal form by simplifying +// nodes on the way up after deleting recursively. +func (t *Trie) delete(n node, key []byte) node { + switch n := n.(type) { + case shortNode: + matchlen := prefixLen(key, n.Key) + if matchlen < len(n.Key) { + return n // don't replace n on mismatch + } + if matchlen == len(key) { + return nil // remove n entirely for whole matches + } + // The key is longer than n.Key. Remove the remaining suffix + // from the subtrie. Child can never be nil here since the + // subtrie must contain at least two other values with keys + // longer than n.Key. + child := t.delete(n.Val, key[len(n.Key):]) + switch child := child.(type) { + case shortNode: + // Deleting from the subtrie reduced it to another + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which + // always creates a new slice) instead of append to + // avoid modifying n.Key since it might be shared with + // other nodes. + return shortNode{concat(n.Key, child.Key...), child.Val} + default: + return shortNode{n.Key, child} } - case *FullNode: - n := node.Copy(self).(*FullNode) - n.set(key[0], self.delete(n.branch(key[0]), key[1:])) - n.dirty = true - + case fullNode: + n[key[0]] = t.delete(n[key[0]], key[1:]) + // Check how many non-nil entries are left after deleting and + // reduce the full node to a short node if only one entry is + // left. Since n must've contained at least two children + // before deletion (otherwise it would not be a full node) n + // can never be reduced to nil. + // + // When the loop is done, pos contains the index of the single + // value that is left in n or -2 if n contains at least two + // values. pos := -1 - for i := 0; i < 17; i++ { - if n.branch(byte(i)) != nil { + for i, cld := range n { + if cld != nil { if pos == -1 { pos = i } else { pos = -2 + break } } } - - var nnode Node - if pos == 16 { - nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) - nnode.(*ShortNode).dirty = true - } else if pos >= 0 { - cnode := n.branch(byte(pos)) - switch cnode := cnode.(type) { - case *ShortNode: - // Stitch keys - k := append([]byte{byte(pos)}, cnode.Key()...) - nnode = NewShortNode(self, k, cnode.Value()) - nnode.(*ShortNode).dirty = true - case *FullNode: - nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) - nnode.(*ShortNode).dirty = true + if pos >= 0 { + if pos != 16 { + // If the remaining entry is a short node, it replaces + // n and its key gets the missing nibble tacked to the + // front. This avoids creating an invalid + // shortNode{..., shortNode{...}}. Since the entry + // might not be loaded yet, resolve it just for this + // check. + cnode := t.resolve(n[pos]) + if cnode, ok := cnode.(shortNode); ok { + k := append([]byte{byte(pos)}, cnode.Key...) + return shortNode{k, cnode.Val} + } } - } else { - nnode = n + // Otherwise, n is replaced by a one-nibble short node + // containing the child. + return shortNode{[]byte{byte(pos)}, n[pos]} } + // n still contains at least two values and cannot be reduced. + return n - return nnode case nil: return nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and delete from it. This leaves all child nodes on + // the path to the value in the trie. + // + // TODO: track whether deletion actually hit a key and keep + // n as a hash node if it didn't. + return t.delete(t.resolveHash(n), key) + default: - panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) } } -// casting functions and cache storing -func (self *Trie) mknode(value *common.Value) Node { - l := value.Len() - switch l { - case 0: - return nil - case 2: - // A value node may consists of 2 bytes. - if value.Get(0).Len() != 0 { - key := CompactDecode(value.Get(0).Bytes()) - if key[len(key)-1] == 16 { - return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes())) - } else { - return NewShortNode(self, key, self.mknode(value.Get(1))) - } - } - case 17: - if len(value.Bytes()) != 17 { - fnode := NewFullNode(self) - for i := 0; i < 16; i++ { - fnode.set(byte(i), self.mknode(value.Get(i))) - } - return fnode +func concat(s1 []byte, s2 ...byte) []byte { + r := make([]byte, len(s1)+len(s2)) + copy(r, s1) + copy(r[len(s1):], s2) + return r +} + +func (t *Trie) resolve(n node) node { + if n, ok := n.(hashNode); ok { + return t.resolveHash(n) + } + return n +} + +func (t *Trie) resolveHash(n hashNode) node { + if v, ok := globalCache.Get(n); ok { + return v + } + enc, err := t.db.Get(n) + if err != nil || enc == nil { + // TODO: This needs to be improved to properly distinguish errors. + // Disk I/O errors shouldn't produce nil (and cause a + // consensus failure or weird crash), but it is unclear how + // they could be handled because the entire stack above the trie isn't + // prepared to cope with missing state nodes. + if glog.V(logger.Error) { + glog.Errorf("Dangling hash node ref %x: %v", n, err) } - case 32: - return NewHash(value.Bytes(), self) + return nil + } + dec := mustDecodeNode(n, enc) + if dec != nil { + globalCache.Put(n, dec) } + return dec +} + +// Root returns the root hash of the trie. +// Deprecated: use Hash instead. +func (t *Trie) Root() []byte { return t.Hash().Bytes() } - return NewValueNode(self, value.Bytes()) +// Hash returns the root hash of the trie. It does not write to the +// database and can be used even if the trie doesn't have one. +func (t *Trie) Hash() common.Hash { + root, _ := t.hashRoot(nil) + return common.BytesToHash(root.(hashNode)) } -func (self *Trie) trans(node Node) Node { - switch node := node.(type) { - case *HashNode: - value := common.NewValueFromBytes(self.cache.Get(node.key)) - return self.mknode(value) - default: - return node +// Commit writes all nodes to the trie's database. +// Nodes are stored with their sha3 hash as the key. +// +// Committing flushes nodes from memory. +// Subsequent Get calls will load nodes from the database. +func (t *Trie) Commit() (root common.Hash, err error) { + if t.db == nil { + panic("Commit called on trie with nil database") } + return t.CommitTo(t.db) } -func (self *Trie) store(node Node) interface{} { - data := common.Encode(node) - if len(data) >= 32 { - key := crypto.Sha3(data) - if node.Dirty() { - //fmt.Println("save", node) - //fmt.Println() - self.cache.Put(key, data) - } +// CommitTo writes all nodes to the given database. +// Nodes are stored with their sha3 hash as the key. +// +// Committing flushes nodes from memory. Subsequent Get calls will +// load nodes from the trie's database. Calling code must ensure that +// the changes made to db are written back to the trie's attached +// database before using the trie. +func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { + n, err := t.hashRoot(db) + if err != nil { + return (common.Hash{}), err + } + t.root = n + return common.BytesToHash(n.(hashNode)), nil +} - return key +func (t *Trie) hashRoot(db DatabaseWriter) (node, error) { + if t.root == nil { + return hashNode(emptyRoot.Bytes()), nil + } + if t.hasher == nil { + t.hasher = newHasher() } + return t.hasher.hash(t.root, db, true) +} - return node.RlpData() +type hasher struct { + tmp *bytes.Buffer + sha hash.Hash } -func (self *Trie) PrintRoot() { - fmt.Println(self.root) - fmt.Printf("root=%x\n", self.Root()) +func newHasher() *hasher { + return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} +} + +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) { + hashed, err := h.replaceChildren(n, db) + if err != nil { + return hashNode{}, err + } + if n, err = h.store(hashed, db, force); err != nil { + return hashNode{}, err + } + return n, nil +} + +// hashChildren replaces child nodes of n with their hashes if the encoded +// size of the child is larger than a hash. +func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) { + var err error + switch n := n.(type) { + case shortNode: + n.Key = compactEncode(n.Key) + if _, ok := n.Val.(valueNode); !ok { + if n.Val, err = h.hash(n.Val, db, false); err != nil { + return n, err + } + } + if n.Val == nil { + // Ensure that nil children are encoded as empty strings. + n.Val = valueNode(nil) + } + return n, nil + case fullNode: + for i := 0; i < 16; i++ { + if n[i] != nil { + if n[i], err = h.hash(n[i], db, false); err != nil { + return n, err + } + } else { + // Ensure that nil children are encoded as empty strings. + n[i] = valueNode(nil) + } + } + if n[16] == nil { + n[16] = valueNode(nil) + } + return n, nil + default: + return n, nil + } +} + +func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { + // Don't store hashes or empty nodes. + if _, isHash := n.(hashNode); n == nil || isHash { + return n, nil + } + h.tmp.Reset() + if err := rlp.Encode(h.tmp, n); err != nil { + panic("encode error: " + err.Error()) + } + if h.tmp.Len() < 32 && !force { + // Nodes smaller than 32 bytes are stored inside their parent. + return n, nil + } + // Larger nodes are replaced by their hash and stored in the database. + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + key := hashNode(h.sha.Sum(nil)) + if db != nil { + err := db.Put(key, h.tmp.Bytes()) + return key, err + } + return key, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index ae4e5efe4..c96861bed 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -18,89 +18,109 @@ package trie import ( "bytes" + "encoding/binary" "fmt" + "io/ioutil" + "os" "testing" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" ) -type Db map[string][]byte - -func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil } -func (self Db) Put(k, v []byte) error { self[string(k)] = v; return nil } - -// Used for testing -func NewEmpty() *Trie { - return New(nil, make(Db)) +func init() { + spew.Config.Indent = " " + spew.Config.DisableMethods = true } -func NewEmptySecure() *SecureTrie { - return NewSecure(nil, make(Db)) +// Used for testing +func newEmpty() *Trie { + db, _ := ethdb.NewMemDatabase() + trie, _ := New(common.Hash{}, db) + return trie } func TestEmptyTrie(t *testing.T) { - trie := NewEmpty() + var trie Trie res := trie.Hash() - exp := crypto.Sha3(common.Encode("")) - if !bytes.Equal(res, exp) { + exp := emptyRoot + if res != common.Hash(exp) { t.Errorf("expected %x got %x", exp, res) } } func TestNull(t *testing.T) { - trie := NewEmpty() - + var trie Trie key := make([]byte, 32) value := common.FromHex("0x823140710bf13990e4500136726d8b55") trie.Update(key, value) value = trie.Get(key) } +func TestMissingRoot(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db) + if trie != nil { + t.Error("New returned non-nil trie for invalid root") + } + if err != ErrMissingRoot { + t.Error("New returned wrong error: %v", err) + } +} + func TestInsert(t *testing.T) { - trie := NewEmpty() + trie := newEmpty() - trie.UpdateString("doe", "reindeer") - trie.UpdateString("dog", "puppy") - trie.UpdateString("dogglesworth", "cat") + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") - exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") + exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") root := trie.Hash() - if !bytes.Equal(root, exp) { + if root != exp { t.Errorf("exp %x got %x", exp, root) } - trie = NewEmpty() - trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + trie = newEmpty() + updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") - exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") - root = trie.Hash() - if !bytes.Equal(root, exp) { + exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") + root, err := trie.Commit() + if err != nil { + t.Fatalf("commit error: %v", err) + } + if root != exp { t.Errorf("exp %x got %x", exp, root) } } func TestGet(t *testing.T) { - trie := NewEmpty() - - trie.UpdateString("doe", "reindeer") - trie.UpdateString("dog", "puppy") - trie.UpdateString("dogglesworth", "cat") + trie := newEmpty() + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") + + for i := 0; i < 2; i++ { + res := getString(trie, "dog") + if !bytes.Equal(res, []byte("puppy")) { + t.Errorf("expected puppy got %x", res) + } - res := trie.GetString("dog") - if !bytes.Equal(res, []byte("puppy")) { - t.Errorf("expected puppy got %x", res) - } + unknown := getString(trie, "unknown") + if unknown != nil { + t.Errorf("expected nil got %x", unknown) + } - unknown := trie.GetString("unknown") - if unknown != nil { - t.Errorf("expected nil got %x", unknown) + if i == 1 { + return + } + trie.Commit() } } func TestDelete(t *testing.T) { - trie := NewEmpty() - + trie := newEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, @@ -113,21 +133,21 @@ func TestDelete(t *testing.T) { } for _, val := range vals { if val.v != "" { - trie.UpdateString(val.k, val.v) + updateString(trie, val.k, val.v) } else { - trie.DeleteString(val.k) + deleteString(trie, val.k) } } hash := trie.Hash() - exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { t.Errorf("expected %x got %x", exp, hash) } } func TestEmptyValues(t *testing.T) { - trie := NewEmpty() + trie := newEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, @@ -140,78 +160,85 @@ func TestEmptyValues(t *testing.T) { {"shaman", ""}, } for _, val := range vals { - trie.UpdateString(val.k, val.v) + updateString(trie, val.k, val.v) } hash := trie.Hash() - exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { t.Errorf("expected %x got %x", exp, hash) } } func TestReplication(t *testing.T) { - trie := NewEmpty() + trie := newEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, {"horse", "stallion"}, {"shaman", "horse"}, {"doge", "coin"}, - {"ether", ""}, {"dog", "puppy"}, - {"shaman", ""}, {"somethingveryoddindeedthis is", "myothernodedata"}, } for _, val := range vals { - trie.UpdateString(val.k, val.v) + updateString(trie, val.k, val.v) } - trie.Commit() - - trie2 := New(trie.Root(), trie.cache.backend) - if string(trie2.GetString("horse")) != "stallion" { - t.Error("expected to have horse => stallion") + exp, err := trie.Commit() + if err != nil { + t.Fatalf("commit error: %v", err) } - hash := trie2.Hash() - exp := trie.Hash() - if !bytes.Equal(hash, exp) { + // create a new trie on top of the database and check that lookups work. + trie2, err := New(exp, trie.db) + if err != nil { + t.Fatalf("can't recreate trie at %x: %v", exp, err) + } + for _, kv := range vals { + if string(getString(trie2, kv.k)) != kv.v { + t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) + } + } + hash, err := trie2.Commit() + if err != nil { + t.Fatalf("commit error: %v", err) + } + if hash != exp { t.Errorf("root failure. expected %x got %x", exp, hash) } -} - -func TestReset(t *testing.T) { - trie := NewEmpty() - vals := []struct{ k, v string }{ + // perform some insertions on the new trie. + vals2 := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, {"horse", "stallion"}, + // {"shaman", "horse"}, + // {"doge", "coin"}, + // {"ether", ""}, + // {"dog", "puppy"}, + // {"somethingveryoddindeedthis is", "myothernodedata"}, + // {"shaman", ""}, } - for _, val := range vals { - trie.UpdateString(val.k, val.v) + for _, val := range vals2 { + updateString(trie2, val.k, val.v) } - trie.Commit() - - before := common.CopyBytes(trie.roothash) - trie.UpdateString("should", "revert") - trie.Hash() - // Should have no effect - trie.Hash() - trie.Hash() - // ### - - trie.Reset() - after := common.CopyBytes(trie.roothash) + if trie2.Hash() != exp { + t.Errorf("root failure. expected %x got %x", exp, hash) + } +} - if !bytes.Equal(before, after) { - t.Errorf("expected roots to be equal. %x - %x", before, after) +func paranoiaCheck(t1 *Trie) (bool, *Trie) { + t2 := new(Trie) + it := NewIterator(t1) + for it.Next() { + t2.Update(it.Key, it.Value) } + return t2.Hash() == t1.Hash(), t2 } func TestParanoia(t *testing.T) { t.Skip() - trie := NewEmpty() + trie := newEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, @@ -225,13 +252,13 @@ func TestParanoia(t *testing.T) { {"somethingveryoddindeedthis is", "myothernodedata"}, } for _, val := range vals { - trie.UpdateString(val.k, val.v) + updateString(trie, val.k, val.v) } trie.Commit() - ok, t2 := ParanoiaCheck(trie, trie.cache.backend) + ok, t2 := paranoiaCheck(trie) if !ok { - t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash) + t.Errorf("trie paranoia check failed %x %x", trie.Hash(), t2.Hash()) } } @@ -240,51 +267,26 @@ func TestOutput(t *testing.T) { t.Skip() base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - trie := NewEmpty() + trie := newEmpty() for i := 0; i < 50; i++ { - trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") + updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") } fmt.Println("############################## FULL ################################") fmt.Println(trie.root) trie.Commit() fmt.Println("############################## SMALL ################################") - trie2 := New(trie.roothash, trie.cache.backend) - trie2.GetString(base + "20") + trie2, _ := New(trie.Hash(), trie.db) + getString(trie2, base+"20") fmt.Println(trie2.root) } -func BenchmarkGets(b *testing.B) { - trie := NewEmpty() - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - {"somethingveryoddindeedthis is", "myothernodedata"}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - trie.Get([]byte("horse")) - } -} - -func BenchmarkUpdate(b *testing.B) { - trie := NewEmpty() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value") - } +func TestLargeValue(t *testing.T) { + trie := newEmpty() + trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) + trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) trie.Hash() + } type kv struct { @@ -293,7 +295,7 @@ type kv struct { } func TestLargeData(t *testing.T) { - trie := NewEmpty() + trie := newEmpty() vals := make(map[string]*kv) for i := byte(0); i < 255; i++ { @@ -305,7 +307,7 @@ func TestLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := trie.Iterator() + it := NewIterator(trie) for it.Next() { vals[string(it.Key)].t = true } @@ -325,30 +327,82 @@ func TestLargeData(t *testing.T) { } } -func TestSecureDelete(t *testing.T) { - trie := NewEmptySecure() +func BenchmarkGet(b *testing.B) { benchGet(b, false) } +func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } +func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } +func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) } +func BenchmarkHashBE(b *testing.B) { benchHash(b, binary.BigEndian) } +func BenchmarkHashLE(b *testing.B) { benchHash(b, binary.LittleEndian) } - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, +const benchElemCount = 20000 + +func benchGet(b *testing.B, commit bool) { + trie := new(Trie) + if commit { + dir, tmpdb := tempDB() + defer os.RemoveAll(dir) + trie, _ = New(common.Hash{}, tmpdb) } - for _, val := range vals { - if val.v != "" { - trie.UpdateString(val.k, val.v) - } else { - trie.DeleteString(val.k) - } + k := make([]byte, 32) + for i := 0; i < benchElemCount; i++ { + binary.LittleEndian.PutUint64(k, uint64(i)) + trie.Update(k, k) + } + binary.LittleEndian.PutUint64(k, benchElemCount/2) + if commit { + trie.Commit() } - hash := trie.Hash() - exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") - if !bytes.Equal(hash, exp) { - t.Errorf("expected %x got %x", exp, hash) + b.ResetTimer() + for i := 0; i < b.N; i++ { + trie.Get(k) } } + +func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie { + trie := newEmpty() + k := make([]byte, 32) + for i := 0; i < b.N; i++ { + e.PutUint64(k, uint64(i)) + trie.Update(k, k) + } + return trie +} + +func benchHash(b *testing.B, e binary.ByteOrder) { + trie := newEmpty() + k := make([]byte, 32) + for i := 0; i < benchElemCount; i++ { + e.PutUint64(k, uint64(i)) + trie.Update(k, k) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + trie.Hash() + } +} + +func tempDB() (string, Database) { + dir, err := ioutil.TempDir("", "trie-bench") + if err != nil { + panic(fmt.Sprintf("can't create temporary directory: %v", err)) + } + db, err := ethdb.NewLDBDatabase(dir, 300*1024) + if err != nil { + panic(fmt.Sprintf("can't create temporary database: %v", err)) + } + return dir, db +} + +func getString(trie *Trie, k string) []byte { + return trie.Get([]byte(k)) +} + +func updateString(trie *Trie, k, v string) { + trie.Update([]byte(k), []byte(v)) +} + +func deleteString(trie *Trie, k string) { + trie.Delete([]byte(k)) +} diff --git a/trie/valuenode.go b/trie/valuenode.go deleted file mode 100644 index 0afa64d54..000000000 --- a/trie/valuenode.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package trie - -import "github.com/ethereum/go-ethereum/common" - -type ValueNode struct { - trie *Trie - data []byte - dirty bool -} - -func NewValueNode(trie *Trie, data []byte) *ValueNode { - return &ValueNode{trie, data, false} -} - -func (self *ValueNode) Value() Node { return self } // Best not to call :-) -func (self *ValueNode) Val() []byte { return self.data } -func (self *ValueNode) Dirty() bool { return self.dirty } -func (self *ValueNode) Copy(t *Trie) Node { - return &ValueNode{t, common.CopyBytes(self.data), self.dirty} -} -func (self *ValueNode) RlpData() interface{} { return self.data } -func (self *ValueNode) Hash() interface{} { return self.data } - -func (self *ValueNode) setDirty(dirty bool) { - self.dirty = dirty -} diff --git a/xeth/xeth.go b/xeth/xeth.go index 00b70da6c..da712a984 100644 --- a/xeth/xeth.go +++ b/xeth/xeth.go @@ -33,9 +33,10 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/event/filter" + "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/miner" @@ -75,7 +76,7 @@ type XEth struct { whisper *Whisper quit chan struct{} - filterManager *filter.FilterManager + filterManager *filters.FilterSystem logMu sync.RWMutex logQueue map[int]*logQueue @@ -111,7 +112,7 @@ func New(ethereum *eth.Ethereum, frontend Frontend) *XEth { backend: ethereum, frontend: frontend, quit: make(chan struct{}), - filterManager: filter.NewFilterManager(ethereum.EventMux()), + filterManager: filters.NewFilterSystem(ethereum.EventMux()), logQueue: make(map[int]*logQueue), blockQueue: make(map[int]*hashQueue), transactionQueue: make(map[int]*hashQueue), @@ -125,10 +126,9 @@ func New(ethereum *eth.Ethereum, frontend Frontend) *XEth { if frontend == nil { xeth.frontend = dummyFrontend{} } - xeth.state = NewState(xeth, xeth.backend.ChainManager().State()) + xeth.state = NewState(xeth, xeth.backend.BlockChain().State()) go xeth.start() - go xeth.filterManager.Start() return xeth } @@ -142,7 +142,7 @@ done: self.logMu.Lock() for id, filter := range self.logQueue { if time.Since(filter.timeout) > filterTickerTime { - self.filterManager.UninstallFilter(id) + self.filterManager.Remove(id) delete(self.logQueue, id) } } @@ -151,7 +151,7 @@ done: self.blockMu.Lock() for id, filter := range self.blockQueue { if time.Since(filter.timeout) > filterTickerTime { - self.filterManager.UninstallFilter(id) + self.filterManager.Remove(id) delete(self.blockQueue, id) } } @@ -160,7 +160,7 @@ done: self.transactionMu.Lock() for id, filter := range self.transactionQueue { if time.Since(filter.timeout) > filterTickerTime { - self.filterManager.UninstallFilter(id) + self.filterManager.Remove(id) delete(self.transactionQueue, id) } } @@ -214,7 +214,7 @@ func (self *XEth) AtStateNum(num int64) *XEth { if block := self.getBlockByHeight(num); block != nil { st = state.New(block.Root(), self.backend.ChainDb()) } else { - st = state.New(self.backend.ChainManager().GetBlockByNumber(0).Root(), self.backend.ChainDb()) + st = state.New(self.backend.BlockChain().GetBlockByNumber(0).Root(), self.backend.ChainDb()) } } @@ -290,19 +290,19 @@ func (self *XEth) getBlockByHeight(height int64) *types.Block { num = uint64(height) } - return self.backend.ChainManager().GetBlockByNumber(num) + return self.backend.BlockChain().GetBlockByNumber(num) } func (self *XEth) BlockByHash(strHash string) *Block { hash := common.HexToHash(strHash) - block := self.backend.ChainManager().GetBlock(hash) + block := self.backend.BlockChain().GetBlock(hash) return NewBlock(block) } func (self *XEth) EthBlockByHash(strHash string) *types.Block { hash := common.HexToHash(strHash) - block := self.backend.ChainManager().GetBlock(hash) + block := self.backend.BlockChain().GetBlock(hash) return block } @@ -356,11 +356,11 @@ func (self *XEth) EthBlockByNumber(num int64) *types.Block { } func (self *XEth) Td(hash common.Hash) *big.Int { - return self.backend.ChainManager().GetTd(hash) + return self.backend.BlockChain().GetTd(hash) } func (self *XEth) CurrentBlock() *types.Block { - return self.backend.ChainManager().CurrentBlock() + return self.backend.BlockChain().CurrentBlock() } func (self *XEth) GetBlockReceipts(bhash common.Hash) types.Receipts { @@ -372,7 +372,7 @@ func (self *XEth) GetTxReceipt(txhash common.Hash) *types.Receipt { } func (self *XEth) GasLimit() *big.Int { - return self.backend.ChainManager().GasLimit() + return self.backend.BlockChain().GasLimit() } func (self *XEth) Block(v interface{}) *Block { @@ -504,7 +504,7 @@ func (self *XEth) IsContract(address string) bool { } func (self *XEth) UninstallFilter(id int) bool { - defer self.filterManager.UninstallFilter(id) + defer self.filterManager.Remove(id) if _, ok := self.logQueue[id]; ok { self.logMu.Lock() @@ -532,22 +532,24 @@ func (self *XEth) NewLogFilter(earliest, latest int64, skip, max int, address [] self.logMu.Lock() defer self.logMu.Unlock() - var id int - filter := core.NewFilter(self.backend) + filter := filters.New(self.backend.ChainDb()) + id := self.filterManager.Add(filter) + self.logQueue[id] = &logQueue{timeout: time.Now()} + filter.SetEarliestBlock(earliest) filter.SetLatestBlock(latest) filter.SetSkip(skip) filter.SetMax(max) filter.SetAddress(cAddress(address)) filter.SetTopics(cTopics(topics)) - filter.LogsCallback = func(logs state.Logs) { + filter.LogsCallback = func(logs vm.Logs) { self.logMu.Lock() defer self.logMu.Unlock() - self.logQueue[id].add(logs...) + if queue := self.logQueue[id]; queue != nil { + queue.add(logs...) + } } - id = self.filterManager.InstallFilter(filter) - self.logQueue[id] = &logQueue{timeout: time.Now()} return id } @@ -556,16 +558,18 @@ func (self *XEth) NewTransactionFilter() int { self.transactionMu.Lock() defer self.transactionMu.Unlock() - var id int - filter := core.NewFilter(self.backend) + filter := filters.New(self.backend.ChainDb()) + id := self.filterManager.Add(filter) + self.transactionQueue[id] = &hashQueue{timeout: time.Now()} + filter.TransactionCallback = func(tx *types.Transaction) { self.transactionMu.Lock() defer self.transactionMu.Unlock() - self.transactionQueue[id].add(tx.Hash()) + if queue := self.transactionQueue[id]; queue != nil { + queue.add(tx.Hash()) + } } - id = self.filterManager.InstallFilter(filter) - self.transactionQueue[id] = &hashQueue{timeout: time.Now()} return id } @@ -573,16 +577,18 @@ func (self *XEth) NewBlockFilter() int { self.blockMu.Lock() defer self.blockMu.Unlock() - var id int - filter := core.NewFilter(self.backend) - filter.BlockCallback = func(block *types.Block, logs state.Logs) { + filter := filters.New(self.backend.ChainDb()) + id := self.filterManager.Add(filter) + self.blockQueue[id] = &hashQueue{timeout: time.Now()} + + filter.BlockCallback = func(block *types.Block, logs vm.Logs) { self.blockMu.Lock() defer self.blockMu.Unlock() - self.blockQueue[id].add(block.Hash()) + if queue := self.blockQueue[id]; queue != nil { + queue.add(block.Hash()) + } } - id = self.filterManager.InstallFilter(filter) - self.blockQueue[id] = &hashQueue{timeout: time.Now()} return id } @@ -598,7 +604,7 @@ func (self *XEth) GetFilterType(id int) byte { return UnknownFilterTy } -func (self *XEth) LogFilterChanged(id int) state.Logs { +func (self *XEth) LogFilterChanged(id int) vm.Logs { self.logMu.Lock() defer self.logMu.Unlock() @@ -628,8 +634,8 @@ func (self *XEth) TransactionFilterChanged(id int) []common.Hash { return nil } -func (self *XEth) Logs(id int) state.Logs { - filter := self.filterManager.GetFilter(id) +func (self *XEth) Logs(id int) vm.Logs { + filter := self.filterManager.Get(id) if filter != nil { return filter.Find() } @@ -637,8 +643,8 @@ func (self *XEth) Logs(id int) state.Logs { return nil } -func (self *XEth) AllLogs(earliest, latest int64, skip, max int, address []string, topics [][]string) state.Logs { - filter := core.NewFilter(self.backend) +func (self *XEth) AllLogs(earliest, latest int64, skip, max int, address []string, topics [][]string) vm.Logs { + filter := filters.New(self.backend.ChainDb()) filter.SetEarliestBlock(earliest) filter.SetLatestBlock(latest) filter.SetSkip(skip) @@ -849,7 +855,7 @@ func (self *XEth) Call(fromStr, toStr, valueStr, gasStr, gasPriceStr, dataStr st } header := self.CurrentBlock().Header() - vmenv := core.NewEnv(statedb, self.backend.ChainManager(), msg, header) + vmenv := core.NewEnv(statedb, self.backend.BlockChain(), msg, header) res, gas, err := core.ApplyMessage(vmenv, msg, from) return common.ToHex(res), gas.String(), err @@ -1022,16 +1028,24 @@ func (m callmsg) Value() *big.Int { return m.value } func (m callmsg) Data() []byte { return m.data } type logQueue struct { - logs state.Logs + mu sync.Mutex + + logs vm.Logs timeout time.Time id int } -func (l *logQueue) add(logs ...*state.Log) { +func (l *logQueue) add(logs ...*vm.Log) { + l.mu.Lock() + defer l.mu.Unlock() + l.logs = append(l.logs, logs...) } -func (l *logQueue) get() state.Logs { +func (l *logQueue) get() vm.Logs { + l.mu.Lock() + defer l.mu.Unlock() + l.timeout = time.Now() tmp := l.logs l.logs = nil @@ -1039,16 +1053,24 @@ func (l *logQueue) get() state.Logs { } type hashQueue struct { + mu sync.Mutex + hashes []common.Hash timeout time.Time id int } func (l *hashQueue) add(hashes ...common.Hash) { + l.mu.Lock() + defer l.mu.Unlock() + l.hashes = append(l.hashes, hashes...) } func (l *hashQueue) get() []common.Hash { + l.mu.Lock() + defer l.mu.Unlock() + l.timeout = time.Now() tmp := l.hashes l.hashes = nil |