diff options
Diffstat (limited to 'ethutil')
-rw-r--r-- | ethutil/README.md | 4 | ||||
-rw-r--r-- | ethutil/big.go | 15 | ||||
-rw-r--r-- | ethutil/common.go | 35 | ||||
-rw-r--r-- | ethutil/common_test.go | 17 | ||||
-rw-r--r-- | ethutil/config.go | 78 | ||||
-rw-r--r-- | ethutil/db.go | 2 | ||||
-rw-r--r-- | ethutil/encoding.go | 3 | ||||
-rw-r--r-- | ethutil/encoding_test.go | 30 | ||||
-rw-r--r-- | ethutil/helpers.go | 5 | ||||
-rw-r--r-- | ethutil/key.go | 19 | ||||
-rw-r--r-- | ethutil/parsing.go | 141 | ||||
-rw-r--r-- | ethutil/parsing_test.go | 6 | ||||
-rw-r--r-- | ethutil/rlp.go | 8 | ||||
-rw-r--r-- | ethutil/rlp_test.go | 49 | ||||
-rw-r--r-- | ethutil/trie.go | 212 | ||||
-rw-r--r-- | ethutil/trie_test.go | 142 | ||||
-rw-r--r-- | ethutil/value.go | 34 | ||||
-rw-r--r-- | ethutil/value_test.go | 52 |
18 files changed, 681 insertions, 171 deletions
diff --git a/ethutil/README.md b/ethutil/README.md index c98612e1e..1ed56b71b 100644 --- a/ethutil/README.md +++ b/ethutil/README.md @@ -53,6 +53,8 @@ trie.Put("doge", "coin") // Look up the key "do" in the trie out := trie.Get("do") fmt.Println(out) // => verb + +trie.Delete("puppy") ``` The patricia trie, in combination with RLP, provides a robust, @@ -82,7 +84,7 @@ type (e.g. `Slice()` returns []interface{}, `Uint()` return 0, etc). `NewEmptyValue()` returns a new \*Value with it's initial value set to a `[]interface{}` -`AppendLint()` appends a list to the current value. +`AppendList()` appends a list to the current value. `Append(v)` appends the value (v) to the current value/list. diff --git a/ethutil/big.go b/ethutil/big.go index 979078bef..1a3902fa3 100644 --- a/ethutil/big.go +++ b/ethutil/big.go @@ -35,3 +35,18 @@ func BigD(data []byte) *big.Int { return n } + +func BigToBytes(num *big.Int, base int) []byte { + ret := make([]byte, base/8) + + return append(ret[:len(ret)-len(num.Bytes())], num.Bytes()...) +} + +// Functions like the build in "copy" function +// but works on big integers +func BigCopy(src *big.Int) (ret *big.Int) { + ret = new(big.Int) + ret.Add(ret, src) + + return +} diff --git a/ethutil/common.go b/ethutil/common.go new file mode 100644 index 000000000..07df6bb13 --- /dev/null +++ b/ethutil/common.go @@ -0,0 +1,35 @@ +package ethutil + +import ( + "fmt" + "math/big" +) + +var ( + Ether = BigPow(10, 18) + Finney = BigPow(10, 15) + Szabo = BigPow(10, 12) + Vito = BigPow(10, 9) + Turing = BigPow(10, 6) + Eins = BigPow(10, 3) + Wei = big.NewInt(1) +) + +func CurrencyToString(num *big.Int) string { + switch { + case num.Cmp(Ether) >= 0: + return fmt.Sprintf("%v Ether", new(big.Int).Div(num, Ether)) + case num.Cmp(Finney) >= 0: + return fmt.Sprintf("%v Finney", new(big.Int).Div(num, Finney)) + case num.Cmp(Szabo) >= 0: + return fmt.Sprintf("%v Szabo", new(big.Int).Div(num, Szabo)) + case num.Cmp(Vito) >= 0: + return fmt.Sprintf("%v Vito", new(big.Int).Div(num, Vito)) + case num.Cmp(Turing) >= 0: + return fmt.Sprintf("%v Turing", new(big.Int).Div(num, Turing)) + case num.Cmp(Eins) >= 0: + return fmt.Sprintf("%v Eins", new(big.Int).Div(num, Eins)) + } + + return fmt.Sprintf("%v Wei", num) +} diff --git a/ethutil/common_test.go b/ethutil/common_test.go new file mode 100644 index 000000000..3a6a37ff5 --- /dev/null +++ b/ethutil/common_test.go @@ -0,0 +1,17 @@ +package ethutil + +import ( + "fmt" + "math/big" + "testing" +) + +func TestCommon(t *testing.T) { + fmt.Println(CurrencyToString(BigPow(10, 19))) + fmt.Println(CurrencyToString(BigPow(10, 16))) + fmt.Println(CurrencyToString(BigPow(10, 13))) + fmt.Println(CurrencyToString(BigPow(10, 10))) + fmt.Println(CurrencyToString(BigPow(10, 7))) + fmt.Println(CurrencyToString(BigPow(10, 4))) + fmt.Println(CurrencyToString(big.NewInt(10))) +} diff --git a/ethutil/config.go b/ethutil/config.go index 2a239f8e2..5bf56134d 100644 --- a/ethutil/config.go +++ b/ethutil/config.go @@ -1,6 +1,7 @@ package ethutil import ( + "fmt" "log" "os" "os/user" @@ -18,7 +19,7 @@ const ( type config struct { Db Database - Log Logger + Log *Logger ExecPath string Debug bool Ver string @@ -34,17 +35,19 @@ func ReadConfig(base string) *config { usr, _ := user.Current() path := path.Join(usr.HomeDir, base) - //Check if the logging directory already exists, create it if not - _, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - log.Printf("Debug logging directory %s doesn't exist, creating it", path) - os.Mkdir(path, 0777) + if len(base) > 0 { + //Check if the logging directory already exists, create it if not + _, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + log.Printf("Debug logging directory %s doesn't exist, creating it\n", path) + os.Mkdir(path, 0777) + } } } - Config = &config{ExecPath: path, Debug: true, Ver: "0.2.2"} - Config.Log = NewLogger(LogFile|LogStd, 0) + Config = &config{ExecPath: path, Debug: true, Ver: "0.3.0"} + Config.Log = NewLogger(LogFile|LogStd, LogLevelDebug) } return Config @@ -57,15 +60,20 @@ const ( LogStd = 0x2 ) +type LogSystem interface { + Println(v ...interface{}) + Printf(format string, v ...interface{}) +} + type Logger struct { - logSys []*log.Logger + logSys []LogSystem logLevel int } -func NewLogger(flag LoggerType, level int) Logger { - var loggers []*log.Logger +func NewLogger(flag LoggerType, level int) *Logger { + var loggers []LogSystem - flags := log.LstdFlags | log.Lshortfile + flags := log.LstdFlags if flag&LogFile > 0 { file, err := os.OpenFile(path.Join(Config.ExecPath, "debug.log"), os.O_RDWR|os.O_CREATE|os.O_APPEND, os.ModePerm) @@ -73,30 +81,60 @@ func NewLogger(flag LoggerType, level int) Logger { log.Panic("unable to create file logger", err) } - log := log.New(file, "[ETH]", flags) + log := log.New(file, "", flags) loggers = append(loggers, log) } if flag&LogStd > 0 { - log := log.New(os.Stdout, "[ETH]", flags) + log := log.New(os.Stdout, "", flags) loggers = append(loggers, log) } - return Logger{logSys: loggers, logLevel: level} + return &Logger{logSys: loggers, logLevel: level} +} + +func (log *Logger) AddLogSystem(logger LogSystem) { + log.logSys = append(log.logSys, logger) +} + +const ( + LogLevelDebug = iota + LogLevelInfo +) + +func (log *Logger) Debugln(v ...interface{}) { + if log.logLevel != LogLevelDebug { + return + } + + for _, logger := range log.logSys { + logger.Println(v...) + } +} + +func (log *Logger) Debugf(format string, v ...interface{}) { + if log.logLevel != LogLevelDebug { + return + } + + for _, logger := range log.logSys { + logger.Printf(format, v...) + } } -func (log Logger) Debugln(v ...interface{}) { - if log.logLevel != 0 { +func (log *Logger) Infoln(v ...interface{}) { + if log.logLevel > LogLevelInfo { return } + fmt.Println(len(log.logSys)) for _, logger := range log.logSys { logger.Println(v...) } } -func (log Logger) Debugf(format string, v ...interface{}) { - if log.logLevel != 0 { +func (log *Logger) Infof(format string, v ...interface{}) { + if log.logLevel > LogLevelInfo { return } diff --git a/ethutil/db.go b/ethutil/db.go index 3681c4b05..abbf4a2b0 100644 --- a/ethutil/db.go +++ b/ethutil/db.go @@ -4,6 +4,8 @@ package ethutil type Database interface { Put(key []byte, value []byte) Get(key []byte) ([]byte, error) + GetKeys() []*Key + Delete(key []byte) error LastKnownTD() []byte Close() Print() diff --git a/ethutil/encoding.go b/ethutil/encoding.go index 207548c93..1f661947a 100644 --- a/ethutil/encoding.go +++ b/ethutil/encoding.go @@ -3,7 +3,6 @@ package ethutil import ( "bytes" "encoding/hex" - _ "fmt" "strings" ) @@ -36,7 +35,7 @@ func CompactEncode(hexSlice []int) string { func CompactDecode(str string) []int { base := CompactHexDecode(str) base = base[:len(base)-1] - if base[0] >= 2 { // && base[len(base)-1] != 16 { + if base[0] >= 2 { base = append(base, 16) } if base[0]%2 == 1 { diff --git a/ethutil/encoding_test.go b/ethutil/encoding_test.go index bcabab0b1..cbfbc0eaf 100644 --- a/ethutil/encoding_test.go +++ b/ethutil/encoding_test.go @@ -35,3 +35,33 @@ func TestCompactHexDecode(t *testing.T) { t.Error("Error compact hex decode. Expected", exp, "got", res) } } + +func TestCompactDecode(t *testing.T) { + exp := []int{1, 2, 3, 4, 5} + res := CompactDecode("\x11\x23\x45") + + if !CompareIntSlice(res, exp) { + t.Error("odd compact decode. Expected", exp, "got", res) + } + + exp = []int{0, 1, 2, 3, 4, 5} + res = CompactDecode("\x00\x01\x23\x45") + + if !CompareIntSlice(res, exp) { + t.Error("even compact decode. Expected", exp, "got", res) + } + + exp = []int{0, 15, 1, 12, 11, 8 /*term*/, 16} + res = CompactDecode("\x20\x0f\x1c\xb8") + + if !CompareIntSlice(res, exp) { + t.Error("even terminated compact decode. Expected", exp, "got", res) + } + + exp = []int{15, 1, 12, 11, 8 /*term*/, 16} + res = CompactDecode("\x3f\x1c\xb8") + + if !CompareIntSlice(res, exp) { + t.Error("even terminated compact decode. Expected", exp, "got", res) + } +}
\ No newline at end of file diff --git a/ethutil/helpers.go b/ethutil/helpers.go index 1c6adf256..aa0f79a04 100644 --- a/ethutil/helpers.go +++ b/ethutil/helpers.go @@ -27,7 +27,6 @@ func Ripemd160(data []byte) []byte { func Sha3Bin(data []byte) []byte { d := sha3.NewKeccak256() - d.Reset() d.Write(data) return d.Sum(nil) @@ -59,3 +58,7 @@ func MatchingNibbleLength(a, b []int) int { func Hex(d []byte) string { return hex.EncodeToString(d) } +func FromHex(str string) []byte { + h, _ := hex.DecodeString(str) + return h +} diff --git a/ethutil/key.go b/ethutil/key.go new file mode 100644 index 000000000..ec195f213 --- /dev/null +++ b/ethutil/key.go @@ -0,0 +1,19 @@ +package ethutil + +type Key struct { + PrivateKey []byte + PublicKey []byte +} + +func NewKeyFromBytes(data []byte) *Key { + val := NewValueFromBytes(data) + return &Key{val.Get(0).Bytes(), val.Get(1).Bytes()} +} + +func (k *Key) Address() []byte { + return Sha3Bin(k.PublicKey[1:])[12:] +} + +func (k *Key) RlpEncode() []byte { + return EmptyValue().Append(k.PrivateKey).Append(k.PublicKey).Encode() +} diff --git a/ethutil/parsing.go b/ethutil/parsing.go index 2c41fb4df..553bb9717 100644 --- a/ethutil/parsing.go +++ b/ethutil/parsing.go @@ -1,95 +1,88 @@ package ethutil import ( - "errors" - "fmt" "math/big" "strconv" - "strings" ) // Op codes -var OpCodes = map[string]string{ - "STOP": "0", - "ADD": "1", - "MUL": "2", - "SUB": "3", - "DIV": "4", - "SDIV": "5", - "MOD": "6", - "SMOD": "7", - "EXP": "8", - "NEG": "9", - "LT": "10", - "LE": "11", - "GT": "12", - "GE": "13", - "EQ": "14", - "NOT": "15", - "MYADDRESS": "16", - "TXSENDER": "17", - - "PUSH": "48", - "POP": "49", - "LOAD": "54", +var OpCodes = map[string]byte{ + "STOP": 0x00, + "ADD": 0x01, + "MUL": 0x02, + "SUB": 0x03, + "DIV": 0x04, + "SDIV": 0x05, + "MOD": 0x06, + "SMOD": 0x07, + "EXP": 0x08, + "NEG": 0x09, + "LT": 0x0a, + "LE": 0x0b, + "GT": 0x0c, + "GE": 0x0d, + "EQ": 0x0e, + "NOT": 0x0f, + "MYADDRESS": 0x10, + "TXSENDER": 0x11, + "TXVALUE": 0x12, + "TXDATAN": 0x13, + "TXDATA": 0x14, + "BLK_PREVHASH": 0x15, + "BLK_COINBASE": 0x16, + "BLK_TIMESTAMP": 0x17, + "BLK_NUMBER": 0x18, + "BLK_DIFFICULTY": 0x19, + "BLK_NONCE": 0x1a, + "BASEFEE": 0x1b, + "SHA256": 0x20, + "RIPEMD160": 0x21, + "ECMUL": 0x22, + "ECADD": 0x23, + "ECSIGN": 0x24, + "ECRECOVER": 0x25, + "ECVALID": 0x26, + "SHA3": 0x27, + "PUSH": 0x30, + "POP": 0x31, + "DUP": 0x32, + "SWAP": 0x33, + "MLOAD": 0x34, + "MSTORE": 0x35, + "SLOAD": 0x36, + "SSTORE": 0x37, + "JMP": 0x38, + "JMPI": 0x39, + "IND": 0x3a, + "EXTRO": 0x3b, + "BALANCE": 0x3c, + "MKTX": 0x3d, + "SUICIDE": 0x3f, } -func CompileInstr(s string) (string, error) { - tokens := strings.Split(s, " ") - if OpCodes[tokens[0]] == "" { - return s, errors.New(fmt.Sprintf("OP not found: %s", tokens[0])) +func IsOpCode(s string) bool { + for key, _ := range OpCodes { + if key == s { + return true + } } + return false +} - code := OpCodes[tokens[0]] // Replace op codes with the proper numerical equivalent - op := new(big.Int) - op.SetString(code, 0) - - args := make([]*big.Int, 6) - for i, val := range tokens[1:len(tokens)] { - num := new(big.Int) - num.SetString(val, 0) - args[i] = num - } - - // Big int equation = op + x * 256 + y * 256**2 + z * 256**3 + a * 256**4 + b * 256**5 + c * 256**6 - base := new(big.Int) - x := new(big.Int) - y := new(big.Int) - z := new(big.Int) - a := new(big.Int) - b := new(big.Int) - c := new(big.Int) - - if args[0] != nil { - x.Mul(args[0], big.NewInt(256)) - } - if args[1] != nil { - y.Mul(args[1], BigPow(256, 2)) - } - if args[2] != nil { - z.Mul(args[2], BigPow(256, 3)) - } - if args[3] != nil { - a.Mul(args[3], BigPow(256, 4)) - } - if args[4] != nil { - b.Mul(args[4], BigPow(256, 5)) - } - if args[5] != nil { - c.Mul(args[5], BigPow(256, 6)) +func CompileInstr(s string) ([]byte, error) { + isOp := IsOpCode(s) + if isOp { + return []byte{OpCodes[s]}, nil } - base.Add(op, x) - base.Add(base, y) - base.Add(base, z) - base.Add(base, a) - base.Add(base, b) - base.Add(base, c) + num := new(big.Int) + num.SetString(s, 0) - return base.String(), nil + return num.Bytes(), nil } func Instr(instr string) (int, []string, error) { + base := new(big.Int) base.SetString(instr, 0) diff --git a/ethutil/parsing_test.go b/ethutil/parsing_test.go index 482eef3ee..6b59777e6 100644 --- a/ethutil/parsing_test.go +++ b/ethutil/parsing_test.go @@ -1,5 +1,6 @@ package ethutil +/* import ( "math" "testing" @@ -13,20 +14,19 @@ func TestCompile(t *testing.T) { } calc := (48 + 0*256 + 0*int64(math.Pow(256, 2))) - if Big(instr).Int64() != calc { + if BigD(instr).Int64() != calc { t.Error("Expected", calc, ", got:", instr) } } func TestValidInstr(t *testing.T) { - /* op, args, err := Instr("68163") if err != nil { t.Error("Error decoding instruction") } - */ } func TestInvalidInstr(t *testing.T) { } +*/ diff --git a/ethutil/rlp.go b/ethutil/rlp.go index 025d269a0..e633f5f1d 100644 --- a/ethutil/rlp.go +++ b/ethutil/rlp.go @@ -86,13 +86,6 @@ func DecodeWithReader(reader *bytes.Buffer) interface{} { // TODO Use a bytes.Buffer instead of a raw byte slice. // Cleaner code, and use draining instead of seeking the next bytes to read func Decode(data []byte, pos uint64) (interface{}, uint64) { - /* - if pos > uint64(len(data)-1) { - log.Println(data) - log.Panicf("index out of range %d for data %q, l = %d", pos, data, len(data)) - } - */ - var slice []interface{} char := int(data[pos]) switch { @@ -131,7 +124,6 @@ func Decode(data []byte, pos uint64) (interface{}, uint64) { case char <= 0xff: l := uint64(data[pos]) - 0xf7 - //b := BigD(data[pos+1 : pos+1+l]).Uint64() b := ReadVarint(bytes.NewReader(data[pos+1 : pos+1+l])) pos = pos + l + 1 diff --git a/ethutil/rlp_test.go b/ethutil/rlp_test.go index 32bcbdce1..2a58bfc0f 100644 --- a/ethutil/rlp_test.go +++ b/ethutil/rlp_test.go @@ -2,15 +2,13 @@ package ethutil import ( "bytes" - "encoding/hex" - "fmt" "math/big" "reflect" "testing" ) func TestRlpValueEncoding(t *testing.T) { - val := EmptyRlpValue() + val := EmptyValue() val.AppendList().Append(1).Append(2).Append(3) val.Append("4").AppendList().Append(5) @@ -63,7 +61,7 @@ func TestEncode(t *testing.T) { str := string(bytes) if str != strRes { - t.Error(fmt.Sprintf("Expected %q, got %q", strRes, str)) + t.Errorf("Expected %q, got %q", strRes, str) } sliceRes := "\xcc\x83dog\x83god\x83cat" @@ -71,7 +69,7 @@ func TestEncode(t *testing.T) { bytes = Encode(strs) slice := string(bytes) if slice != sliceRes { - t.Error(fmt.Sprintf("Expected %q, got %q", sliceRes, slice)) + t.Error("Expected %q, got %q", sliceRes, slice) } intRes := "\x82\x04\x00" @@ -108,13 +106,9 @@ func TestEncodeDecodeBigInt(t *testing.T) { encoded := Encode(bigInt) value := NewValueFromBytes(encoded) - fmt.Println(value.BigInt(), bigInt) if value.BigInt().Cmp(bigInt) != 0 { t.Errorf("Expected %v, got %v", bigInt, value.BigInt()) } - - dec, _ := hex.DecodeString("52f4fc1e") - fmt.Println(NewValueFromBytes(dec).BigInt()) } func TestEncodeDecodeBytes(t *testing.T) { @@ -125,43 +119,6 @@ func TestEncodeDecodeBytes(t *testing.T) { } } -/* -var ZeroHash256 = make([]byte, 32) -var ZeroHash160 = make([]byte, 20) -var EmptyShaList = Sha3Bin(Encode([]interface{}{})) - -var GenisisHeader = []interface{}{ - // Previous hash (none) - //"", - ZeroHash256, - // Sha of uncles - Sha3Bin(Encode([]interface{}{})), - // Coinbase - ZeroHash160, - // Root state - "", - // Sha of transactions - //EmptyShaList, - Sha3Bin(Encode([]interface{}{})), - // Difficulty - BigPow(2, 22), - // Time - //big.NewInt(0), - int64(0), - // extra - "", - // Nonce - big.NewInt(42), -} - -func TestEnc(t *testing.T) { - //enc := Encode(GenisisHeader) - //fmt.Printf("%x (%d)\n", enc, len(enc)) - h, _ := hex.DecodeString("f8a0a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347940000000000000000000000000000000000000000a06d076baa9c4074fb2df222dd16a96b0155a1e6686b3e5748b4e9ca0a208a425ca01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d493478340000080802a") - fmt.Printf("%x\n", Sha3Bin(h)) -} -*/ - func BenchmarkEncodeDecode(b *testing.B) { for i := 0; i < b.N; i++ { bytes := Encode([]interface{}{"dog", "god", "cat"}) diff --git a/ethutil/trie.go b/ethutil/trie.go index 0a3f73136..a17dc37ad 100644 --- a/ethutil/trie.go +++ b/ethutil/trie.go @@ -5,6 +5,15 @@ import ( "reflect" ) +// TODO +// A StateObject is an object that has a state root +// This is goig to be the object for the second level caching (the caching of object which have a state such as contracts) +type StateObject interface { + State() *Trie + Sync() + Undo() +} + type Node struct { Key []byte Value *Value @@ -20,8 +29,9 @@ func (n *Node) Copy() *Node { } type Cache struct { - nodes map[string]*Node - db Database + nodes map[string]*Node + db Database + IsDirty bool } func NewCache(db Database) *Cache { @@ -36,6 +46,7 @@ func (cache *Cache) Put(v interface{}) interface{} { sha := Sha3Bin(enc) cache.nodes[string(sha)] = NewNode(sha, value, true) + cache.IsDirty = true return sha } @@ -59,13 +70,25 @@ func (cache *Cache) Get(key []byte) *Value { return value } +func (cache *Cache) Delete(key []byte) { + delete(cache.nodes, string(key)) + + cache.db.Delete(key) +} + func (cache *Cache) Commit() { + // Don't try to commit if it isn't dirty + if !cache.IsDirty { + return + } + for key, node := range cache.nodes { if node.Dirty { cache.db.Put([]byte(key), node.Value.Encode()) node.Dirty = false } } + cache.IsDirty = false // If the nodes grows beyond the 200 entries we simple empty it // FIXME come up with something better @@ -80,6 +103,7 @@ func (cache *Cache) Undo() { delete(cache.nodes, key) } } + cache.IsDirty = false } // A (modified) Radix Trie implementation. The Trie implements @@ -89,18 +113,29 @@ func (cache *Cache) Undo() { // Please note that the data isn't persisted unless `Sync` is // explicitly called. type Trie struct { - Root interface{} + prevRoot interface{} + Root interface{} //db Database cache *Cache } func NewTrie(db Database, Root interface{}) *Trie { - return &Trie{cache: NewCache(db), Root: Root} + return &Trie{cache: NewCache(db), Root: Root, prevRoot: Root} } // Save the cached value to the database. func (t *Trie) Sync() { t.cache.Commit() + t.prevRoot = t.Root +} + +func (t *Trie) Undo() { + t.cache.Undo() + t.Root = t.prevRoot +} + +func (t *Trie) Cache() *Cache { + return t.cache } /* @@ -119,6 +154,10 @@ func (t *Trie) Get(key string) string { return c.Str() } +func (t *Trie) Delete(key string) { + t.Update(key, "") +} + func (t *Trie) GetState(node interface{}, key []int) interface{} { n := NewValue(node) // Return the node if key is empty (= found) @@ -168,13 +207,15 @@ func (t *Trie) GetNode(node interface{}) *Value { } func (t *Trie) UpdateState(node interface{}, key []int, value string) interface{} { + if value != "" { return t.InsertState(node, key, value) } else { // delete it + return t.DeleteState(node, key) } - return "" + return t.Root } func (t *Trie) Put(node interface{}) interface{} { @@ -228,6 +269,7 @@ func (t *Trie) InsertState(node interface{}, key []int, value interface{}) inter // Check for "special" 2 slice type node if currentNode.Len() == 2 { // Decode the key + k := CompactDecode(currentNode.Get(0).Str()) v := currentNode.Get(1).Raw() @@ -282,6 +324,87 @@ func (t *Trie) InsertState(node interface{}, key []int, value interface{}) inter return "" } +func (t *Trie) DeleteState(node interface{}, key []int) interface{} { + if len(key) == 0 { + return "" + } + + // New node + n := NewValue(node) + if node == nil || (n.Type() == reflect.String && (n.Str() == "" || n.Get(0).IsNil())) || n.Len() == 0 { + return "" + } + + currentNode := t.GetNode(node) + // Check for "special" 2 slice type node + if currentNode.Len() == 2 { + // Decode the key + k := CompactDecode(currentNode.Get(0).Str()) + v := currentNode.Get(1).Raw() + + // Matching key pair (ie. there's already an object with this key) + if CompareIntSlice(k, key) { + return "" + } else if CompareIntSlice(key[:len(k)], k) { + hash := t.DeleteState(v, key[len(k):]) + child := t.GetNode(hash) + + var newNode []interface{} + if child.Len() == 2 { + newKey := append(k, CompactDecode(child.Get(0).Str())...) + newNode = []interface{}{CompactEncode(newKey), child.Get(1).Raw()} + } else { + newNode = []interface{}{currentNode.Get(0).Str(), hash} + } + + return t.Put(newNode) + } else { + return node + } + } else { + // Copy the current node over to the new node and replace the first nibble in the key + n := EmptyStringSlice(17) + var newNode []interface{} + + for i := 0; i < 17; i++ { + cpy := currentNode.Get(i).Raw() + if cpy != nil { + n[i] = cpy + } + } + + n[key[0]] = t.DeleteState(n[key[0]], key[1:]) + amount := -1 + for i := 0; i < 17; i++ { + if n[i] != "" { + if amount == -1 { + amount = i + } else { + amount = -2 + } + } + } + if amount == 16 { + newNode = []interface{}{CompactEncode([]int{16}), n[amount]} + } else if amount >= 0 { + child := t.GetNode(n[amount]) + if child.Len() == 17 { + newNode = []interface{}{CompactEncode([]int{amount}), n[amount]} + } else if child.Len() == 2 { + key := append([]int{amount}, CompactDecode(child.Get(0).Str())...) + newNode = []interface{}{CompactEncode(key), child.Get(1).Str()} + } + + } else { + newNode = n + } + + return t.Put(newNode) + } + + return "" +} + // Simple compare function which creates a rlp value out of the evaluated objects func (t *Trie) Cmp(trie *Trie) bool { return NewValue(t.Root).Cmp(NewValue(trie.Root)) @@ -296,3 +419,82 @@ func (t *Trie) Copy() *Trie { return trie } + +type TrieIterator struct { + trie *Trie + key string + value string + + shas [][]byte + values []string +} + +func (t *Trie) NewIterator() *TrieIterator { + return &TrieIterator{trie: t} +} + +// Some time in the near future this will need refactoring :-) +// XXX Note to self, IsSlice == inline node. Str == sha3 to node +func (it *TrieIterator) workNode(currentNode *Value) { + if currentNode.Len() == 2 { + k := CompactDecode(currentNode.Get(0).Str()) + + if currentNode.Get(1).Str() == "" { + it.workNode(currentNode.Get(1)) + } else { + if k[len(k)-1] == 16 { + it.values = append(it.values, currentNode.Get(1).Str()) + } else { + it.shas = append(it.shas, currentNode.Get(1).Bytes()) + it.getNode(currentNode.Get(1).Bytes()) + } + } + } else { + for i := 0; i < currentNode.Len(); i++ { + if i == 16 && currentNode.Get(i).Len() != 0 { + it.values = append(it.values, currentNode.Get(i).Str()) + } else { + if currentNode.Get(i).Str() == "" { + it.workNode(currentNode.Get(i)) + } else { + val := currentNode.Get(i).Str() + if val != "" { + it.shas = append(it.shas, currentNode.Get(1).Bytes()) + it.getNode([]byte(val)) + } + } + } + } + } +} + +func (it *TrieIterator) getNode(node []byte) { + currentNode := it.trie.cache.Get(node) + it.workNode(currentNode) +} + +func (it *TrieIterator) Collect() [][]byte { + if it.trie.Root == "" { + return nil + } + + it.getNode(NewValue(it.trie.Root).Bytes()) + + return it.shas +} + +func (it *TrieIterator) Purge() int { + shas := it.Collect() + for _, sha := range shas { + it.trie.cache.Delete(sha) + } + return len(it.values) +} + +func (it *TrieIterator) Key() string { + return "" +} + +func (it *TrieIterator) Value() string { + return "" +} diff --git a/ethutil/trie_test.go b/ethutil/trie_test.go index b87d35e1a..7c398f1de 100644 --- a/ethutil/trie_test.go +++ b/ethutil/trie_test.go @@ -1,11 +1,12 @@ package ethutil import ( - _ "encoding/hex" - _ "fmt" + "reflect" "testing" ) +const LONG_WORD = "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ" + type MemDatabase struct { db map[string][]byte } @@ -20,15 +21,24 @@ func (db *MemDatabase) Put(key []byte, value []byte) { func (db *MemDatabase) Get(key []byte) ([]byte, error) { return db.db[string(key)], nil } +func (db *MemDatabase) Delete(key []byte) error { + delete(db.db, string(key)) + return nil +} +func (db *MemDatabase) GetKeys() []*Key { return nil } func (db *MemDatabase) Print() {} func (db *MemDatabase) Close() {} func (db *MemDatabase) LastKnownTD() []byte { return nil } -func TestTrieSync(t *testing.T) { +func New() (*MemDatabase, *Trie) { db, _ := NewMemDatabase() - trie := NewTrie(db, "") + return db, NewTrie(db, "") +} - trie.Update("dog", "kindofalongsentencewhichshouldbeencodedinitsentirety") +func TestTrieSync(t *testing.T) { + db, trie := New() + + trie.Update("dog", LONG_WORD) if len(db.db) != 0 { t.Error("Expected no data in database") } @@ -38,3 +48,125 @@ func TestTrieSync(t *testing.T) { t.Error("Expected data to be persisted") } } + +func TestTrieDirtyTracking(t *testing.T) { + _, trie := New() + trie.Update("dog", LONG_WORD) + if !trie.cache.IsDirty { + t.Error("Expected trie to be dirty") + } + + trie.Sync() + if trie.cache.IsDirty { + t.Error("Expected trie not to be dirty") + } + + trie.Update("test", LONG_WORD) + trie.cache.Undo() + if trie.cache.IsDirty { + t.Error("Expected trie not to be dirty") + } + +} + +func TestTrieReset(t *testing.T) { + _, trie := New() + + trie.Update("cat", LONG_WORD) + if len(trie.cache.nodes) == 0 { + t.Error("Expected cached nodes") + } + + trie.cache.Undo() + + if len(trie.cache.nodes) != 0 { + t.Error("Expected no nodes after undo") + } +} + +func TestTrieGet(t *testing.T) { + _, trie := New() + + trie.Update("cat", LONG_WORD) + x := trie.Get("cat") + if x != LONG_WORD { + t.Error("expected %s, got %s", LONG_WORD, x) + } +} + +func TestTrieUpdating(t *testing.T) { + _, trie := New() + trie.Update("cat", LONG_WORD) + trie.Update("cat", LONG_WORD+"1") + x := trie.Get("cat") + if x != LONG_WORD+"1" { + t.Error("expected %S, got %s", LONG_WORD+"1", x) + } +} + +func TestTrieCmp(t *testing.T) { + _, trie1 := New() + _, trie2 := New() + + trie1.Update("doge", LONG_WORD) + trie2.Update("doge", LONG_WORD) + if !trie1.Cmp(trie2) { + t.Error("Expected tries to be equal") + } + + trie1.Update("dog", LONG_WORD) + trie2.Update("cat", LONG_WORD) + if trie1.Cmp(trie2) { + t.Errorf("Expected tries not to be equal %x %x", trie1.Root, trie2.Root) + } +} + +func TestTrieDelete(t *testing.T) { + _, trie := New() + trie.Update("cat", LONG_WORD) + exp := trie.Root + trie.Update("dog", LONG_WORD) + trie.Delete("dog") + if !reflect.DeepEqual(exp, trie.Root) { + t.Errorf("Expected tries to be equal %x : %x", exp, trie.Root) + } + + trie.Update("dog", LONG_WORD) + exp = trie.Root + trie.Update("dude", LONG_WORD) + trie.Delete("dude") + if !reflect.DeepEqual(exp, trie.Root) { + t.Errorf("Expected tries to be equal %x : %x", exp, trie.Root) + } +} + +func TestTrieDeleteWithValue(t *testing.T) { + _, trie := New() + trie.Update("c", LONG_WORD) + exp := trie.Root + trie.Update("ca", LONG_WORD) + trie.Update("cat", LONG_WORD) + trie.Delete("ca") + trie.Delete("cat") + if !reflect.DeepEqual(exp, trie.Root) { + t.Errorf("Expected tries to be equal %x : %x", exp, trie.Root) + } + +} + +func TestTrieIterator(t *testing.T) { + _, trie := New() + trie.Update("c", LONG_WORD) + trie.Update("ca", LONG_WORD) + trie.Update("cat", LONG_WORD) + + lenBefore := len(trie.cache.nodes) + it := trie.NewIterator() + if num := it.Purge(); num != 3 { + t.Errorf("Expected purge to return 3, got %d", num) + } + + if lenBefore == len(trie.cache.nodes) { + t.Errorf("Expected cached nodes to be deleted") + } +} diff --git a/ethutil/value.go b/ethutil/value.go index 2a990783e..3dd84d12d 100644 --- a/ethutil/value.go +++ b/ethutil/value.go @@ -36,7 +36,8 @@ func (val *Value) Len() int { if data, ok := val.Val.([]interface{}); ok { return len(data) } else if data, ok := val.Val.([]byte); ok { - // FIXME + return len(data) + } else if data, ok := val.Val.(string); ok { return len(data) } @@ -60,6 +61,10 @@ func (val *Value) Uint() uint64 { return uint64(Val) } else if Val, ok := val.Val.(uint64); ok { return Val + } else if Val, ok := val.Val.(int); ok { + return uint64(Val) + } else if Val, ok := val.Val.(uint); ok { + return uint64(Val) } else if Val, ok := val.Val.([]byte); ok { return ReadVarint(bytes.NewReader(Val)) } @@ -80,6 +85,8 @@ func (val *Value) BigInt() *big.Int { b := new(big.Int).SetBytes(a) return b + } else if a, ok := val.Val.(*big.Int); ok { + return a } else { return big.NewInt(int64(val.Uint())) } @@ -92,6 +99,8 @@ func (val *Value) Str() string { return string(a) } else if a, ok := val.Val.(string); ok { return a + } else if a, ok := val.Val.(byte); ok { + return string(a) } return "" @@ -102,7 +111,7 @@ func (val *Value) Bytes() []byte { return a } - return make([]byte, 0) + return []byte{} } func (val *Value) Slice() []interface{} { @@ -131,6 +140,19 @@ func (val *Value) SliceFromTo(from, to int) *Value { return NewValue(slice[from:to]) } +// TODO More type checking methods +func (val *Value) IsSlice() bool { + return val.Type() == reflect.Slice +} + +func (val *Value) IsStr() bool { + return val.Type() == reflect.String +} + +func (val *Value) IsEmpty() bool { + return val.Val == nil || ((val.IsSlice() || val.IsStr()) && val.Len() == 0) +} + // Threat the value as a slice func (val *Value) Get(idx int) *Value { if d, ok := val.Val.([]interface{}); ok { @@ -140,7 +162,7 @@ func (val *Value) Get(idx int) *Value { } if idx < 0 { - panic("negative idx for Rlp Get") + panic("negative idx for Value Get") } return NewValue(d[idx]) @@ -158,9 +180,9 @@ func (val *Value) Encode() []byte { return Encode(val.Val) } -func NewValueFromBytes(rlpData []byte) *Value { - if len(rlpData) != 0 { - data, _ := Decode(rlpData, 0) +func NewValueFromBytes(data []byte) *Value { + if len(data) != 0 { + data, _ := Decode(data, 0) return NewValue(data) } diff --git a/ethutil/value_test.go b/ethutil/value_test.go new file mode 100644 index 000000000..0e2da5328 --- /dev/null +++ b/ethutil/value_test.go @@ -0,0 +1,52 @@ +package ethutil + +import ( + "bytes" + "math/big" + "testing" +) + +func TestValueCmp(t *testing.T) { + val1 := NewValue("hello") + val2 := NewValue("world") + if val1.Cmp(val2) { + t.Error("Expected values not to be equal") + } + + val3 := NewValue("hello") + val4 := NewValue("hello") + if !val3.Cmp(val4) { + t.Error("Expected values to be equal") + } +} + +func TestValueTypes(t *testing.T) { + str := NewValue("str") + num := NewValue(1) + inter := NewValue([]interface{}{1}) + byt := NewValue([]byte{1, 2, 3, 4}) + bigInt := NewValue(big.NewInt(10)) + + if str.Str() != "str" { + t.Errorf("expected Str to return 'str', got %s", str.Str()) + } + + if num.Uint() != 1 { + t.Errorf("expected Uint to return '1', got %d", num.Uint()) + } + + interExp := []interface{}{1} + if !NewValue(inter.Interface()).Cmp(NewValue(interExp)) { + t.Errorf("expected Interface to return '%v', got %v", interExp, num.Interface()) + } + + bytExp := []byte{1, 2, 3, 4} + if bytes.Compare(byt.Bytes(), bytExp) != 0 { + t.Errorf("expected Bytes to return '%v', got %v", bytExp, byt.Bytes()) + } + + bigExp := big.NewInt(10) + if bigInt.BigInt().Cmp(bigExp) != 0 { + t.Errorf("expected BigInt to return '%v', got %v", bigExp, bigInt.BigInt()) + } +} |