diff options
-rw-r--r-- | accounts/account_manager.go | 145 | ||||
-rw-r--r-- | accounts/accounts_test.go | 44 | ||||
-rw-r--r-- | cmd/geth/js_test.go | 10 | ||||
-rw-r--r-- | cmd/geth/main.go | 104 | ||||
-rw-r--r-- | cmd/utils/flags.go | 2 | ||||
-rw-r--r-- | core/block_processor.go | 9 | ||||
-rw-r--r-- | core/chain_makers.go | 4 | ||||
-rw-r--r-- | core/error.go | 22 | ||||
-rw-r--r-- | core/execution.go | 2 | ||||
-rw-r--r-- | core/genesis.go | 2 | ||||
-rw-r--r-- | core/state/state_object.go | 8 | ||||
-rw-r--r-- | core/state/state_test.go | 2 | ||||
-rw-r--r-- | core/state/statedb.go | 26 | ||||
-rw-r--r-- | core/state_transition.go | 2 | ||||
-rw-r--r-- | core/vm/errors.go | 24 | ||||
-rw-r--r-- | core/vm/vm.go | 4 | ||||
-rw-r--r-- | crypto/crypto.go | 2 | ||||
-rw-r--r-- | crypto/key_store_passphrase.go | 53 | ||||
-rw-r--r-- | crypto/key_store_plain.go | 155 | ||||
-rw-r--r-- | eth/backend.go | 12 | ||||
-rw-r--r-- | miner/worker.go | 2 | ||||
-rw-r--r-- | tests/block_test_util.go | 2 | ||||
-rw-r--r-- | tests/init.go | 5 | ||||
-rw-r--r-- | tests/state_test_util.go | 2 | ||||
-rw-r--r-- | trie/fullnode.go | 17 | ||||
-rw-r--r-- | trie/hashnode.go | 11 | ||||
-rw-r--r-- | trie/node.go | 1 | ||||
-rw-r--r-- | trie/shortnode.go | 12 | ||||
-rw-r--r-- | trie/trie.go | 37 | ||||
-rw-r--r-- | trie/trie_test.go | 2 | ||||
-rw-r--r-- | trie/valuenode.go | 23 | ||||
-rw-r--r-- | xeth/xeth.go | 5 |
32 files changed, 474 insertions, 277 deletions
diff --git a/accounts/account_manager.go b/accounts/account_manager.go index 13f16296a..17b128e9e 100644 --- a/accounts/account_manager.go +++ b/accounts/account_manager.go @@ -26,7 +26,7 @@ This abstracts part of a user's interaction with an account she controls. It's not an abstraction of core Ethereum accounts data type / logic - for that see the core processing code of blocks / txs. -Currently this is pretty much a passthrough to the KeyStore2 interface, +Currently this is pretty much a passthrough to the KeyStore interface, and accounts persistence is derived from stored keys' addresses */ @@ -36,6 +36,7 @@ import ( "crypto/ecdsa" crand "crypto/rand" "errors" + "fmt" "os" "sync" "time" @@ -49,17 +50,12 @@ var ( ErrNoKeys = errors.New("no keys in store") ) -const ( - // Default unlock duration (in seconds) when an account is unlocked from the console - DefaultAccountUnlockDuration = 300 -) - type Account struct { Address common.Address } type Manager struct { - keyStore crypto.KeyStore2 + keyStore crypto.KeyStore unlocked map[common.Address]*unlocked mutex sync.RWMutex } @@ -69,7 +65,7 @@ type unlocked struct { abort chan struct{} } -func NewManager(keyStore crypto.KeyStore2) *Manager { +func NewManager(keyStore crypto.KeyStore) *Manager { return &Manager{ keyStore: keyStore, unlocked: make(map[common.Address]*unlocked), @@ -86,19 +82,6 @@ func (am *Manager) HasAccount(addr common.Address) bool { return false } -func (am *Manager) Primary() (addr common.Address, err error) { - addrs, err := am.keyStore.GetKeyAddresses() - if os.IsNotExist(err) { - return common.Address{}, ErrNoKeys - } else if err != nil { - return common.Address{}, err - } - if len(addrs) == 0 { - return common.Address{}, ErrNoKeys - } - return addrs[0], nil -} - func (am *Manager) DeleteAccount(address common.Address, auth string) error { return am.keyStore.DeleteKey(address, auth) } @@ -114,28 +97,58 @@ func (am *Manager) Sign(a Account, toSign []byte) (signature []byte, err error) return signature, err } -// TimedUnlock unlocks the account with the given address. -// When timeout has passed, the account will be locked again. +// unlock indefinitely +func (am *Manager) Unlock(addr common.Address, keyAuth string) error { + return am.TimedUnlock(addr, keyAuth, 0) +} + +// Unlock unlocks the account with the given address. The account +// stays unlocked for the duration of timeout +// it timeout is 0 the account is unlocked for the entire session func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error { key, err := am.keyStore.GetKey(addr, keyAuth) if err != nil { return err } - u := am.addUnlocked(addr, key) - go am.dropLater(addr, u, timeout) + var u *unlocked + am.mutex.Lock() + defer am.mutex.Unlock() + var found bool + u, found = am.unlocked[addr] + if found { + // terminate dropLater for this key to avoid unexpected drops. + if u.abort != nil { + close(u.abort) + } + } + if timeout > 0 { + u = &unlocked{Key: key, abort: make(chan struct{})} + go am.expire(addr, u, timeout) + } else { + u = &unlocked{Key: key} + } + am.unlocked[addr] = u return nil } -// Unlock unlocks the account with the given address. The account -// stays unlocked until the program exits or until a TimedUnlock -// timeout (started after the call to Unlock) expires. -func (am *Manager) Unlock(addr common.Address, keyAuth string) error { - key, err := am.keyStore.GetKey(addr, keyAuth) - if err != nil { - return err +func (am *Manager) expire(addr common.Address, u *unlocked, timeout time.Duration) { + t := time.NewTimer(timeout) + defer t.Stop() + select { + case <-u.abort: + // just quit + case <-t.C: + am.mutex.Lock() + // only drop if it's still the same key instance that dropLater + // was launched with. we can check that using pointer equality + // because the map stores a new pointer every time the key is + // unlocked. + if am.unlocked[addr] == u { + zeroKey(u.PrivateKey) + delete(am.unlocked, addr) + } + am.mutex.Unlock() } - am.addUnlocked(addr, key) - return nil } func (am *Manager) NewAccount(auth string) (Account, error) { @@ -146,6 +159,20 @@ func (am *Manager) NewAccount(auth string) (Account, error) { return Account{Address: key.Address}, nil } +func (am *Manager) AddressByIndex(index int) (addr string, err error) { + var addrs []common.Address + addrs, err = am.keyStore.GetKeyAddresses() + if err != nil { + return + } + if index < 0 || index >= len(addrs) { + err = fmt.Errorf("index out of range: %d (should be 0-%d)", index, len(addrs)-1) + } else { + addr = addrs[index].Hex() + } + return +} + func (am *Manager) Accounts() ([]Account, error) { addresses, err := am.keyStore.GetKeyAddresses() if os.IsNotExist(err) { @@ -162,43 +189,6 @@ func (am *Manager) Accounts() ([]Account, error) { return accounts, err } -func (am *Manager) addUnlocked(addr common.Address, key *crypto.Key) *unlocked { - u := &unlocked{Key: key, abort: make(chan struct{})} - am.mutex.Lock() - prev, found := am.unlocked[addr] - if found { - // terminate dropLater for this key to avoid unexpected drops. - close(prev.abort) - // the key is zeroed here instead of in dropLater because - // there might not actually be a dropLater running for this - // key, i.e. when Unlock was used. - zeroKey(prev.PrivateKey) - } - am.unlocked[addr] = u - am.mutex.Unlock() - return u -} - -func (am *Manager) dropLater(addr common.Address, u *unlocked, timeout time.Duration) { - t := time.NewTimer(timeout) - defer t.Stop() - select { - case <-u.abort: - // just quit - case <-t.C: - am.mutex.Lock() - // only drop if it's still the same key instance that dropLater - // was launched with. we can check that using pointer equality - // because the map stores a new pointer every time the key is - // unlocked. - if am.unlocked[addr] == u { - zeroKey(u.PrivateKey) - delete(am.unlocked, addr) - } - am.mutex.Unlock() - } -} - // zeroKey zeroes a private key in memory. func zeroKey(k *ecdsa.PrivateKey) { b := k.D.Bits() @@ -229,6 +219,19 @@ func (am *Manager) Import(path string, keyAuth string) (Account, error) { return Account{Address: key.Address}, nil } +func (am *Manager) Update(addr common.Address, authFrom, authTo string) (err error) { + var key *crypto.Key + key, err = am.keyStore.GetKey(addr, authFrom) + + if err == nil { + err = am.keyStore.StoreKey(key, authTo) + if err == nil { + am.keyStore.Cleanup(addr) + } + } + return +} + func (am *Manager) ImportPreSaleKey(keyJSON []byte, password string) (acc Account, err error) { var key *crypto.Key key, err = crypto.ImportPreSaleKey(am.keyStore, keyJSON, password) diff --git a/accounts/accounts_test.go b/accounts/accounts_test.go index 427114cbd..4b94b78fd 100644 --- a/accounts/accounts_test.go +++ b/accounts/accounts_test.go @@ -58,9 +58,51 @@ func TestTimedUnlock(t *testing.T) { if err != ErrLocked { t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) } + +} + +func TestOverrideUnlock(t *testing.T) { + dir, ks := tmpKeyStore(t, crypto.NewKeyStorePassphrase) + defer os.RemoveAll(dir) + + am := NewManager(ks) + pass := "foo" + a1, err := am.NewAccount(pass) + toSign := randentropy.GetEntropyCSPRNG(32) + + // Unlock indefinitely + if err = am.Unlock(a1.Address, pass); err != nil { + t.Fatal(err) + } + + // Signing without passphrase works because account is temp unlocked + _, err = am.Sign(a1, toSign) + if err != nil { + t.Fatal("Signing shouldn't return an error after unlocking, got ", err) + } + + // reset unlock to a shorter period, invalidates the previous unlock + if err = am.TimedUnlock(a1.Address, pass, 100*time.Millisecond); err != nil { + t.Fatal(err) + } + + // Signing without passphrase still works because account is temp unlocked + _, err = am.Sign(a1, toSign) + if err != nil { + t.Fatal("Signing shouldn't return an error after unlocking, got ", err) + } + + // Signing fails again after automatic locking + time.Sleep(150 * time.Millisecond) + _, err = am.Sign(a1, toSign) + if err != ErrLocked { + t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) + } } -func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) { +// + +func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore) (string, crypto.KeyStore) { d, err := ioutil.TempDir("", "eth-keystore-test") if err != nil { t.Fatal(err) diff --git a/cmd/geth/js_test.go b/cmd/geth/js_test.go index cfbe26bee..480f77c91 100644 --- a/cmd/geth/js_test.go +++ b/cmd/geth/js_test.go @@ -20,8 +20,8 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/rpc/comms" "github.com/ethereum/go-ethereum/rpc/codec" + "github.com/ethereum/go-ethereum/rpc/comms" ) const ( @@ -127,6 +127,7 @@ func TestNodeInfo(t *testing.T) { } defer ethereum.Stop() defer os.RemoveAll(tmp) + want := `{"DiscPort":0,"IP":"0.0.0.0","ListenAddr":"","Name":"test","NodeID":"4cb2fc32924e94277bf94b5e4c983beedb2eabd5a0bc941db32202735c6625d020ca14a5963d1738af43b6ac0a711d61b1a06de931a499fe2aa0b1a132a902b5","NodeUrl":"enode://4cb2fc32924e94277bf94b5e4c983beedb2eabd5a0bc941db32202735c6625d020ca14a5963d1738af43b6ac0a711d61b1a06de931a499fe2aa0b1a132a902b5@0.0.0.0:0","TCPPort":0,"Td":"131072"}` checkEvalJSON(t, repl, `admin.nodeInfo`, want) } @@ -140,8 +141,7 @@ func TestAccounts(t *testing.T) { defer os.RemoveAll(tmp) checkEvalJSON(t, repl, `eth.accounts`, `["`+testAddress+`"]`) - checkEvalJSON(t, repl, `eth.coinbase`, `"`+testAddress+`"`) - + checkEvalJSON(t, repl, `eth.coinbase`, `null`) val, err := repl.re.Run(`personal.newAccount("password")`) if err != nil { t.Errorf("expected no error, got %v", err) @@ -151,9 +151,7 @@ func TestAccounts(t *testing.T) { t.Errorf("address not hex: %q", addr) } - // skip until order fixed #824 - // checkEvalJSON(t, repl, `eth.accounts`, `["`+testAddress+`", "`+addr+`"]`) - // checkEvalJSON(t, repl, `eth.coinbase`, `"`+testAddress+`"`) + checkEvalJSON(t, repl, `eth.accounts`, `["`+testAddress+`","`+addr+`"]`) } func TestBlockChain(t *testing.T) { diff --git a/cmd/geth/main.go b/cmd/geth/main.go index be40d5137..ffd26a7c2 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -153,9 +153,12 @@ Note that exporting your key in unencrypted format is NOT supported. Keys are stored under <DATADIR>/keys. It is safe to transfer the entire directory or the individual keys therein -between ethereum nodes. +between ethereum nodes by simply copying. Make sure you backup your keys regularly. +In order to use your account to send transactions, you need to unlock them using the +'--unlock' option. The argument is a comma + And finally. DO NOT FORGET YOUR PASSWORD. `, Subcommands: []cli.Command{ @@ -187,6 +190,33 @@ password to file or expose in any other way. `, }, { + Action: accountUpdate, + Name: "update", + Usage: "update an existing account", + Description: ` + + ethereum account update <address> + +Update an existing account. + +The account is saved in the newest version in encrypted format, you are prompted +for a passphrase to unlock the account and another to save the updated file. + +This same command can therefore be used to migrate an account of a deprecated +format to the newest format or change the password for an account. + +For non-interactive use the passphrase can be specified with the --password flag: + + ethereum --password <passwordfile> account new + +Since only one password can be given, only format update can be performed, +changing your password is only possible interactively. + +Note that account update has the a side effect that the order of your accounts +changes. + `, + }, + { Action: accountImport, Name: "import", Usage: "import a private key into a new account", @@ -430,19 +460,30 @@ func execJSFiles(ctx *cli.Context) { ethereum.WaitForShutdown() } -func unlockAccount(ctx *cli.Context, am *accounts.Manager, account string) (passphrase string) { +func unlockAccount(ctx *cli.Context, am *accounts.Manager, addr string, i int) (addrHex, auth string) { var err error // Load startup keys. XXX we are going to need a different format - if !((len(account) == 40) || (len(account) == 42)) { // with or without 0x - utils.Fatalf("Invalid account address '%s'", account) + if !((len(addr) == 40) || (len(addr) == 42)) { // with or without 0x + var index int + index, err = strconv.Atoi(addr) + if err != nil { + utils.Fatalf("Invalid account address '%s'", addr) + } + + addrHex, err = am.AddressByIndex(index) + if err != nil { + utils.Fatalf("%v", err) + } + } else { + addrHex = addr } // Attempt to unlock the account 3 times attempts := 3 for tries := 0; tries < attempts; tries++ { - msg := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", account, tries+1, attempts) - passphrase = getPassPhrase(ctx, msg, false) - err = am.Unlock(common.HexToAddress(account), passphrase) + msg := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", addr, tries+1, attempts) + auth = getPassPhrase(ctx, msg, false, i) + err = am.Unlock(common.HexToAddress(addrHex), auth) if err == nil { break } @@ -450,7 +491,7 @@ func unlockAccount(ctx *cli.Context, am *accounts.Manager, account string) (pass if err != nil { utils.Fatalf("Unlock account failed '%v'", err) } - fmt.Printf("Account '%s' unlocked.\n", account) + fmt.Printf("Account '%s' unlocked.\n", addr) return } @@ -492,16 +533,12 @@ func startEth(ctx *cli.Context, eth *eth.Ethereum) { account := ctx.GlobalString(utils.UnlockedAccountFlag.Name) accounts := strings.Split(account, " ") - for _, account := range accounts { + for i, account := range accounts { if len(account) > 0 { if account == "primary" { - primaryAcc, err := am.Primary() - if err != nil { - utils.Fatalf("no primary account: %v", err) - } - account = primaryAcc.Hex() + utils.Fatalf("the 'primary' keyword is deprecated. You can use integer indexes, but the indexes are not permanent, they can change if you add external keys, export your keys or copy your keystore to another node.") } - unlockAccount(ctx, am, account) + unlockAccount(ctx, am, account, i) } } // Start auxiliary services if enabled. @@ -528,14 +565,12 @@ func accountList(ctx *cli.Context) { if err != nil { utils.Fatalf("Could not list accounts: %v", err) } - name := "Primary" for i, acct := range accts { - fmt.Printf("%s #%d: %x\n", name, i, acct) - name = "Account" + fmt.Printf("Account #%d: %x\n", i, acct) } } -func getPassPhrase(ctx *cli.Context, desc string, confirmation bool) (passphrase string) { +func getPassPhrase(ctx *cli.Context, desc string, confirmation bool, i int) (passphrase string) { passfile := ctx.GlobalString(utils.PasswordFileFlag.Name) if len(passfile) == 0 { fmt.Println(desc) @@ -559,14 +594,22 @@ func getPassPhrase(ctx *cli.Context, desc string, confirmation bool) (passphrase if err != nil { utils.Fatalf("Unable to read password file '%s': %v", passfile, err) } - passphrase = string(passbytes) + // this is backwards compatible if the same password unlocks several accounts + // it also has the consequence that trailing newlines will not count as part + // of the password, so --password <(echo -n 'pass') will now work without -n + passphrases := strings.Split(string(passbytes), "\n") + if i >= len(passphrases) { + passphrase = passphrases[len(passphrases)-1] + } else { + passphrase = passphrases[i] + } } return } func accountCreate(ctx *cli.Context) { am := utils.MakeAccountManager(ctx) - passphrase := getPassPhrase(ctx, "Your new account is locked with a password. Please give a password. Do not forget this password.", true) + passphrase := getPassPhrase(ctx, "Your new account is locked with a password. Please give a password. Do not forget this password.", true, 0) acct, err := am.NewAccount(passphrase) if err != nil { utils.Fatalf("Could not create the account: %v", err) @@ -574,6 +617,21 @@ func accountCreate(ctx *cli.Context) { fmt.Printf("Address: %x\n", acct) } +func accountUpdate(ctx *cli.Context) { + am := utils.MakeAccountManager(ctx) + arg := ctx.Args().First() + if len(arg) == 0 { + utils.Fatalf("account address or index must be given as argument") + } + + addr, authFrom := unlockAccount(ctx, am, arg, 0) + authTo := getPassPhrase(ctx, "Please give a new password. Do not forget this password.", true, 0) + err := am.Update(common.HexToAddress(addr), authFrom, authTo) + if err != nil { + utils.Fatalf("Could not update the account: %v", err) + } +} + func importWallet(ctx *cli.Context) { keyfile := ctx.Args().First() if len(keyfile) == 0 { @@ -585,7 +643,7 @@ func importWallet(ctx *cli.Context) { } am := utils.MakeAccountManager(ctx) - passphrase := getPassPhrase(ctx, "", false) + passphrase := getPassPhrase(ctx, "", false, 0) acct, err := am.ImportPreSaleKey(keyJson, passphrase) if err != nil { @@ -600,7 +658,7 @@ func accountImport(ctx *cli.Context) { utils.Fatalf("keyfile must be given as argument") } am := utils.MakeAccountManager(ctx) - passphrase := getPassPhrase(ctx, "Your new account is locked with a password. Please give a password. Do not forget this password.", true) + passphrase := getPassPhrase(ctx, "Your new account is locked with a password. Please give a password. Do not forget this password.", true, 0) acct, err := am.Import(keyfile, passphrase) if err != nil { utils.Fatalf("Could not create the account: %v", err) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index b4182ff59..20d3543d6 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -133,7 +133,7 @@ var ( UnlockedAccountFlag = cli.StringFlag{ Name: "unlock", - Usage: "Unlock the account given until this program exits (prompts for password). '--unlock primary' unlocks the primary account", + Usage: "Unlock the account given until this program exits (prompts for password). '--unlock n' unlocks the n-th account in order or creation.", Value: "", } PasswordFileFlag = cli.StringFlag{ diff --git a/core/block_processor.go b/core/block_processor.go index e8014ec22..660c917e4 100644 --- a/core/block_processor.go +++ b/core/block_processor.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" @@ -72,12 +73,12 @@ func (self *BlockProcessor) ApplyTransaction(coinbase *state.StateObject, stated cb := statedb.GetStateObject(coinbase.Address()) _, gas, err := ApplyMessage(NewEnv(statedb, self.bc, tx, header), tx, cb) - if err != nil && (IsNonceErr(err) || state.IsGasLimitErr(err) || IsInvalidTxErr(err)) { + if err != nil && err != vm.OutOfGasError { return nil, nil, err } // Update the state with pending changes - statedb.Update() + statedb.SyncIntermediate() usedGas.Add(usedGas, gas) receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas) @@ -118,7 +119,7 @@ func (self *BlockProcessor) ApplyTransactions(coinbase *state.StateObject, state statedb.StartRecord(tx.Hash(), block.Hash(), i) receipt, txGas, err := self.ApplyTransaction(coinbase, statedb, header, tx, totalUsedGas, transientProcess) - if err != nil && (IsNonceErr(err) || state.IsGasLimitErr(err) || IsInvalidTxErr(err)) { + if err != nil && err != vm.OutOfGasError { return nil, err } @@ -243,7 +244,7 @@ func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs st // Commit state objects/accounts to a temporary trie (does not save) // used to calculate the state root. - state.Update() + state.SyncObjects() if header.Root != state.Root() { err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root()) return diff --git a/core/chain_makers.go b/core/chain_makers.go index 37475e0ae..c46f627f8 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -77,7 +77,7 @@ func (b *BlockGen) AddTx(tx *types.Transaction) { if err != nil { panic(err) } - b.statedb.Update() + b.statedb.SyncIntermediate() b.header.GasUsed.Add(b.header.GasUsed, gas) receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed) logs := b.statedb.GetLogs(tx.Hash()) @@ -135,7 +135,7 @@ func GenerateChain(parent *types.Block, db common.Database, n int, gen func(int, gen(i, b) } AccumulateRewards(statedb, h, b.uncles) - statedb.Update() + statedb.SyncIntermediate() h.Root = statedb.Root() return types.NewBlock(h, b.txs, b.uncles, b.receipts) } diff --git a/core/error.go b/core/error.go index 3f3c350df..fb64d09b2 100644 --- a/core/error.go +++ b/core/error.go @@ -30,7 +30,6 @@ func ParentError(hash common.Hash) error { func IsParentErr(err error) bool { _, ok := err.(*ParentErr) - return ok } @@ -48,7 +47,6 @@ func UncleError(format string, v ...interface{}) error { func IsUncleErr(err error) bool { _, ok := err.(*UncleErr) - return ok } @@ -67,7 +65,6 @@ func ValidationError(format string, v ...interface{}) *ValidationErr { func IsValidationErr(err error) bool { _, ok := err.(*ValidationErr) - return ok } @@ -86,7 +83,6 @@ func NonceError(is, exp uint64) *NonceErr { func IsNonceErr(err error) bool { _, ok := err.(*NonceErr) - return ok } @@ -121,24 +117,6 @@ func InvalidTxError(err error) *InvalidTxErr { func IsInvalidTxErr(err error) bool { _, ok := err.(*InvalidTxErr) - - return ok -} - -type OutOfGasErr struct { - Message string -} - -func OutOfGasError() *OutOfGasErr { - return &OutOfGasErr{Message: "Out of gas"} -} -func (self *OutOfGasErr) Error() string { - return self.Message -} - -func IsOutOfGasErr(err error) bool { - _, ok := err.(*OutOfGasErr) - return ok } diff --git a/core/execution.go b/core/execution.go index 9fb0210de..a8c4ffb6d 100644 --- a/core/execution.go +++ b/core/execution.go @@ -53,7 +53,7 @@ func (self *Execution) exec(contextAddr *common.Address, code []byte, caller vm. if env.Depth() > int(params.CallCreateDepth.Int64()) { caller.ReturnGas(self.Gas, self.price) - return nil, vm.DepthError{} + return nil, vm.DepthError } vsnapshot := env.State().Copy() diff --git a/core/genesis.go b/core/genesis.go index df13466ec..d27e7097b 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -64,7 +64,7 @@ func GenesisBlockForTesting(db common.Database, addr common.Address, balance *bi statedb := state.New(common.Hash{}, db) obj := statedb.GetOrNewStateObject(addr) obj.SetBalance(balance) - statedb.Update() + statedb.SyncObjects() statedb.Sync() block := types.NewBlock(&types.Header{ Difficulty: params.GenesisDifficulty, diff --git a/core/state/state_object.go b/core/state/state_object.go index a31c182d2..e40aeda82 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -57,8 +57,6 @@ type StateObject struct { initCode Code // Cached storage (flushed when updated) storage Storage - // Temporary prepaid gas, reward after transition - prepaid *big.Int // Total gas pool is the total amount of gas currently // left if this object is the coinbase. Gas is directly @@ -77,14 +75,10 @@ func (self *StateObject) Reset() { } func NewStateObject(address common.Address, db common.Database) *StateObject { - // This to ensure that it has 20 bytes (and not 0 bytes), thus left or right pad doesn't matter. - //address := common.ToAddress(addr) - object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true} object.trie = trie.NewSecure((common.Hash{}).Bytes(), db) object.storage = make(Storage) object.gasPool = new(big.Int) - object.prepaid = new(big.Int) return object } @@ -110,7 +104,6 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db common.Data object.trie = trie.NewSecure(extobject.Root[:], db) object.storage = make(map[string]common.Hash) object.gasPool = new(big.Int) - object.prepaid = new(big.Int) object.code, _ = db.Get(extobject.CodeHash) return object @@ -172,7 +165,6 @@ func (self *StateObject) Update() { self.setAddr([]byte(key), value) } - self.storage = make(Storage) } func (c *StateObject) GetInstr(pc *big.Int) *common.Value { diff --git a/core/state/state_test.go b/core/state/state_test.go index 00e133dab..b63b8ae9b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -72,7 +72,7 @@ func TestNull(t *testing.T) { //value := common.FromHex("0x823140710bf13990e4500136726d8b55") var value common.Hash state.SetState(address, common.Hash{}, value) - state.Update() + state.SyncIntermediate() state.Sync() value = state.GetState(address, common.Hash{}) if !common.EmptyHash(value) { diff --git a/core/state/statedb.go b/core/state/statedb.go index f6f63f329..4ccda1fc7 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -18,6 +18,7 @@ import ( type StateDB struct { db common.Database trie *trie.SecureTrie + root common.Hash stateObjects map[string]*StateObject @@ -31,7 +32,7 @@ type StateDB struct { // Create a new state from a given trie func New(root common.Hash, db common.Database) *StateDB { trie := trie.NewSecure(root[:], db) - return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)} + return &StateDB{root: root, db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)} } func (self *StateDB) PrintRoot() { @@ -185,7 +186,7 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) { addr := stateObject.Address() self.trie.Delete(addr[:]) - delete(self.stateObjects, addr.Str()) + //delete(self.stateObjects, addr.Str()) } // Retrieve a state object given my the address. Nil if not found @@ -323,7 +324,8 @@ func (self *StateDB) Refunds() *big.Int { return self.refund } -func (self *StateDB) Update() { +// SyncIntermediate updates the intermediate state and all mid steps +func (self *StateDB) SyncIntermediate() { self.refund = new(big.Int) for _, stateObject := range self.stateObjects { @@ -340,6 +342,24 @@ func (self *StateDB) Update() { } } +// SyncObjects syncs the changed objects to the trie +func (self *StateDB) SyncObjects() { + self.trie = trie.NewSecure(self.root[:], self.db) + + self.refund = new(big.Int) + + for _, stateObject := range self.stateObjects { + if stateObject.remove { + self.DeleteStateObject(stateObject) + } else { + stateObject.Update() + + self.UpdateStateObject(stateObject) + } + stateObject.dirty = false + } +} + // Debug stuff func (self *StateDB) CreateOutputForDiff() { for _, stateObject := range self.stateObjects { diff --git a/core/state_transition.go b/core/state_transition.go index 5611ffd0f..465000e87 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -122,7 +122,7 @@ func (self *StateTransition) To() *state.StateObject { func (self *StateTransition) UseGas(amount *big.Int) error { if self.gas.Cmp(amount) < 0 { - return OutOfGasError() + return vm.OutOfGasError } self.gas.Sub(self.gas, amount) diff --git a/core/vm/errors.go b/core/vm/errors.go index 799eb6797..75b9c0f10 100644 --- a/core/vm/errors.go +++ b/core/vm/errors.go @@ -1,21 +1,14 @@ package vm import ( + "errors" "fmt" "github.com/ethereum/go-ethereum/params" ) -type OutOfGasError struct{} - -func (self OutOfGasError) Error() string { - return "Out Of Gas" -} - -func IsOOGErr(err error) bool { - _, ok := err.(OutOfGasError) - return ok -} +var OutOfGasError = errors.New("Out of gas") +var DepthError = fmt.Errorf("Max call depth exceeded (%d)", params.CallCreateDepth) type StackError struct { req, has int @@ -33,14 +26,3 @@ func IsStack(err error) bool { _, ok := err.(StackError) return ok } - -type DepthError struct{} - -func (self DepthError) Error() string { - return fmt.Sprintf("Max call depth exceeded (%d)", params.CallCreateDepth) -} - -func IsDepthErr(err error) bool { - _, ok := err.(DepthError) - return ok -} diff --git a/core/vm/vm.go b/core/vm/vm.go index ba803683b..e390fb89c 100644 --- a/core/vm/vm.go +++ b/core/vm/vm.go @@ -116,7 +116,7 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { context.UseGas(context.Gas) - return context.Return(nil), OutOfGasError{} + return context.Return(nil), OutOfGasError } // Resize the memory calculated previously mem.Resize(newMemSize.Uint64()) @@ -789,7 +789,7 @@ func (self *Vm) RunPrecompiled(p *PrecompiledAccount, input []byte, context *Con return context.Return(ret), nil } else { - return nil, OutOfGasError{} + return nil, OutOfGasError } } diff --git a/crypto/crypto.go b/crypto/crypto.go index 153bbbc5d..deef67415 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -209,7 +209,7 @@ func ImportBlockTestKey(privKeyBytes []byte) error { } // creates a Key and stores that in the given KeyStore by decrypting a presale key JSON -func ImportPreSaleKey(keyStore KeyStore2, keyJSON []byte, password string) (*Key, error) { +func ImportPreSaleKey(keyStore KeyStore, keyJSON []byte, password string) (*Key, error) { key, err := decryptPreSaleKey(keyJSON, password) if err != nil { return nil, err diff --git a/crypto/key_store_passphrase.go b/crypto/key_store_passphrase.go index 2000a2438..47909bc76 100644 --- a/crypto/key_store_passphrase.go +++ b/crypto/key_store_passphrase.go @@ -41,8 +41,6 @@ import ( "errors" "fmt" "io" - "os" - "path/filepath" "reflect" "code.google.com/p/go-uuid/uuid" @@ -65,7 +63,7 @@ type keyStorePassphrase struct { keysDirPath string } -func NewKeyStorePassphrase(path string) KeyStore2 { +func NewKeyStorePassphrase(path string) KeyStore { return &keyStorePassphrase{path} } @@ -74,20 +72,23 @@ func (ks keyStorePassphrase) GenerateNewKey(rand io.Reader, auth string) (key *K } func (ks keyStorePassphrase) GetKey(keyAddr common.Address, auth string) (key *Key, err error) { - keyBytes, keyId, err := DecryptKeyFromFile(ks, keyAddr, auth) - if err != nil { - return nil, err - } - key = &Key{ - Id: uuid.UUID(keyId), - Address: keyAddr, - PrivateKey: ToECDSA(keyBytes), + keyBytes, keyId, err := decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) + if err == nil { + key = &Key{ + Id: uuid.UUID(keyId), + Address: keyAddr, + PrivateKey: ToECDSA(keyBytes), + } } - return key, err + return +} + +func (ks keyStorePassphrase) Cleanup(keyAddr common.Address) (err error) { + return cleanup(ks.keysDirPath, keyAddr) } func (ks keyStorePassphrase) GetKeyAddresses() (addresses []common.Address, err error) { - return GetKeyAddresses(ks.keysDirPath) + return getKeyAddresses(ks.keysDirPath) } func (ks keyStorePassphrase) StoreKey(key *Key, auth string) (err error) { @@ -139,42 +140,40 @@ func (ks keyStorePassphrase) StoreKey(key *Key, auth string) (err error) { return err } - return WriteKeyFile(key.Address, ks.keysDirPath, keyJSON) + return writeKeyFile(key.Address, ks.keysDirPath, keyJSON) } func (ks keyStorePassphrase) DeleteKey(keyAddr common.Address, auth string) (err error) { // only delete if correct passphrase is given - _, _, err = DecryptKeyFromFile(ks, keyAddr, auth) + _, _, err = decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) if err != nil { return err } - keyDirPath := filepath.Join(ks.keysDirPath, hex.EncodeToString(keyAddr[:])) - return os.RemoveAll(keyDirPath) + return deleteKey(ks.keysDirPath, keyAddr) } -func DecryptKeyFromFile(ks keyStorePassphrase, keyAddr common.Address, auth string) (keyBytes []byte, keyId []byte, err error) { - fileContent, err := GetKeyFile(ks.keysDirPath, keyAddr) +func decryptKeyFromFile(keysDirPath string, keyAddr common.Address, auth string) (keyBytes []byte, keyId []byte, err error) { + fmt.Printf("%v\n", keyAddr.Hex()) + m := make(map[string]interface{}) + err = getKey(keysDirPath, keyAddr, &m) if err != nil { - return nil, nil, err + return } - m := make(map[string]interface{}) - err = json.Unmarshal(fileContent, &m) - v := reflect.ValueOf(m["version"]) if v.Kind() == reflect.String && v.String() == "1" { k := new(encryptedKeyJSONV1) - err := json.Unmarshal(fileContent, k) + err = getKey(keysDirPath, keyAddr, &k) if err != nil { - return nil, nil, err + return } return decryptKeyV1(k, auth) } else { k := new(encryptedKeyJSONV3) - err := json.Unmarshal(fileContent, k) + err = getKey(keysDirPath, keyAddr, &k) if err != nil { - return nil, nil, err + return } return decryptKeyV3(k, auth) } diff --git a/crypto/key_store_plain.go b/crypto/key_store_plain.go index 6a8afe27d..c13c5e7a4 100644 --- a/crypto/key_store_plain.go +++ b/crypto/key_store_plain.go @@ -27,28 +27,30 @@ import ( "encoding/hex" "encoding/json" "fmt" - "github.com/ethereum/go-ethereum/common" "io" "io/ioutil" "os" "path/filepath" + "time" + + "github.com/ethereum/go-ethereum/common" ) -// TODO: rename to KeyStore when replacing existing KeyStore -type KeyStore2 interface { +type KeyStore interface { // create new key using io.Reader entropy source and optionally using auth string GenerateNewKey(io.Reader, string) (*Key, error) - GetKey(common.Address, string) (*Key, error) // key from addr and auth string + GetKey(common.Address, string) (*Key, error) // get key from addr and auth string GetKeyAddresses() ([]common.Address, error) // get all addresses StoreKey(*Key, string) error // store key optionally using auth string DeleteKey(common.Address, string) error // delete key by addr and auth string + Cleanup(keyAddr common.Address) (err error) } type keyStorePlain struct { keysDirPath string } -func NewKeyStorePlain(path string) KeyStore2 { +func NewKeyStorePlain(path string) KeyStore { return &keyStorePlain{path} } @@ -56,7 +58,7 @@ func (ks keyStorePlain) GenerateNewKey(rand io.Reader, auth string) (key *Key, e return GenerateNewKeyDefault(ks, rand, auth) } -func GenerateNewKeyDefault(ks KeyStore2, rand io.Reader, auth string) (key *Key, err error) { +func GenerateNewKeyDefault(ks KeyStore, rand io.Reader, auth string) (key *Key, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("GenerateNewKey error: %v", r) @@ -68,62 +70,149 @@ func GenerateNewKeyDefault(ks KeyStore2, rand io.Reader, auth string) (key *Key, } func (ks keyStorePlain) GetKey(keyAddr common.Address, auth string) (key *Key, err error) { - fileContent, err := GetKeyFile(ks.keysDirPath, keyAddr) + key = new(Key) + err = getKey(ks.keysDirPath, keyAddr, key) + return +} + +func getKey(keysDirPath string, keyAddr common.Address, content interface{}) (err error) { + fileContent, err := getKeyFile(keysDirPath, keyAddr) if err != nil { - return nil, err + return } - - key = new(Key) - err = json.Unmarshal(fileContent, key) - return key, err + return json.Unmarshal(fileContent, content) } func (ks keyStorePlain) GetKeyAddresses() (addresses []common.Address, err error) { - return GetKeyAddresses(ks.keysDirPath) + return getKeyAddresses(ks.keysDirPath) +} + +func (ks keyStorePlain) Cleanup(keyAddr common.Address) (err error) { + return cleanup(ks.keysDirPath, keyAddr) } func (ks keyStorePlain) StoreKey(key *Key, auth string) (err error) { keyJSON, err := json.Marshal(key) if err != nil { - return err + return } - err = WriteKeyFile(key.Address, ks.keysDirPath, keyJSON) - return err + err = writeKeyFile(key.Address, ks.keysDirPath, keyJSON) + return } func (ks keyStorePlain) DeleteKey(keyAddr common.Address, auth string) (err error) { - keyDirPath := filepath.Join(ks.keysDirPath, keyAddr.Hex()) - err = os.RemoveAll(keyDirPath) - return err + return deleteKey(ks.keysDirPath, keyAddr) } -func GetKeyFile(keysDirPath string, keyAddr common.Address) (fileContent []byte, err error) { - fileName := hex.EncodeToString(keyAddr[:]) - return ioutil.ReadFile(filepath.Join(keysDirPath, fileName, fileName)) +func deleteKey(keysDirPath string, keyAddr common.Address) (err error) { + var path string + path, err = getKeyFilePath(keysDirPath, keyAddr) + if err == nil { + addrHex := hex.EncodeToString(keyAddr[:]) + if path == filepath.Join(keysDirPath, addrHex, addrHex) { + path = filepath.Join(keysDirPath, addrHex) + } + err = os.RemoveAll(path) + } + return +} + +func getKeyFilePath(keysDirPath string, keyAddr common.Address) (keyFilePath string, err error) { + addrHex := hex.EncodeToString(keyAddr[:]) + matches, err := filepath.Glob(filepath.Join(keysDirPath, fmt.Sprintf("*--%s", addrHex))) + if len(matches) > 0 { + if err == nil { + keyFilePath = matches[len(matches)-1] + } + return + } + keyFilePath = filepath.Join(keysDirPath, addrHex, addrHex) + _, err = os.Stat(keyFilePath) + return } -func WriteKeyFile(addr common.Address, keysDirPath string, content []byte) (err error) { - addrHex := hex.EncodeToString(addr[:]) - keyDirPath := filepath.Join(keysDirPath, addrHex) - keyFilePath := filepath.Join(keyDirPath, addrHex) - err = os.MkdirAll(keyDirPath, 0700) // read, write and dir search for user +func cleanup(keysDirPath string, keyAddr common.Address) (err error) { + fileInfos, err := ioutil.ReadDir(keysDirPath) + if err != nil { + return + } + var paths []string + account := hex.EncodeToString(keyAddr[:]) + for _, fileInfo := range fileInfos { + path := filepath.Join(keysDirPath, fileInfo.Name()) + if len(path) >= 40 { + addr := path[len(path)-40 : len(path)] + if addr == account { + if path == filepath.Join(keysDirPath, addr, addr) { + path = filepath.Join(keysDirPath, addr) + } + paths = append(paths, path) + } + } + } + if len(paths) > 1 { + for i := 0; err == nil && i < len(paths)-1; i++ { + err = os.RemoveAll(paths[i]) + if err != nil { + break + } + } + } + return +} + +func getKeyFile(keysDirPath string, keyAddr common.Address) (fileContent []byte, err error) { + var keyFilePath string + keyFilePath, err = getKeyFilePath(keysDirPath, keyAddr) + if err == nil { + fileContent, err = ioutil.ReadFile(keyFilePath) + } + return +} + +func writeKeyFile(addr common.Address, keysDirPath string, content []byte) (err error) { + filename := keyFileName(addr) + // read, write and dir search for user + err = os.MkdirAll(keysDirPath, 0700) if err != nil { return err } - return ioutil.WriteFile(keyFilePath, content, 0600) // read, write for user + // read, write for user + return ioutil.WriteFile(filepath.Join(keysDirPath, filename), content, 0600) +} + +// keyFilePath implements the naming convention for keyfiles: +// UTC--<created_at UTC ISO8601>-<address hex> +func keyFileName(keyAddr common.Address) string { + ts := time.Now().UTC() + return fmt.Sprintf("UTC--%s--%s", toISO8601(ts), hex.EncodeToString(keyAddr[:])) +} + +func toISO8601(t time.Time) string { + var tz string + name, offset := t.Zone() + if name == "UTC" { + tz = "Z" + } else { + tz = fmt.Sprintf("%03d00", offset/3600) + } + return fmt.Sprintf("%04d-%02d-%02dT%02d:%02d:%02d.%09d%s", t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), tz) } -func GetKeyAddresses(keysDirPath string) (addresses []common.Address, err error) { +func getKeyAddresses(keysDirPath string) (addresses []common.Address, err error) { fileInfos, err := ioutil.ReadDir(keysDirPath) if err != nil { return nil, err } for _, fileInfo := range fileInfos { - address, err := hex.DecodeString(fileInfo.Name()) - if err != nil { - continue + filename := fileInfo.Name() + if len(filename) >= 40 { + addr := filename[len(filename)-40 : len(filename)] + address, err := hex.DecodeString(addr) + if err == nil { + addresses = append(addresses, common.BytesToAddress(address)) + } } - addresses = append(addresses, common.BytesToAddress(address)) } return addresses, err } diff --git a/eth/backend.go b/eth/backend.go index 618eec9fb..e62252b6c 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -464,17 +464,9 @@ func (s *Ethereum) StartMining(threads int) error { func (s *Ethereum) Etherbase() (eb common.Address, err error) { eb = s.etherbase if (eb == common.Address{}) { - primary, err := s.accountManager.Primary() - if err != nil { - return eb, err - } - if (primary == common.Address{}) { - err = fmt.Errorf("no accounts found") - return eb, err - } - eb = primary + err = fmt.Errorf("etherbase address must be explicitly specified") } - return eb, nil + return } func (s *Ethereum) StopMining() { s.miner.Stop() } diff --git a/miner/worker.go b/miner/worker.go index dd004da6e..1615ff84b 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -453,7 +453,7 @@ func (self *worker) commitNewWork() { if atomic.LoadInt32(&self.mining) == 1 { // commit state root after all state transitions. core.AccumulateRewards(self.current.state, header, uncles) - current.state.Update() + current.state.SyncObjects() self.current.state.Sync() header.Root = current.state.Root() } diff --git a/tests/block_test_util.go b/tests/block_test_util.go index 67f6a1d18..3b20da492 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -215,7 +215,7 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro } } // sync objects to trie - statedb.Update() + statedb.SyncObjects() // sync trie to disk statedb.Sync() diff --git a/tests/init.go b/tests/init.go index dd8df930f..1deaf5912 100644 --- a/tests/init.go +++ b/tests/init.go @@ -20,11 +20,6 @@ var ( BlockSkipTests = []string{ "SimpleTx3", - // these panic in block_processor.go:84 , see https://github.com/ethereum/go-ethereum/issues/1384 - "TRANSCT_rvalue_TooShort", - "TRANSCT_rvalue_TooLarge", - "TRANSCT_svalue_TooLarge", - // TODO: check why these fail "BLOCK__RandomByteAtTheEnd", "TRANSCT__RandomByteAtTheEnd", diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 2f3d497be..7f1a22ac0 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -175,7 +175,7 @@ func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state. if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) { statedb.Set(snapshot) } - statedb.Update() + statedb.SyncObjects() return ret, vmenv.state.Logs(), vmenv.Gas, err } diff --git a/trie/fullnode.go b/trie/fullnode.go index 522fdb373..1bfdcd5bf 100644 --- a/trie/fullnode.go +++ b/trie/fullnode.go @@ -1,17 +1,16 @@ package trie -import "fmt" - type FullNode struct { trie *Trie nodes [17]Node + dirty bool } func NewFullNode(t *Trie) *FullNode { return &FullNode{trie: t} } -func (self *FullNode) Dirty() bool { return true } +func (self *FullNode) Dirty() bool { return self.dirty } func (self *FullNode) Value() Node { self.nodes[16] = self.trie.trans(self.nodes[16]) return self.nodes[16] @@ -24,9 +23,10 @@ func (self *FullNode) Copy(t *Trie) Node { nnode := NewFullNode(t) for i, node := range self.nodes { if node != nil { - nnode.nodes[i] = node.Copy(t) + nnode.nodes[i] = node } } + nnode.dirty = true return nnode } @@ -60,11 +60,8 @@ func (self *FullNode) RlpData() interface{} { } func (self *FullNode) set(k byte, value Node) { - if _, ok := value.(*ValueNode); ok && k != 16 { - fmt.Println(value, k) - } - self.nodes[int(k)] = value + self.dirty = true } func (self *FullNode) branch(i byte) Node { @@ -75,3 +72,7 @@ func (self *FullNode) branch(i byte) Node { } return nil } + +func (self *FullNode) setDirty(dirty bool) { + self.dirty = dirty +} diff --git a/trie/hashnode.go b/trie/hashnode.go index 8125cc3c9..e82ab8069 100644 --- a/trie/hashnode.go +++ b/trie/hashnode.go @@ -3,12 +3,13 @@ package trie import "github.com/ethereum/go-ethereum/common" type HashNode struct { - key []byte - trie *Trie + key []byte + trie *Trie + dirty bool } func NewHash(key []byte, trie *Trie) *HashNode { - return &HashNode{key, trie} + return &HashNode{key, trie, false} } func (self *HashNode) RlpData() interface{} { @@ -19,6 +20,10 @@ func (self *HashNode) Hash() interface{} { return self.key } +func (self *HashNode) setDirty(dirty bool) { + self.dirty = dirty +} + // These methods will never be called but we have to satisfy Node interface func (self *HashNode) Value() Node { return nil } func (self *HashNode) Dirty() bool { return true } diff --git a/trie/node.go b/trie/node.go index 0d8a7cff9..dccbc64a3 100644 --- a/trie/node.go +++ b/trie/node.go @@ -11,6 +11,7 @@ type Node interface { fstring(string) string Hash() interface{} RlpData() interface{} + setDirty(dirty bool) } // Value node diff --git a/trie/shortnode.go b/trie/shortnode.go index edd490b4d..c86e50096 100644 --- a/trie/shortnode.go +++ b/trie/shortnode.go @@ -6,20 +6,22 @@ type ShortNode struct { trie *Trie key []byte value Node + dirty bool } func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { - return &ShortNode{t, []byte(CompactEncode(key)), value} + return &ShortNode{t, []byte(CompactEncode(key)), value, false} } func (self *ShortNode) Value() Node { self.value = self.trie.trans(self.value) return self.value } -func (self *ShortNode) Dirty() bool { return true } +func (self *ShortNode) Dirty() bool { return self.dirty } func (self *ShortNode) Copy(t *Trie) Node { - node := &ShortNode{t, nil, self.value.Copy(t)} + node := &ShortNode{t, nil, self.value.Copy(t), self.dirty} node.key = common.CopyBytes(self.key) + node.dirty = true return node } @@ -33,3 +35,7 @@ func (self *ShortNode) Hash() interface{} { func (self *ShortNode) Key() []byte { return CompactDecode(string(self.key)) } + +func (self *ShortNode) setDirty(dirty bool) { + self.dirty = dirty +} diff --git a/trie/trie.go b/trie/trie.go index d990338ee..7e17baa2f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -117,7 +117,9 @@ func (self *Trie) Update(key, value []byte) Node { k := CompactHexDecode(string(key)) if len(value) != 0 { - self.root = self.insert(self.root, k, &ValueNode{self, value}) + node := NewValueNode(self, value) + node.dirty = true + self.root = self.insert(self.root, k, node) } else { self.root = self.delete(self.root, k) } @@ -157,7 +159,9 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { } if node == nil { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node } switch node := node.(type) { @@ -165,7 +169,10 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { k := node.Key() cnode := node.Value() if bytes.Equal(k, key) { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node + } var n Node @@ -176,6 +183,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { pnode := self.insert(nil, k[matchlength+1:], cnode) nnode := self.insert(nil, key[matchlength+1:], value) fulln := NewFullNode(self) + fulln.dirty = true fulln.set(k[matchlength], pnode) fulln.set(key[matchlength], nnode) n = fulln @@ -184,11 +192,14 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { return n } - return NewShortNode(self, key[:matchlength], n) + snode := NewShortNode(self, key[:matchlength], n) + snode.dirty = true + return snode case *FullNode: cpy := node.Copy(self).(*FullNode) cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) + cpy.dirty = true return cpy @@ -242,8 +253,10 @@ func (self *Trie) delete(node Node, key []byte) Node { case *ShortNode: nkey := append(k, child.Key()...) n = NewShortNode(self, nkey, child.Value()) + n.(*ShortNode).dirty = true case *FullNode: sn := NewShortNode(self, node.Key(), child) + sn.dirty = true sn.key = node.key n = sn } @@ -256,6 +269,7 @@ func (self *Trie) delete(node Node, key []byte) Node { case *FullNode: n := node.Copy(self).(*FullNode) n.set(key[0], self.delete(n.branch(key[0]), key[1:])) + n.dirty = true pos := -1 for i := 0; i < 17; i++ { @@ -271,6 +285,7 @@ func (self *Trie) delete(node Node, key []byte) Node { var nnode Node if pos == 16 { nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } else if pos >= 0 { cnode := n.branch(byte(pos)) switch cnode := cnode.(type) { @@ -278,8 +293,10 @@ func (self *Trie) delete(node Node, key []byte) Node { // Stitch keys k := append([]byte{byte(pos)}, cnode.Key()...) nnode = NewShortNode(self, k, cnode.Value()) + nnode.(*ShortNode).dirty = true case *FullNode: nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } } else { nnode = n @@ -304,7 +321,7 @@ func (self *Trie) mknode(value *common.Value) Node { if value.Get(0).Len() != 0 { key := CompactDecode(string(value.Get(0).Bytes())) if key[len(key)-1] == 16 { - return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()}) + return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes())) } else { return NewShortNode(self, key, self.mknode(value.Get(1))) } @@ -318,10 +335,10 @@ func (self *Trie) mknode(value *common.Value) Node { return fnode } case 32: - return &HashNode{value.Bytes(), self} + return NewHash(value.Bytes(), self) } - return &ValueNode{self, value.Bytes()} + return NewValueNode(self, value.Bytes()) } func (self *Trie) trans(node Node) Node { @@ -338,7 +355,11 @@ func (self *Trie) store(node Node) interface{} { data := common.Encode(node) if len(data) >= 32 { key := crypto.Sha3(data) - self.cache.Put(key, data) + if node.Dirty() { + //fmt.Println("save", node) + //fmt.Println() + self.cache.Put(key, data) + } return key } diff --git a/trie/trie_test.go b/trie/trie_test.go index 9a58958d8..60f0873a8 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -152,7 +152,7 @@ func TestReplication(t *testing.T) { } trie.Commit() - trie2 := New(trie.roothash, trie.cache.backend) + trie2 := New(trie.Root(), trie.cache.backend) if string(trie2.GetString("horse")) != "stallion" { t.Error("expected to have horse => stallion") } diff --git a/trie/valuenode.go b/trie/valuenode.go index 7bf8ff06e..6adb59652 100644 --- a/trie/valuenode.go +++ b/trie/valuenode.go @@ -3,13 +3,24 @@ package trie import "github.com/ethereum/go-ethereum/common" type ValueNode struct { - trie *Trie - data []byte + trie *Trie + data []byte + dirty bool } -func (self *ValueNode) Value() Node { return self } // Best not to call :-) -func (self *ValueNode) Val() []byte { return self.data } -func (self *ValueNode) Dirty() bool { return true } -func (self *ValueNode) Copy(t *Trie) Node { return &ValueNode{t, common.CopyBytes(self.data)} } +func NewValueNode(trie *Trie, data []byte) *ValueNode { + return &ValueNode{trie, data, false} +} + +func (self *ValueNode) Value() Node { return self } // Best not to call :-) +func (self *ValueNode) Val() []byte { return self.data } +func (self *ValueNode) Dirty() bool { return self.dirty } +func (self *ValueNode) Copy(t *Trie) Node { + return &ValueNode{t, common.CopyBytes(self.data), self.dirty} +} func (self *ValueNode) RlpData() interface{} { return self.data } func (self *ValueNode) Hash() interface{} { return self.data } + +func (self *ValueNode) setDirty(dirty bool) { + self.dirty = dirty +} diff --git a/xeth/xeth.go b/xeth/xeth.go index 1e87b738d..88d802820 100644 --- a/xeth/xeth.go +++ b/xeth/xeth.go @@ -467,7 +467,10 @@ func (self *XEth) IsListening() bool { } func (self *XEth) Coinbase() string { - eb, _ := self.backend.Etherbase() + eb, err := self.backend.Etherbase() + if err != nil { + return "0x0" + } return eb.Hex() } |