aboutsummaryrefslogtreecommitdiffstats
path: root/ethchain
diff options
context:
space:
mode:
Diffstat (limited to 'ethchain')
-rw-r--r--ethchain/state.go16
-rw-r--r--ethchain/state_object.go7
-rw-r--r--ethchain/state_test.go31
-rw-r--r--ethchain/vm.go17
4 files changed, 65 insertions, 6 deletions
diff --git a/ethchain/state.go b/ethchain/state.go
index 6ec6916f4..e209e0e2f 100644
--- a/ethchain/state.go
+++ b/ethchain/state.go
@@ -99,7 +99,21 @@ func (s *State) Cmp(other *State) bool {
}
func (s *State) Copy() *State {
- return NewState(s.trie.Copy())
+ state := NewState(s.trie.Copy())
+ for k, subState := range s.states {
+ state.states[k] = subState.Copy()
+ }
+
+ return state
+}
+
+func (s *State) Snapshot() *State {
+ return s.Copy()
+}
+
+func (s *State) Revert(snapshot *State) {
+ s.trie = snapshot.trie
+ s.states = snapshot.states
}
func (s *State) Put(key, object []byte) {
diff --git a/ethchain/state_object.go b/ethchain/state_object.go
index 4d615e2fe..3e9c6df40 100644
--- a/ethchain/state_object.go
+++ b/ethchain/state_object.go
@@ -81,12 +81,17 @@ func (c *StateObject) SetStorage(num *big.Int, val *ethutil.Value) {
c.SetAddr(addr, val)
}
-func (c *StateObject) GetMem(num *big.Int) *ethutil.Value {
+func (c *StateObject) GetStorage(num *big.Int) *ethutil.Value {
nb := ethutil.BigToBytes(num, 256)
return c.Addr(nb)
}
+/* DEPRECATED */
+func (c *StateObject) GetMem(num *big.Int) *ethutil.Value {
+ return c.GetStorage(num)
+}
+
func (c *StateObject) GetInstr(pc *big.Int) *ethutil.Value {
if int64(len(c.script)-1) < pc.Int64() {
return ethutil.NewValue(0)
diff --git a/ethchain/state_test.go b/ethchain/state_test.go
new file mode 100644
index 000000000..4cc3fdf75
--- /dev/null
+++ b/ethchain/state_test.go
@@ -0,0 +1,31 @@
+package ethchain
+
+import (
+ "fmt"
+ "github.com/ethereum/eth-go/ethdb"
+ "github.com/ethereum/eth-go/ethutil"
+ "testing"
+)
+
+func TestSnapshot(t *testing.T) {
+ ethutil.ReadConfig("", ethutil.LogStd, "")
+
+ db, _ := ethdb.NewMemDatabase()
+ state := NewState(ethutil.NewTrie(db, ""))
+
+ stateObject := NewContract([]byte("aa"), ethutil.Big1, ZeroHash256)
+ state.UpdateStateObject(stateObject)
+ stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(42))
+
+ snapshot := state.Snapshot()
+
+ stateObject = state.GetStateObject([]byte("aa"))
+ stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(43))
+
+ state.Revert(snapshot)
+
+ stateObject = state.GetStateObject([]byte("aa"))
+ if !stateObject.GetStorage(ethutil.Big("0")).Cmp(ethutil.NewValue(42)) {
+ t.Error("Expected storage 0 to be 42")
+ }
+}
diff --git a/ethchain/vm.go b/ethchain/vm.go
index e067a9c96..e025920f3 100644
--- a/ethchain/vm.go
+++ b/ethchain/vm.go
@@ -426,6 +426,10 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
value := stack.Pop()
size, offset := stack.Popn()
+ // Snapshot the current stack so we are able to
+ // revert back to it later.
+ snapshot := vm.state.Snapshot()
+
// Generate a new address
addr := ethutil.CreateAddress(closure.callee.Address(), closure.callee.N())
// Create a new contract
@@ -448,6 +452,9 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
closure.Script, err = closure.Call(vm, nil, hook)
if err != nil {
stack.Push(ethutil.BigFalse)
+
+ // Revert the state as it was before.
+ vm.state.Revert(snapshot)
} else {
stack.Push(ethutil.BigD(addr))
@@ -473,6 +480,8 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
// Get the arguments from the memory
args := mem.Get(inOffset.Int64(), inSize.Int64())
+ snapshot := vm.state.Snapshot()
+
// Fetch the contract which will serve as the closure body
contract := vm.state.GetStateObject(addr.Bytes())
@@ -495,14 +504,14 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
if err != nil {
stack.Push(ethutil.BigFalse)
// Reset the changes applied this object
- //contract.State().Reset()
+ vm.state.Revert(snapshot)
} else {
stack.Push(ethutil.BigTrue)
- }
- vm.state.UpdateStateObject(contract)
+ vm.state.UpdateStateObject(contract)
- mem.Set(retOffset.Int64(), retSize.Int64(), ret)
+ mem.Set(retOffset.Int64(), retSize.Int64(), ret)
+ }
} else {
ethutil.Config.Log.Debugf("Contract %x not found\n", addr.Bytes())
stack.Push(ethutil.BigFalse)