diff options
Diffstat (limited to 'ethchain')
-rw-r--r-- | ethchain/state.go | 16 | ||||
-rw-r--r-- | ethchain/state_object.go | 7 | ||||
-rw-r--r-- | ethchain/state_test.go | 31 | ||||
-rw-r--r-- | ethchain/vm.go | 17 |
4 files changed, 65 insertions, 6 deletions
diff --git a/ethchain/state.go b/ethchain/state.go index 6ec6916f4..e209e0e2f 100644 --- a/ethchain/state.go +++ b/ethchain/state.go @@ -99,7 +99,21 @@ func (s *State) Cmp(other *State) bool { } func (s *State) Copy() *State { - return NewState(s.trie.Copy()) + state := NewState(s.trie.Copy()) + for k, subState := range s.states { + state.states[k] = subState.Copy() + } + + return state +} + +func (s *State) Snapshot() *State { + return s.Copy() +} + +func (s *State) Revert(snapshot *State) { + s.trie = snapshot.trie + s.states = snapshot.states } func (s *State) Put(key, object []byte) { diff --git a/ethchain/state_object.go b/ethchain/state_object.go index 4d615e2fe..3e9c6df40 100644 --- a/ethchain/state_object.go +++ b/ethchain/state_object.go @@ -81,12 +81,17 @@ func (c *StateObject) SetStorage(num *big.Int, val *ethutil.Value) { c.SetAddr(addr, val) } -func (c *StateObject) GetMem(num *big.Int) *ethutil.Value { +func (c *StateObject) GetStorage(num *big.Int) *ethutil.Value { nb := ethutil.BigToBytes(num, 256) return c.Addr(nb) } +/* DEPRECATED */ +func (c *StateObject) GetMem(num *big.Int) *ethutil.Value { + return c.GetStorage(num) +} + func (c *StateObject) GetInstr(pc *big.Int) *ethutil.Value { if int64(len(c.script)-1) < pc.Int64() { return ethutil.NewValue(0) diff --git a/ethchain/state_test.go b/ethchain/state_test.go new file mode 100644 index 000000000..4cc3fdf75 --- /dev/null +++ b/ethchain/state_test.go @@ -0,0 +1,31 @@ +package ethchain + +import ( + "fmt" + "github.com/ethereum/eth-go/ethdb" + "github.com/ethereum/eth-go/ethutil" + "testing" +) + +func TestSnapshot(t *testing.T) { + ethutil.ReadConfig("", ethutil.LogStd, "") + + db, _ := ethdb.NewMemDatabase() + state := NewState(ethutil.NewTrie(db, "")) + + stateObject := NewContract([]byte("aa"), ethutil.Big1, ZeroHash256) + state.UpdateStateObject(stateObject) + stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(42)) + + snapshot := state.Snapshot() + + stateObject = state.GetStateObject([]byte("aa")) + stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(43)) + + state.Revert(snapshot) + + stateObject = state.GetStateObject([]byte("aa")) + if !stateObject.GetStorage(ethutil.Big("0")).Cmp(ethutil.NewValue(42)) { + t.Error("Expected storage 0 to be 42") + } +} diff --git a/ethchain/vm.go b/ethchain/vm.go index e067a9c96..e025920f3 100644 --- a/ethchain/vm.go +++ b/ethchain/vm.go @@ -426,6 +426,10 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro value := stack.Pop() size, offset := stack.Popn() + // Snapshot the current stack so we are able to + // revert back to it later. + snapshot := vm.state.Snapshot() + // Generate a new address addr := ethutil.CreateAddress(closure.callee.Address(), closure.callee.N()) // Create a new contract @@ -448,6 +452,9 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro closure.Script, err = closure.Call(vm, nil, hook) if err != nil { stack.Push(ethutil.BigFalse) + + // Revert the state as it was before. + vm.state.Revert(snapshot) } else { stack.Push(ethutil.BigD(addr)) @@ -473,6 +480,8 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro // Get the arguments from the memory args := mem.Get(inOffset.Int64(), inSize.Int64()) + snapshot := vm.state.Snapshot() + // Fetch the contract which will serve as the closure body contract := vm.state.GetStateObject(addr.Bytes()) @@ -495,14 +504,14 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro if err != nil { stack.Push(ethutil.BigFalse) // Reset the changes applied this object - //contract.State().Reset() + vm.state.Revert(snapshot) } else { stack.Push(ethutil.BigTrue) - } - vm.state.UpdateStateObject(contract) + vm.state.UpdateStateObject(contract) - mem.Set(retOffset.Int64(), retSize.Int64(), ret) + mem.Set(retOffset.Int64(), retSize.Int64(), ret) + } } else { ethutil.Config.Log.Debugf("Contract %x not found\n", addr.Bytes()) stack.Push(ethutil.BigFalse) |