diff options
author | Javier Peletier <jm@epiclabs.io> | 2018-03-05 23:00:03 +0800 |
---|---|---|
committer | Javier Peletier <jm@epiclabs.io> | 2018-03-05 23:00:03 +0800 |
commit | 13b566e06e9aae28bddde431c7d53a335272411a (patch) | |
tree | b649ab64ca2c00327500aed5c826d5a6f454a6cf | |
parent | 1e72271f571f916691c5c18b8f0c4c5f7e0445c3 (diff) | |
parent | 1548518644071c8fa8eb98a8cb8a8c4603400acb (diff) | |
download | go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar.gz go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar.bz2 go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar.lz go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar.xz go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.tar.zst go-tangerine-13b566e06e9aae28bddde431c7d53a335272411a.zip |
accounts/abi: Add one-parameter event test case from enriquefynn/unpack_one_arg_event
256 files changed, 14983 insertions, 4481 deletions
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6076fe46a..a7b617655 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,5 +5,7 @@ accounts/usbwallet @karalabe consensus @karalabe core/ @karalabe @holiman eth/ @karalabe +les/ @zsfelfoldi +light/ @zsfelfoldi mobile/ @karalabe p2p/ @fjl @zsfelfoldi diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 000000000..c621939c3 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,17 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 366 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 42 +# Issues with these labels will never be considered stale +exemptLabels: + - pinned + - security +# Label to use when marking an issue as stale +staleLabel: stale +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: false diff --git a/.gitignore b/.gitignore index 0763d8492..b8d292901 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,9 @@ profile.cov # IdeaIDE .idea +# VS Code +.vscode + # dashboard /dashboard/assets/flow-typed /dashboard/assets/node_modules diff --git a/.travis.yml b/.travis.yml index 3941fa785..a76a78954 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,17 +6,6 @@ matrix: - os: linux dist: trusty sudo: required - go: 1.7.x - script: - - sudo modprobe fuse - - sudo chmod 666 /dev/fuse - - sudo chown root:$USER /etc/fuse.conf - - go run build/ci.go install - - go run build/ci.go test -coverage - - - os: linux - dist: trusty - sudo: required go: 1.8.x script: - sudo modprobe fuse @@ -5,6 +5,8 @@ Official golang implementation of the Ethereum protocol. [![API Reference]( https://camo.githubusercontent.com/915b7be44ada53c290eb157634330494ebe3e30a/68747470733a2f2f676f646f632e6f72672f6769746875622e636f6d2f676f6c616e672f6764646f3f7374617475732e737667 )](https://godoc.org/github.com/ethereum/go-ethereum) +[![Go Report Card](https://goreportcard.com/badge/github.com/ethereum/go-ethereum)](https://goreportcard.com/report/github.com/ethereum/go-ethereum) +[![Travis](https://travis-ci.org/ethereum/go-ethereum.svg?branch=master)](https://travis-ci.org/ethereum/go-ethereum) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/ethereum/go-ethereum?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) Automated builds are available for stable releases and the unstable master branch. @@ -1 +1 @@ -1.8.1 +1.8.3 diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 65e69d064..da2ef9178 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -136,11 +136,11 @@ func (abi *ABI) UnmarshalJSON(data []byte) error { // MethodById looks up a method by the 4-byte id // returns nil if none found -func (abi *ABI) MethodById(sigdata []byte) *Method { +func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { for _, method := range abi.Methods { if bytes.Equal(method.Id(), sigdata[:4]) { - return &method + return &method, nil } } - return nil + return nil, fmt.Errorf("no method with id: %#x", sigdata[:4]) } diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 325f33a82..8018df775 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -702,7 +702,11 @@ func TestABI_MethodById(t *testing.T) { } for name, m := range abi.Methods { a := fmt.Sprintf("%v", m) - b := fmt.Sprintf("%v", abi.MethodById(m.Id())) + m2, err := abi.MethodById(m.Id()) + if err != nil { + t.Fatalf("Failed to look up ABI method: %v", err) + } + b := fmt.Sprintf("%v", m2) if a != b { t.Errorf("Method %v (id %v) not 'findable' by id in ABI", name, common.ToHex(m.Id())) } diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index 04ca6150a..1b480da60 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -67,6 +67,17 @@ func (arguments Arguments) LengthNonIndexed() int { return out } +// NonIndexed returns the arguments with indexed arguments filtered out +func (arguments Arguments) NonIndexed() Arguments { + var ret []Argument + for _, arg := range arguments { + if !arg.Indexed { + ret = append(ret, arg) + } + } + return ret +} + // isTuple returns true for non-atomic constructs, like (uint,uint) or uint[] func (arguments Arguments) isTuple() bool { return len(arguments) > 1 @@ -74,21 +85,25 @@ func (arguments Arguments) isTuple() bool { // Unpack performs the operation hexdata -> Go format func (arguments Arguments) Unpack(v interface{}, data []byte) error { - if arguments.isTuple() { - return arguments.unpackTuple(v, data) - } - return arguments.unpackAtomic(v, data) -} -func (arguments Arguments) unpackTuple(v interface{}, output []byte) error { // make sure the passed value is arguments pointer - valueOf := reflect.ValueOf(v) - if reflect.Ptr != valueOf.Kind() { + if reflect.Ptr != reflect.ValueOf(v).Kind() { return fmt.Errorf("abi: Unpack(non-pointer %T)", v) } + marshalledValues, err := arguments.UnpackValues(data) + if err != nil { + return err + } + if arguments.isTuple() { + return arguments.unpackTuple(v, marshalledValues) + } + return arguments.unpackAtomic(v, marshalledValues) +} + +func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { var ( - value = valueOf.Elem() + value = reflect.ValueOf(v).Elem() typ = value.Type() kind = value.Kind() ) @@ -110,30 +125,9 @@ func (arguments Arguments) unpackTuple(v interface{}, output []byte) error { exists[field] = true } } - // `i` counts the nonindexed arguments. - // `j` counts the number of complex types. - // both `i` and `j` are used to to correctly compute `data` offset. + for i, arg := range arguments.NonIndexed() { - i, j := -1, 0 - for _, arg := range arguments { - - if arg.Indexed { - // can't read, continue - continue - } - i++ - marshalledValue, err := toGoType((i+j)*32, arg.Type, output) - if err != nil { - return err - } - - if arg.Type.T == ArrayTy { - // combined index ('i' + 'j') need to be adjusted only by size of array, thus - // we need to decrement 'j' because 'i' was incremented - j += arg.Type.Size - 1 - } - - reflectValue := reflect.ValueOf(marshalledValue) + reflectValue := reflect.ValueOf(marshalledValues[i]) switch kind { case reflect.Struct: @@ -166,34 +160,72 @@ func (arguments Arguments) unpackTuple(v interface{}, output []byte) error { } // unpackAtomic unpacks ( hexdata -> go ) a single value -func (arguments Arguments) unpackAtomic(v interface{}, output []byte) error { - // make sure the passed value is arguments pointer - valueOf := reflect.ValueOf(v) - if reflect.Ptr != valueOf.Kind() { - return fmt.Errorf("abi: Unpack(non-pointer %T)", v) - } - arg := arguments[0] - if arg.Indexed { - return fmt.Errorf("abi: attempting to unpack indexed variable into element.") +func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues []interface{}) error { + if len(marshalledValues) != 1 { + return fmt.Errorf("abi: wrong length, expected single value, got %d", len(marshalledValues)) } + elem := reflect.ValueOf(v).Elem() + reflectValue := reflect.ValueOf(marshalledValues[0]) + return set(elem, reflectValue, arguments.NonIndexed()[0]) +} - value := valueOf.Elem() +// Computes the full size of an array; +// i.e. counting nested arrays, which count towards size for unpacking. +func getArraySize(arr *Type) int { + size := arr.Size + // Arrays can be nested, with each element being the same size + arr = arr.Elem + for arr.T == ArrayTy { + // Keep multiplying by elem.Size while the elem is an array. + size *= arr.Size + arr = arr.Elem + } + // Now we have the full array size, including its children. + return size +} - marshalledValue, err := toGoType(0, arg.Type, output) - if err != nil { - return err +// UnpackValues can be used to unpack ABI-encoded hexdata according to the ABI-specification, +// without supplying a struct to unpack into. Instead, this method returns a list containing the +// values. An atomic argument will be a list with one element. +func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) { + retval := make([]interface{}, 0, arguments.LengthNonIndexed()) + virtualArgs := 0 + for index, arg := range arguments.NonIndexed() { + marshalledValue, err := toGoType((index+virtualArgs)*32, arg.Type, data) + if arg.Type.T == ArrayTy { + // If we have a static array, like [3]uint256, these are coded as + // just like uint256,uint256,uint256. + // This means that we need to add two 'virtual' arguments when + // we count the index from now on. + // + // Array values nested multiple levels deep are also encoded inline: + // [2][3]uint256: uint256,uint256,uint256,uint256,uint256,uint256 + // + // Calculate the full array size to get the correct offset for the next argument. + // Decrement it by 1, as the normal index increment is still applied. + virtualArgs += getArraySize(&arg.Type) - 1 + } + if err != nil { + return nil, err + } + retval = append(retval, marshalledValue) } - return set(value, reflect.ValueOf(marshalledValue), arg) + return retval, nil } -// Unpack performs the operation Go format -> Hexdata +// PackValues performs the operation Go format -> Hexdata +// It is the semantic opposite of UnpackValues +func (arguments Arguments) PackValues(args []interface{}) ([]byte, error) { + return arguments.Pack(args...) +} + +// Pack performs the operation Go format -> Hexdata func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { // Make sure arguments match up and pack them abiArgs := arguments if len(args) != len(abiArgs) { return nil, fmt.Errorf("argument count mismatch: %d for %d", len(args), len(abiArgs)) } - // variable input is the output appended at the end of packed // output. This is used for strings and bytes types input. var variableInput []byte @@ -207,7 +239,6 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { inputOffset += 32 } } - var ret []byte for i, a := range args { input := abiArgs[i] @@ -216,7 +247,6 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { if err != nil { return nil, err } - // check for a slice type (string, bytes, slice) if input.Type.requiresLengthPrefix() { // calculate the offset diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index bd342a8cb..fe7dea4da 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -428,10 +428,23 @@ func (fb *filterBackend) HeaderByNumber(ctx context.Context, block rpc.BlockNumb } return fb.bc.GetHeaderByNumber(uint64(block.Int64())), nil } + func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)), nil } +func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types.Log, error) { + receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)) + if receipts == nil { + return nil, nil + } + logs := make([][]*types.Log, len(receipts)) + for i, receipt := range receipts { + logs[i] = receipt.Logs + } + return logs, nil +} + func (fb *filterBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { return event.NewSubscription(func(quit <-chan struct{}) error { <-quit diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go index e31b45481..7fdd2c624 100644 --- a/accounts/abi/bind/bind.go +++ b/accounts/abi/bind/bind.go @@ -164,118 +164,147 @@ var bindType = map[Lang]func(kind abi.Type) string{ LangJava: bindTypeJava, } +// Helper function for the binding generators. +// It reads the unmatched characters after the inner type-match, +// (since the inner type is a prefix of the total type declaration), +// looks for valid arrays (possibly a dynamic one) wrapping the inner type, +// and returns the sizes of these arrays. +// +// Returned array sizes are in the same order as solidity signatures; inner array size first. +// Array sizes may also be "", indicating a dynamic array. +func wrapArray(stringKind string, innerLen int, innerMapping string) (string, []string) { + remainder := stringKind[innerLen:] + //find all the sizes + matches := regexp.MustCompile(`\[(\d*)\]`).FindAllStringSubmatch(remainder, -1) + parts := make([]string, 0, len(matches)) + for _, match := range matches { + //get group 1 from the regex match + parts = append(parts, match[1]) + } + return innerMapping, parts +} + +// Translates the array sizes to a Go-lang declaration of a (nested) array of the inner type. +// Simply returns the inner type if arraySizes is empty. +func arrayBindingGo(inner string, arraySizes []string) string { + out := "" + //prepend all array sizes, from outer (end arraySizes) to inner (start arraySizes) + for i := len(arraySizes) - 1; i >= 0; i-- { + out += "[" + arraySizes[i] + "]" + } + out += inner + return out +} + // bindTypeGo converts a Solidity type to a Go one. Since there is no clear mapping // from all Solidity types to Go ones (e.g. uint17), those that cannot be exactly // mapped will use an upscaled type (e.g. *big.Int). func bindTypeGo(kind abi.Type) string { stringKind := kind.String() + innerLen, innerMapping := bindUnnestedTypeGo(stringKind) + return arrayBindingGo(wrapArray(stringKind, innerLen, innerMapping)) +} + +// The inner function of bindTypeGo, this finds the inner type of stringKind. +// (Or just the type itself if it is not an array or slice) +// The length of the matched part is returned, with the the translated type. +func bindUnnestedTypeGo(stringKind string) (int, string) { switch { case strings.HasPrefix(stringKind, "address"): - parts := regexp.MustCompile(`address(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 2 { - return stringKind - } - return fmt.Sprintf("%scommon.Address", parts[1]) + return len("address"), "common.Address" case strings.HasPrefix(stringKind, "bytes"): - parts := regexp.MustCompile(`bytes([0-9]*)(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 3 { - return stringKind - } - return fmt.Sprintf("%s[%s]byte", parts[2], parts[1]) + parts := regexp.MustCompile(`bytes([0-9]*)`).FindStringSubmatch(stringKind) + return len(parts[0]), fmt.Sprintf("[%s]byte", parts[1]) case strings.HasPrefix(stringKind, "int") || strings.HasPrefix(stringKind, "uint"): - parts := regexp.MustCompile(`(u)?int([0-9]*)(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 4 { - return stringKind - } + parts := regexp.MustCompile(`(u)?int([0-9]*)`).FindStringSubmatch(stringKind) switch parts[2] { case "8", "16", "32", "64": - return fmt.Sprintf("%s%sint%s", parts[3], parts[1], parts[2]) + return len(parts[0]), fmt.Sprintf("%sint%s", parts[1], parts[2]) } - return fmt.Sprintf("%s*big.Int", parts[3]) + return len(parts[0]), "*big.Int" - case strings.HasPrefix(stringKind, "bool") || strings.HasPrefix(stringKind, "string"): - parts := regexp.MustCompile(`([a-z]+)(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 3 { - return stringKind - } - return fmt.Sprintf("%s%s", parts[2], parts[1]) + case strings.HasPrefix(stringKind, "bool"): + return len("bool"), "bool" + + case strings.HasPrefix(stringKind, "string"): + return len("string"), "string" default: - return stringKind + return len(stringKind), stringKind } } +// Translates the array sizes to a Java declaration of a (nested) array of the inner type. +// Simply returns the inner type if arraySizes is empty. +func arrayBindingJava(inner string, arraySizes []string) string { + // Java array type declarations do not include the length. + return inner + strings.Repeat("[]", len(arraySizes)) +} + // bindTypeJava converts a Solidity type to a Java one. Since there is no clear mapping // from all Solidity types to Java ones (e.g. uint17), those that cannot be exactly // mapped will use an upscaled type (e.g. BigDecimal). func bindTypeJava(kind abi.Type) string { stringKind := kind.String() + innerLen, innerMapping := bindUnnestedTypeJava(stringKind) + return arrayBindingJava(wrapArray(stringKind, innerLen, innerMapping)) +} + +// The inner function of bindTypeJava, this finds the inner type of stringKind. +// (Or just the type itself if it is not an array or slice) +// The length of the matched part is returned, with the the translated type. +func bindUnnestedTypeJava(stringKind string) (int, string) { switch { case strings.HasPrefix(stringKind, "address"): parts := regexp.MustCompile(`address(\[[0-9]*\])?`).FindStringSubmatch(stringKind) if len(parts) != 2 { - return stringKind + return len(stringKind), stringKind } if parts[1] == "" { - return fmt.Sprintf("Address") + return len("address"), "Address" } - return fmt.Sprintf("Addresses") + return len(parts[0]), "Addresses" case strings.HasPrefix(stringKind, "bytes"): - parts := regexp.MustCompile(`bytes([0-9]*)(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 3 { - return stringKind - } - if parts[2] != "" { - return "byte[][]" + parts := regexp.MustCompile(`bytes([0-9]*)`).FindStringSubmatch(stringKind) + if len(parts) != 2 { + return len(stringKind), stringKind } - return "byte[]" + return len(parts[0]), "byte[]" case strings.HasPrefix(stringKind, "int") || strings.HasPrefix(stringKind, "uint"): - parts := regexp.MustCompile(`(u)?int([0-9]*)(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 4 { - return stringKind - } - switch parts[2] { - case "8", "16", "32", "64": - if parts[1] == "" { - if parts[3] == "" { - return fmt.Sprintf("int%s", parts[2]) - } - return fmt.Sprintf("int%s[]", parts[2]) - } + //Note that uint and int (without digits) are also matched, + // these are size 256, and will translate to BigInt (the default). + parts := regexp.MustCompile(`(u)?int([0-9]*)`).FindStringSubmatch(stringKind) + if len(parts) != 3 { + return len(stringKind), stringKind } - if parts[3] == "" { - return fmt.Sprintf("BigInt") + + namedSize := map[string]string{ + "8": "byte", + "16": "short", + "32": "int", + "64": "long", + }[parts[2]] + + //default to BigInt + if namedSize == "" { + namedSize = "BigInt" } - return fmt.Sprintf("BigInts") + return len(parts[0]), namedSize case strings.HasPrefix(stringKind, "bool"): - parts := regexp.MustCompile(`bool(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 2 { - return stringKind - } - if parts[1] == "" { - return fmt.Sprintf("bool") - } - return fmt.Sprintf("bool[]") + return len("bool"), "boolean" case strings.HasPrefix(stringKind, "string"): - parts := regexp.MustCompile(`string(\[[0-9]*\])?`).FindStringSubmatch(stringKind) - if len(parts) != 2 { - return stringKind - } - if parts[1] == "" { - return fmt.Sprintf("String") - } - return fmt.Sprintf("String[]") + return len("string"), "String" default: - return stringKind + return len(stringKind), stringKind } } @@ -325,11 +354,13 @@ func namedTypeJava(javaKind string, solKind abi.Type) string { return "String" case "string[]": return "Strings" - case "bool": + case "boolean": return "Bool" - case "bool[]": + case "boolean[]": return "Bools" - case "BigInt": + case "BigInt[]": + return "BigInts" + default: parts := regexp.MustCompile(`(u)?int([0-9]*)(\[[0-9]*\])?`).FindStringSubmatch(solKind.String()) if len(parts) != 4 { return javaKind @@ -344,8 +375,6 @@ func namedTypeJava(javaKind string, solKind abi.Type) string { default: return javaKind } - default: - return javaKind } } diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index c4838e647..26816ec20 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -737,6 +737,72 @@ var bindTests = []struct { } `, }, + { + `DeeplyNestedArray`, + ` + contract DeeplyNestedArray { + uint64[3][4][5] public deepUint64Array; + function storeDeepUintArray(uint64[3][4][5] arr) public { + deepUint64Array = arr; + } + function retrieveDeepArray() public view returns (uint64[3][4][5]) { + return deepUint64Array; + } + } + `, + `6060604052341561000f57600080fd5b6106438061001e6000396000f300606060405260043610610057576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff168063344248551461005c5780638ed4573a1461011457806398ed1856146101ab575b600080fd5b341561006757600080fd5b610112600480806107800190600580602002604051908101604052809291906000905b828210156101055783826101800201600480602002604051908101604052809291906000905b828210156100f25783826060020160038060200260405190810160405280929190826003602002808284378201915050505050815260200190600101906100b0565b505050508152602001906001019061008a565b5050505091905050610208565b005b341561011f57600080fd5b61012761021d565b604051808260056000925b8184101561019b578284602002015160046000925b8184101561018d5782846020020151600360200280838360005b8381101561017c578082015181840152602081019050610161565b505050509050019260010192610147565b925050509260010192610132565b9250505091505060405180910390f35b34156101b657600080fd5b6101de6004808035906020019091908035906020019091908035906020019091905050610309565b604051808267ffffffffffffffff1667ffffffffffffffff16815260200191505060405180910390f35b80600090600561021992919061035f565b5050565b6102256103b0565b6000600580602002604051908101604052809291906000905b8282101561030057838260040201600480602002604051908101604052809291906000905b828210156102ed578382016003806020026040519081016040528092919082600380156102d9576020028201916000905b82829054906101000a900467ffffffffffffffff1667ffffffffffffffff16815260200190600801906020826007010492830192600103820291508084116102945790505b505050505081526020019060010190610263565b505050508152602001906001019061023e565b50505050905090565b60008360058110151561031857fe5b600402018260048110151561032957fe5b018160038110151561033757fe5b6004918282040191900660080292509250509054906101000a900467ffffffffffffffff1681565b826005600402810192821561039f579160200282015b8281111561039e5782518290600461038e9291906103df565b5091602001919060040190610375565b5b5090506103ac919061042d565b5090565b610780604051908101604052806005905b6103c9610459565b8152602001906001900390816103c15790505090565b826004810192821561041c579160200282015b8281111561041b5782518290600361040b929190610488565b50916020019190600101906103f2565b5b5090506104299190610536565b5090565b61045691905b8082111561045257600081816104499190610562565b50600401610433565b5090565b90565b610180604051908101604052806004905b6104726105a7565b81526020019060019003908161046a5790505090565b82600380016004900481019282156105255791602002820160005b838211156104ef57835183826101000a81548167ffffffffffffffff021916908367ffffffffffffffff16021790555092602001926008016020816007010492830192600103026104a3565b80156105235782816101000a81549067ffffffffffffffff02191690556008016020816007010492830192600103026104ef565b505b50905061053291906105d9565b5090565b61055f91905b8082111561055b57600081816105529190610610565b5060010161053c565b5090565b90565b50600081816105719190610610565b50600101600081816105839190610610565b50600101600081816105959190610610565b5060010160006105a59190610610565b565b6060604051908101604052806003905b600067ffffffffffffffff168152602001906001900390816105b75790505090565b61060d91905b8082111561060957600081816101000a81549067ffffffffffffffff0219169055506001016105df565b5090565b90565b50600090555600a165627a7a7230582087e5a43f6965ab6ef7a4ff056ab80ed78fd8c15cff57715a1bf34ec76a93661c0029`, + `[{"constant":false,"inputs":[{"name":"arr","type":"uint64[3][4][5]"}],"name":"storeDeepUintArray","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[],"name":"retrieveDeepArray","outputs":[{"name":"","type":"uint64[3][4][5]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[{"name":"","type":"uint256"},{"name":"","type":"uint256"},{"name":"","type":"uint256"}],"name":"deepUint64Array","outputs":[{"name":"","type":"uint64"}],"payable":false,"stateMutability":"view","type":"function"}]`, + ` + // Generate a new random account and a funded simulator + key, _ := crypto.GenerateKey() + auth := bind.NewKeyedTransactor(key) + sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}) + + //deploy the test contract + _, _, testContract, err := DeployDeeplyNestedArray(auth, sim) + if err != nil { + t.Fatalf("Failed to deploy test contract: %v", err) + } + + // Finish deploy. + sim.Commit() + + //Create coordinate-filled array, for testing purposes. + testArr := [5][4][3]uint64{} + for i := 0; i < 5; i++ { + testArr[i] = [4][3]uint64{} + for j := 0; j < 4; j++ { + testArr[i][j] = [3]uint64{} + for k := 0; k < 3; k++ { + //pack the coordinates, each array value will be unique, and can be validated easily. + testArr[i][j][k] = uint64(i) << 16 | uint64(j) << 8 | uint64(k) + } + } + } + + if _, err := testContract.StoreDeepUintArray(&bind.TransactOpts{ + From: auth.From, + Signer: auth.Signer, + }, testArr); err != nil { + t.Fatalf("Failed to store nested array in test contract: %v", err) + } + + sim.Commit() + + retrievedArr, err := testContract.RetrieveDeepArray(&bind.CallOpts{ + From: auth.From, + Pending: false, + }) + if err != nil { + t.Fatalf("Failed to retrieve nested array from test contract: %v", err) + } + + //quick check to see if contents were copied + // (See accounts/abi/unpack_test.go for more extensive testing) + if retrievedArr[4][3][2] != testArr[4][3][2] { + t.Fatalf("Retrieved value does not match expected value! got: %d, expected: %d. %v", retrievedArr[4][3][2], testArr[4][3][2], err) + }`, + }, } // Tests that packages generated by the binder can be successfully compiled and diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index 14ab516ac..58a5b7a58 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -300,6 +300,11 @@ func TestPack(t *testing.T) { common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { + "uint32[2][3][4]", + [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018"), + }, + { "address[]", []common.Address{{1}, {2}}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"), diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go index 334245661..793d515ad 100644 --- a/accounts/abi/unpack.go +++ b/accounts/abi/unpack.go @@ -93,15 +93,28 @@ func readFixedBytes(t Type, word []byte) (interface{}, error) { } +func getFullElemSize(elem *Type) int { + //all other should be counted as 32 (slices have pointers to respective elements) + size := 32 + //arrays wrap it, each element being the same size + for elem.T == ArrayTy { + size *= elem.Size + elem = elem.Elem + } + return size +} + // iteratively unpack elements func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) { + if size < 0 { + return nil, fmt.Errorf("cannot marshal input to array, size is negative (%d)", size) + } if start+32*size > len(output) { return nil, fmt.Errorf("abi: cannot marshal in to go array: offset %d would go over slice boundary (len=%d)", len(output), start+32*size) } // this value will become our slice or our array, depending on the type var refSlice reflect.Value - slice := output[start : start+size*32] if t.T == SliceTy { // declare our slice @@ -113,15 +126,20 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) return nil, fmt.Errorf("abi: invalid type in array/slice unpacking stage") } - for i, j := start, 0; j*32 < len(slice); i, j = i+32, j+1 { - // this corrects the arrangement so that we get all the underlying array values - if t.Elem.T == ArrayTy && j != 0 { - i = start + t.Elem.Size*32*j - } + // Arrays have packed elements, resulting in longer unpack steps. + // Slices have just 32 bytes per element (pointing to the contents). + elemSize := 32 + if t.T == ArrayTy { + elemSize = getFullElemSize(t.Elem) + } + + for i, j := start, 0; j < size; i, j = i+elemSize, j+1 { + inter, err := toGoType(i, *t.Elem, output) if err != nil { return nil, err } + // append the item to our reflect slice refSlice.Index(j).Set(reflect.ValueOf(inter)) } @@ -181,16 +199,32 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { // interprets a 32 byte slice as an offset and then determines which indice to look to decode the type. func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err error) { - offset := int(binary.BigEndian.Uint64(output[index+24 : index+32])) - if offset+32 > len(output) { - return 0, 0, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32) + bigOffsetEnd := big.NewInt(0).SetBytes(output[index : index+32]) + bigOffsetEnd.Add(bigOffsetEnd, common.Big32) + outputLength := big.NewInt(int64(len(output))) + + if bigOffsetEnd.Cmp(outputLength) > 0 { + return 0, 0, fmt.Errorf("abi: cannot marshal in to go slice: offset %v would go over slice boundary (len=%v)", bigOffsetEnd, outputLength) } - length = int(binary.BigEndian.Uint64(output[offset+24 : offset+32])) - if offset+32+length > len(output) { - return 0, 0, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+length) + + if bigOffsetEnd.BitLen() > 63 { + return 0, 0, fmt.Errorf("abi offset larger than int64: %v", bigOffsetEnd) } - start = offset + 32 - //fmt.Printf("LENGTH PREFIX INFO: \nsize: %v\noffset: %v\nstart: %v\n", length, offset, start) + offsetEnd := int(bigOffsetEnd.Uint64()) + lengthBig := big.NewInt(0).SetBytes(output[offsetEnd-32 : offsetEnd]) + + totalSize := big.NewInt(0) + totalSize.Add(totalSize, bigOffsetEnd) + totalSize.Add(totalSize, lengthBig) + if totalSize.BitLen() > 63 { + return 0, 0, fmt.Errorf("abi length larger than int64: %v", totalSize) + } + + if totalSize.Cmp(outputLength) > 0 { + return 0, 0, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %v require %v", outputLength, totalSize) + } + start = int(bigOffsetEnd.Uint64()) + length = int(lengthBig.Uint64()) return } diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index a65426a30..ee6256709 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -130,7 +130,7 @@ var unpackTests = []unpackTest{ { def: `[{"type": "bytes32"}]`, enc: "0100000000000000000000000000000000000000000000000000000000000000", - want: common.HexToHash("0100000000000000000000000000000000000000000000000000000000000000"), + want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, { def: `[{"type": "function"}]`, @@ -190,6 +190,11 @@ var unpackTests = []unpackTest{ want: [2]uint32{1, 2}, }, { + def: `[{"type": "uint32[2][3][4]"}]`, + enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018", + want: [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, + }, + { def: `[{"type": "uint64[]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: []uint64{1, 2}, @@ -435,6 +440,46 @@ func TestMultiReturnWithArray(t *testing.T) { } } +func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { + // Similar to TestMultiReturnWithArray, but with a special case in mind: + // values of nested static arrays count towards the size as well, and any element following + // after such nested array argument should be read with the correct offset, + // so that it does not read content from the previous array argument. + const definition = `[{"name" : "multi", "outputs": [{"type": "uint64[3][2][4]"}, {"type": "uint64"}]}]` + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + buff := new(bytes.Buffer) + // construct the test array, each 3 char element is joined with 61 '0' chars, + // to from the ((3 + 61) * 0.5) = 32 byte elements in the array. + buff.Write(common.Hex2Bytes(strings.Join([]string{ + "", //empty, to apply the 61-char separator to the first element as well. + "111", "112", "113", "121", "122", "123", + "211", "212", "213", "221", "222", "223", + "311", "312", "313", "321", "322", "323", + "411", "412", "413", "421", "422", "423", + }, "0000000000000000000000000000000000000000000000000000000000000"))) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000009876")) + + ret1, ret1Exp := new([4][2][3]uint64), [4][2][3]uint64{ + {{0x111, 0x112, 0x113}, {0x121, 0x122, 0x123}}, + {{0x211, 0x212, 0x213}, {0x221, 0x222, 0x223}}, + {{0x311, 0x312, 0x313}, {0x321, 0x322, 0x323}}, + {{0x411, 0x412, 0x413}, {0x421, 0x422, 0x423}}, + } + ret2, ret2Exp := new(uint64), uint64(0x9876) + if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*ret1, ret1Exp) { + t.Error("array result", *ret1, "!= Expected", ret1Exp) + } + if *ret2 != ret2Exp { + t.Error("int result", *ret2, "!= Expected", ret2Exp) + } +} + func TestUnmarshal(t *testing.T) { const definition = `[ { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] }, @@ -683,3 +728,73 @@ func TestUnmarshal(t *testing.T) { t.Fatal("expected error:", err) } } + +func TestOOMMaliciousInput(t *testing.T) { + oomTests := []unpackTest{ + { + def: `[{"type": "uint8[]"}]`, + enc: "0000000000000000000000000000000000000000000000000000000000000020" + // offset + "0000000000000000000000000000000000000000000000000000000000000003" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + { // Length larger than 64 bits + def: `[{"type": "uint8[]"}]`, + enc: "0000000000000000000000000000000000000000000000000000000000000020" + // offset + "00ffffffffffffffffffffffffffffffffffffffffffffff0000000000000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + { // Offset very large (over 64 bits) + def: `[{"type": "uint8[]"}]`, + enc: "00ffffffffffffffffffffffffffffffffffffffffffffff0000000000000020" + // offset + "0000000000000000000000000000000000000000000000000000000000000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + { // Offset very large (below 64 bits) + def: `[{"type": "uint8[]"}]`, + enc: "0000000000000000000000000000000000000000000000007ffffffffff00020" + // offset + "0000000000000000000000000000000000000000000000000000000000000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + { // Offset negative (as 64 bit) + def: `[{"type": "uint8[]"}]`, + enc: "000000000000000000000000000000000000000000000000f000000000000020" + // offset + "0000000000000000000000000000000000000000000000000000000000000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + + { // Negative length + def: `[{"type": "uint8[]"}]`, + enc: "0000000000000000000000000000000000000000000000000000000000000020" + // offset + "000000000000000000000000000000000000000000000000f000000000000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + { // Very large length + def: `[{"type": "uint8[]"}]`, + enc: "0000000000000000000000000000000000000000000000000000000000000020" + // offset + "0000000000000000000000000000000000000000000000007fffffffff000002" + // num elems + "0000000000000000000000000000000000000000000000000000000000000001" + // elem 1 + "0000000000000000000000000000000000000000000000000000000000000002", // elem 2 + }, + } + for i, test := range oomTests { + def := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def) + abi, err := JSON(strings.NewReader(def)) + if err != nil { + t.Fatalf("invalid ABI definition %s: %v", def, err) + } + encb, err := hex.DecodeString(test.enc) + if err != nil { + t.Fatalf("invalid hex: %s" + test.enc) + } + _, err = abi.Methods["method"].Outputs.UnpackValues(encb) + if err == nil { + t.Fatalf("Expected error on malicious input, test %d", i) + } + } +} diff --git a/build/ci.go b/build/ci.go index 544483c42..24b58c1ae 100644 --- a/build/ci.go +++ b/build/ci.go @@ -182,13 +182,13 @@ func doInstall(cmdline []string) { // Check Go version. People regularly open issues about compilation // failure with outdated Go. This should save them the trouble. if !strings.Contains(runtime.Version(), "devel") { - // Figure out the minor version number since we can't textually compare (1.10 < 1.7) + // Figure out the minor version number since we can't textually compare (1.10 < 1.8) var minor int fmt.Sscanf(strings.TrimPrefix(runtime.Version(), "go1."), "%d", &minor) - if minor < 7 { + if minor < 8 { log.Println("You have Go version", runtime.Version()) - log.Println("go-ethereum requires at least Go version 1.7 and cannot") + log.Println("go-ethereum requires at least Go version 1.8 and cannot") log.Println("be compiled with an earlier version. Please upgrade your Go installation.") os.Exit(1) } diff --git a/cmd/evm/main.go b/cmd/evm/main.go index 6c39cf8b8..a59cb1fb8 100644 --- a/cmd/evm/main.go +++ b/cmd/evm/main.go @@ -86,10 +86,6 @@ var ( Name: "create", Usage: "indicates the action should be create rather than call", } - DisableGasMeteringFlag = cli.BoolFlag{ - Name: "nogasmetering", - Usage: "disable gas metering", - } GenesisFlag = cli.StringFlag{ Name: "prestate", Usage: "JSON file with prestate (genesis) config", @@ -128,7 +124,6 @@ func init() { ValueFlag, DumpFlag, InputFlag, - DisableGasMeteringFlag, MemProfileFlag, CPUProfileFlag, StatDumpFlag, diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index a9a2e5420..8a7399840 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -161,9 +161,8 @@ func runCmd(ctx *cli.Context) error { GasPrice: utils.GlobalBig(ctx, PriceFlag.Name), Value: utils.GlobalBig(ctx, ValueFlag.Name), EVMConfig: vm.Config{ - Tracer: tracer, - Debug: ctx.GlobalBool(DebugFlag.Name) || ctx.GlobalBool(MachineFlag.Name), - DisableGasMetering: ctx.GlobalBool(DisableGasMeteringFlag.Name), + Tracer: tracer, + Debug: ctx.GlobalBool(DebugFlag.Name) || ctx.GlobalBool(MachineFlag.Name), }, } diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 99527f9d1..5bad09bbd 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -533,9 +533,11 @@ func (f *faucet) loop() { } defer sub.Unsubscribe() - for { - select { - case head := <-heads: + // Start a goroutine to update the state from head notifications in the background + update := make(chan *types.Header) + + go func() { + for head := range update { // New chain head arrived, query the current stats and stream to clients var ( balance *big.Int @@ -588,6 +590,17 @@ func (f *faucet) loop() { } } f.lock.RUnlock() + } + }() + // Wait for various events and assing to the appropriate background threads + for { + select { + case head := <-heads: + // New head arrived, send if for state update if there's none running + select { + case update <- head: + default: + } case <-f.update: // Pending requests updated, stream to clients @@ -686,8 +699,6 @@ func authTwitter(url string) (string, string, common.Address, error) { if len(parts) < 4 || parts[len(parts)-2] != "status" { return "", "", common.Address{}, errors.New("Invalid Twitter status URL") } - username := parts[len(parts)-3] - // Twitter's API isn't really friendly with direct links. Still, we don't // want to do ask read permissions from users, so just load the public posts and // scrape it for the Ethereum address and profile URL. @@ -697,6 +708,13 @@ func authTwitter(url string) (string, string, common.Address, error) { } defer res.Body.Close() + // Resolve the username from the final redirect, no intermediate junk + parts = strings.Split(res.Request.URL.String(), "/") + if len(parts) < 4 || parts[len(parts)-2] != "status" { + return "", "", common.Address{}, errors.New("Invalid Twitter status URL") + } + username := parts[len(parts)-3] + body, err := ioutil.ReadAll(res.Body) if err != nil { return "", "", common.Address{}, err diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go index 9d5cc38a1..2500a969c 100644 --- a/cmd/geth/consolecmd.go +++ b/cmd/geth/consolecmd.go @@ -22,6 +22,7 @@ import ( "os/signal" "path/filepath" "strings" + "syscall" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/console" @@ -42,7 +43,7 @@ var ( Description: ` The Geth console is an interactive shell for the JavaScript runtime environment which exposes a node admin interface as well as the Ðapp JavaScript API. -See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console.`, +See https://github.com/ethereum/go-ethereum/wiki/JavaScript-Console.`, } attachCommand = cli.Command{ @@ -55,7 +56,7 @@ See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console.`, Description: ` The Geth console is an interactive shell for the JavaScript runtime environment which exposes a node admin interface as well as the Ðapp JavaScript API. -See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console. +See https://github.com/ethereum/go-ethereum/wiki/JavaScript-Console. This command allows to open a console on a running geth node.`, } @@ -68,7 +69,7 @@ This command allows to open a console on a running geth node.`, Category: "CONSOLE COMMANDS", Description: ` The JavaScript VM exposes a node admin interface as well as the Ðapp -JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console`, +JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/JavaScript-Console`, } ) @@ -207,7 +208,7 @@ func ephemeralConsole(ctx *cli.Context) error { } // Wait for pending callbacks, but stop for Ctrl-C. abort := make(chan os.Signal, 1) - signal.Notify(abort, os.Interrupt) + signal.Notify(abort, syscall.SIGINT, syscall.SIGTERM) go func() { <-abort diff --git a/cmd/puppeth/genesis.go b/cmd/puppeth/genesis.go index f747f4739..1974a94aa 100644 --- a/cmd/puppeth/genesis.go +++ b/cmd/puppeth/genesis.go @@ -168,19 +168,18 @@ type parityChainSpec struct { Engine struct { Ethash struct { Params struct { - MinimumDifficulty *hexutil.Big `json:"minimumDifficulty"` - DifficultyBoundDivisor *hexutil.Big `json:"difficultyBoundDivisor"` - GasLimitBoundDivisor hexutil.Uint64 `json:"gasLimitBoundDivisor"` - DurationLimit *hexutil.Big `json:"durationLimit"` - BlockReward *hexutil.Big `json:"blockReward"` - HomesteadTransition uint64 `json:"homesteadTransition"` - EIP150Transition uint64 `json:"eip150Transition"` - EIP160Transition uint64 `json:"eip160Transition"` - EIP161abcTransition uint64 `json:"eip161abcTransition"` - EIP161dTransition uint64 `json:"eip161dTransition"` - EIP649Reward *hexutil.Big `json:"eip649Reward"` - EIP100bTransition uint64 `json:"eip100bTransition"` - EIP649Transition uint64 `json:"eip649Transition"` + MinimumDifficulty *hexutil.Big `json:"minimumDifficulty"` + DifficultyBoundDivisor *hexutil.Big `json:"difficultyBoundDivisor"` + DurationLimit *hexutil.Big `json:"durationLimit"` + BlockReward *hexutil.Big `json:"blockReward"` + HomesteadTransition uint64 `json:"homesteadTransition"` + EIP150Transition uint64 `json:"eip150Transition"` + EIP160Transition uint64 `json:"eip160Transition"` + EIP161abcTransition uint64 `json:"eip161abcTransition"` + EIP161dTransition uint64 `json:"eip161dTransition"` + EIP649Reward *hexutil.Big `json:"eip649Reward"` + EIP100bTransition uint64 `json:"eip100bTransition"` + EIP649Transition uint64 `json:"eip649Transition"` } `json:"params"` } `json:"Ethash"` } `json:"engine"` @@ -188,6 +187,7 @@ type parityChainSpec struct { Params struct { MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` MinGasLimit hexutil.Uint64 `json:"minGasLimit"` + GasLimitBoundDivisor hexutil.Uint64 `json:"gasLimitBoundDivisor"` NetworkID hexutil.Uint64 `json:"networkID"` MaxCodeSize uint64 `json:"maxCodeSize"` EIP155Transition uint64 `json:"eip155Transition"` @@ -270,7 +270,6 @@ func newParityChainSpec(network string, genesis *core.Genesis, bootnodes []strin } spec.Engine.Ethash.Params.MinimumDifficulty = (*hexutil.Big)(params.MinimumDifficulty) spec.Engine.Ethash.Params.DifficultyBoundDivisor = (*hexutil.Big)(params.DifficultyBoundDivisor) - spec.Engine.Ethash.Params.GasLimitBoundDivisor = (hexutil.Uint64)(params.GasLimitBoundDivisor) spec.Engine.Ethash.Params.DurationLimit = (*hexutil.Big)(params.DurationLimit) spec.Engine.Ethash.Params.BlockReward = (*hexutil.Big)(ethash.FrontierBlockReward) spec.Engine.Ethash.Params.HomesteadTransition = genesis.Config.HomesteadBlock.Uint64() @@ -284,6 +283,7 @@ func newParityChainSpec(network string, genesis *core.Genesis, bootnodes []strin spec.Params.MaximumExtraDataSize = (hexutil.Uint64)(params.MaximumExtraDataSize) spec.Params.MinGasLimit = (hexutil.Uint64)(params.MinGasLimit) + spec.Params.GasLimitBoundDivisor = (hexutil.Uint64)(params.GasLimitBoundDivisor) spec.Params.NetworkID = (hexutil.Uint64)(genesis.Config.ChainId.Uint64()) spec.Params.MaxCodeSize = params.MaxCodeSize spec.Params.EIP155Transition = genesis.Config.EIP155Block.Uint64() diff --git a/cmd/puppeth/module_dashboard.go b/cmd/puppeth/module_dashboard.go index 1cb2d4549..3832b247f 100644 --- a/cmd/puppeth/module_dashboard.go +++ b/cmd/puppeth/module_dashboard.go @@ -631,6 +631,7 @@ func deployDashboard(client *sshClient, network string, conf *config, config *da "Tangerine": conf.Genesis.Config.EIP150Block, "Spurious": conf.Genesis.Config.EIP155Block, "Byzantium": conf.Genesis.Config.ByzantiumBlock, + "Constantinople": conf.Genesis.Config.ConstantinopleBlock, }) files[filepath.Join(workdir, "index.html")] = indexfile.Bytes() diff --git a/cmd/puppeth/wizard_intro.go b/cmd/puppeth/wizard_intro.go index 84998afc9..60aa0f7ff 100644 --- a/cmd/puppeth/wizard_intro.go +++ b/cmd/puppeth/wizard_intro.go @@ -59,15 +59,16 @@ func (w *wizard) run() { fmt.Println() // Make sure we have a good network name to work with fmt.Println() + // Docker accepts hyphens in image names, but doesn't like it for container names if w.network == "" { - fmt.Println("Please specify a network name to administer (no spaces, please)") + fmt.Println("Please specify a network name to administer (no spaces or hyphens, please)") for { w.network = w.readString() - if !strings.Contains(w.network, " ") { + if !strings.Contains(w.network, " ") && !strings.Contains(w.network, "-") { fmt.Printf("\nSweet, you can set this via --network=%s next time!\n\n", w.network) break } - log.Error("I also like to live dangerously, still no spaces") + log.Error("I also like to live dangerously, still no spaces or hyphens") } } log.Info("Administering Ethereum network", "name", w.network) diff --git a/cmd/swarm/config.go b/cmd/swarm/config.go index 29b5faefa..adac772ba 100644 --- a/cmd/swarm/config.go +++ b/cmd/swarm/config.go @@ -23,6 +23,7 @@ import ( "os" "reflect" "strconv" + "strings" "unicode" cli "gopkg.in/urfave/cli.v1" @@ -97,10 +98,15 @@ func buildConfig(ctx *cli.Context) (config *bzzapi.Config, err error) { config = bzzapi.NewDefaultConfig() //first load settings from config file (if provided) config, err = configFileOverride(config, ctx) + if err != nil { + return nil, err + } //override settings provided by environment variables config = envVarsOverride(config) //override settings provided by command line config = cmdLineOverride(config, ctx) + //validate configuration parameters + err = validateConfig(config) return } @@ -194,12 +200,16 @@ func cmdLineOverride(currentConfig *bzzapi.Config, ctx *cli.Context) *bzzapi.Con utils.Fatalf(SWARM_ERR_SWAP_SET_NO_API) } - //EnsApi can be set to "", so can't check for empty string, as it is allowed! if ctx.GlobalIsSet(EnsAPIFlag.Name) { - currentConfig.EnsApi = ctx.GlobalString(EnsAPIFlag.Name) + ensAPIs := ctx.GlobalStringSlice(EnsAPIFlag.Name) + // preserve backward compatibility to disable ENS with --ens-api="" + if len(ensAPIs) == 1 && ensAPIs[0] == "" { + ensAPIs = nil + } + currentConfig.EnsAPIs = ensAPIs } - if ensaddr := ctx.GlobalString(EnsAddrFlag.Name); ensaddr != "" { + if ensaddr := ctx.GlobalString(DeprecatedEnsAddrFlag.Name); ensaddr != "" { currentConfig.EnsRoot = common.HexToAddress(ensaddr) } @@ -266,9 +276,8 @@ func envVarsOverride(currentConfig *bzzapi.Config) (config *bzzapi.Config) { utils.Fatalf(SWARM_ERR_SWAP_SET_NO_API) } - //EnsApi can be set to "", so can't check for empty string, as it is allowed - if ensapi, exists := os.LookupEnv(SWARM_ENV_ENS_API); exists { - currentConfig.EnsApi = ensapi + if ensapi := os.Getenv(SWARM_ENV_ENS_API); ensapi != "" { + currentConfig.EnsAPIs = strings.Split(ensapi, ",") } if ensaddr := os.Getenv(SWARM_ENV_ENS_ADDR); ensaddr != "" { @@ -309,6 +318,43 @@ func checkDeprecated(ctx *cli.Context) { if ctx.GlobalString(DeprecatedEthAPIFlag.Name) != "" { utils.Fatalf("--ethapi is no longer a valid command line flag, please use --ens-api and/or --swap-api.") } + // warn if --ens-api flag is set + if ctx.GlobalString(DeprecatedEnsAddrFlag.Name) != "" { + log.Warn("--ens-addr is no longer a valid command line flag, please use --ens-api to specify contract address.") + } +} + +//validate configuration parameters +func validateConfig(cfg *bzzapi.Config) (err error) { + for _, ensAPI := range cfg.EnsAPIs { + if ensAPI != "" { + if err := validateEnsAPIs(ensAPI); err != nil { + return fmt.Errorf("invalid format [tld:][contract-addr@]url for ENS API endpoint configuration %q: %v", ensAPI, err) + } + } + } + return nil +} + +//validate EnsAPIs configuration parameter +func validateEnsAPIs(s string) (err error) { + // missing contract address + if strings.HasPrefix(s, "@") { + return errors.New("missing contract address") + } + // missing url + if strings.HasSuffix(s, "@") { + return errors.New("missing url") + } + // missing tld + if strings.HasPrefix(s, ":") { + return errors.New("missing tld") + } + // missing url + if strings.HasSuffix(s, ":") { + return errors.New("missing url") + } + return nil } //print a Config as string diff --git a/cmd/swarm/config_test.go b/cmd/swarm/config_test.go index 166980d14..9bf584f50 100644 --- a/cmd/swarm/config_test.go +++ b/cmd/swarm/config_test.go @@ -457,3 +457,98 @@ func TestCmdLineOverridesFile(t *testing.T) { node.Shutdown() } + +func TestValidateConfig(t *testing.T) { + for _, c := range []struct { + cfg *api.Config + err string + }{ + { + cfg: &api.Config{EnsAPIs: []string{ + "/data/testnet/geth.ipc", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "http://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "ws://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "test:/data/testnet/geth.ipc", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "test:ws://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "314159265dD8dbb310642f98f50C066173C1259b@/data/testnet/geth.ipc", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "314159265dD8dbb310642f98f50C066173C1259b@http://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "314159265dD8dbb310642f98f50C066173C1259b@ws://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "test:314159265dD8dbb310642f98f50C066173C1259b@/data/testnet/geth.ipc", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "eth:314159265dD8dbb310642f98f50C066173C1259b@http://127.0.0.1:1234", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "eth:314159265dD8dbb310642f98f50C066173C1259b@ws://127.0.0.1:12344", + }}, + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "eth:", + }}, + err: "invalid format [tld:][contract-addr@]url for ENS API endpoint configuration \"eth:\": missing url", + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "314159265dD8dbb310642f98f50C066173C1259b@", + }}, + err: "invalid format [tld:][contract-addr@]url for ENS API endpoint configuration \"314159265dD8dbb310642f98f50C066173C1259b@\": missing url", + }, + { + cfg: &api.Config{EnsAPIs: []string{ + ":314159265dD8dbb310642f98f50C066173C1259", + }}, + err: "invalid format [tld:][contract-addr@]url for ENS API endpoint configuration \":314159265dD8dbb310642f98f50C066173C1259\": missing tld", + }, + { + cfg: &api.Config{EnsAPIs: []string{ + "@/data/testnet/geth.ipc", + }}, + err: "invalid format [tld:][contract-addr@]url for ENS API endpoint configuration \"@/data/testnet/geth.ipc\": missing contract address", + }, + } { + err := validateConfig(c.cfg) + if c.err != "" && err.Error() != c.err { + t.Errorf("expected error %q, got %q", c.err, err) + } + if c.err == "" && err != nil { + t.Errorf("unexpected error %q", err) + } + } +} diff --git a/cmd/swarm/main.go b/cmd/swarm/main.go index 77315a426..360020b77 100644 --- a/cmd/swarm/main.go +++ b/cmd/swarm/main.go @@ -17,11 +17,9 @@ package main import ( - "context" "crypto/ecdsa" "fmt" "io/ioutil" - "math/big" "os" "os/signal" "runtime" @@ -29,14 +27,12 @@ import ( "strconv" "strings" "syscall" - "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/console" - "github.com/ethereum/go-ethereum/contracts/ens" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/internal/debug" @@ -45,9 +41,9 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/swarm" bzzapi "github.com/ethereum/go-ethereum/swarm/api" + swarmmetrics "github.com/ethereum/go-ethereum/swarm/metrics" "gopkg.in/urfave/cli.v1" ) @@ -110,16 +106,11 @@ var ( Usage: "Swarm Syncing enabled (default true)", EnvVar: SWARM_ENV_SYNC_ENABLE, } - EnsAPIFlag = cli.StringFlag{ + EnsAPIFlag = cli.StringSliceFlag{ Name: "ens-api", - Usage: "URL of the Ethereum API provider to use for ENS record lookups", + Usage: "ENS API endpoint for a TLD and with contract address, can be repeated, format [tld:][contract-addr@]url", EnvVar: SWARM_ENV_ENS_API, } - EnsAddrFlag = cli.StringFlag{ - Name: "ens-addr", - Usage: "ENS contract address (default is detected as testnet or mainnet using --ens-api)", - EnvVar: SWARM_ENV_ENS_ADDR, - } SwarmApiFlag = cli.StringFlag{ Name: "bzzapi", Usage: "Swarm HTTP endpoint", @@ -156,6 +147,10 @@ var ( Name: "ethapi", Usage: "DEPRECATED: please use --ens-api and --swap-api", } + DeprecatedEnsAddrFlag = cli.StringFlag{ + Name: "ens-addr", + Usage: "DEPRECATED: ENS contract address, please use --ens-api with contract address according to its format", + } ) //declare a few constant error messages, useful for later error check comparisons in test @@ -343,7 +338,6 @@ DEPRECATED: use 'swarm db clean'. // bzzd-specific flags CorsStringFlag, EnsAPIFlag, - EnsAddrFlag, SwarmTomlConfigPathFlag, SwarmConfigPathFlag, SwarmSwapEnabledFlag, @@ -363,11 +357,17 @@ DEPRECATED: use 'swarm db clean'. SwarmUploadMimeType, //deprecated flags DeprecatedEthAPIFlag, + DeprecatedEnsAddrFlag, } app.Flags = append(app.Flags, debug.Flags...) + app.Flags = append(app.Flags, swarmmetrics.Flags...) app.Before = func(ctx *cli.Context) error { runtime.GOMAXPROCS(runtime.NumCPU()) - return debug.Setup(ctx) + if err := debug.Setup(ctx); err != nil { + return err + } + swarmmetrics.Setup(ctx) + return nil } app.After = func(ctx *cli.Context) error { debug.Exit() @@ -448,38 +448,6 @@ func bzzd(ctx *cli.Context) error { return nil } -// detectEnsAddr determines the ENS contract address by getting both the -// version and genesis hash using the client and matching them to either -// mainnet or testnet addresses -func detectEnsAddr(client *rpc.Client) (common.Address, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - var version string - if err := client.CallContext(ctx, &version, "net_version"); err != nil { - return common.Address{}, err - } - - block, err := ethclient.NewClient(client).BlockByNumber(ctx, big.NewInt(0)) - if err != nil { - return common.Address{}, err - } - - switch { - - case version == "1" && block.Hash() == params.MainnetGenesisHash: - log.Info("using Mainnet ENS contract address", "addr", ens.MainNetAddress) - return ens.MainNetAddress, nil - - case version == "3" && block.Hash() == params.TestnetGenesisHash: - log.Info("using Testnet ENS contract address", "addr", ens.TestNetAddress) - return ens.TestNetAddress, nil - - default: - return common.Address{}, fmt.Errorf("unknown version and genesis hash: %s %s", version, block.Hash()) - } -} - func registerBzzService(bzzconfig *bzzapi.Config, ctx *cli.Context, stack *node.Node) { //define the swarm service boot function @@ -494,27 +462,7 @@ func registerBzzService(bzzconfig *bzzapi.Config, ctx *cli.Context, stack *node. } } - var ensClient *ethclient.Client - if bzzconfig.EnsApi != "" { - log.Info("connecting to ENS API", "url", bzzconfig.EnsApi) - client, err := rpc.Dial(bzzconfig.EnsApi) - if err != nil { - return nil, fmt.Errorf("error connecting to ENS API %s: %s", bzzconfig.EnsApi, err) - } - ensClient = ethclient.NewClient(client) - - //no ENS root address set yet - if bzzconfig.EnsRoot == (common.Address{}) { - ensAddr, err := detectEnsAddr(client) - if err == nil { - bzzconfig.EnsRoot = ensAddr - } else { - log.Warn(fmt.Sprintf("could not determine ENS contract address, using default %s", bzzconfig.EnsRoot), "err", err) - } - } - } - - return swarm.NewSwarm(ctx, swapClient, ensClient, bzzconfig, bzzconfig.SwapEnabled, bzzconfig.SyncEnabled, bzzconfig.Cors) + return swarm.NewSwarm(ctx, swapClient, bzzconfig) } //register within the ethereum node if err := stack.Register(boot); err != nil { diff --git a/cmd/swarm/manifest.go b/cmd/swarm/manifest.go index aa276e0f9..41a69a5d0 100644 --- a/cmd/swarm/manifest.go +++ b/cmd/swarm/manifest.go @@ -35,7 +35,7 @@ const bzzManifestJSON = "application/bzz-manifest+json" func add(ctx *cli.Context) { args := ctx.Args() if len(args) < 3 { - utils.Fatalf("Need atleast three arguments <MHASH> <path> <HASH> [<content-type>]") + utils.Fatalf("Need at least three arguments <MHASH> <path> <HASH> [<content-type>]") } var ( @@ -69,7 +69,7 @@ func update(ctx *cli.Context) { args := ctx.Args() if len(args) < 3 { - utils.Fatalf("Need atleast three arguments <MHASH> <path> <HASH>") + utils.Fatalf("Need at least three arguments <MHASH> <path> <HASH>") } var ( @@ -101,7 +101,7 @@ func update(ctx *cli.Context) { func remove(ctx *cli.Context) { args := ctx.Args() if len(args) < 2 { - utils.Fatalf("Need atleast two arguments <MHASH> <path>") + utils.Fatalf("Need at least two arguments <MHASH> <path>") } var ( diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 53cdf7861..186d18d8f 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -25,6 +25,7 @@ import ( "os/signal" "runtime" "strings" + "syscall" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -64,7 +65,7 @@ func StartNode(stack *node.Node) { } go func() { sigc := make(chan os.Signal, 1) - signal.Notify(sigc, os.Interrupt) + signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM) defer signal.Stop(sigc) <-sigc log.Info("Got interrupt, shutting down...") @@ -85,7 +86,7 @@ func ImportChain(chain *core.BlockChain, fn string) error { // If a signal is received, the import will stop at the next batch. interrupt := make(chan os.Signal, 1) stop := make(chan struct{}) - signal.Notify(interrupt, os.Interrupt) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) defer signal.Stop(interrupt) defer close(interrupt) go func() { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 5fd5013f0..dbf26b8e0 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -400,7 +400,7 @@ var ( RPCVirtualHostsFlag = cli.StringFlag{ Name: "rpcvhosts", Usage: "Comma separated list of virtual hostnames from which to accept requests (server enforced). Accepts '*' wildcard.", - Value: "localhost", + Value: strings.Join(node.DefaultConfig.HTTPVirtualHosts, ","), } RPCApiFlag = cli.StringFlag{ Name: "rpcapi", @@ -695,8 +695,9 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { if ctx.GlobalIsSet(RPCApiFlag.Name) { cfg.HTTPModules = splitAndTrim(ctx.GlobalString(RPCApiFlag.Name)) } - - cfg.HTTPVirtualHosts = splitAndTrim(ctx.GlobalString(RPCVirtualHostsFlag.Name)) + if ctx.GlobalIsSet(RPCVirtualHostsFlag.Name) { + cfg.HTTPVirtualHosts = splitAndTrim(ctx.GlobalString(RPCVirtualHostsFlag.Name)) + } } // setWS creates the WebSocket RPC listener interface string from the set diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go index 971b1c0ab..84bdfa4c3 100644 --- a/cmd/wnode/main.go +++ b/cmd/wnode/main.go @@ -22,6 +22,7 @@ package main import ( "bufio" "crypto/ecdsa" + crand "crypto/rand" "crypto/sha512" "encoding/binary" "encoding/hex" @@ -48,6 +49,7 @@ import ( ) const quitCommand = "~Q" +const entropySize = 32 // singletons var ( @@ -55,6 +57,7 @@ var ( shh *whisper.Whisper done chan struct{} mailServer mailserver.WMailServer + entropy [entropySize]byte input = bufio.NewReader(os.Stdin) ) @@ -76,14 +79,15 @@ var ( // cmd arguments var ( - bootstrapMode = flag.Bool("standalone", false, "boostrap node: don't actively connect to peers, wait for incoming connections") - forwarderMode = flag.Bool("forwarder", false, "forwarder mode: only forward messages, neither send nor decrypt messages") + bootstrapMode = flag.Bool("standalone", false, "boostrap node: don't initiate connection to peers, just wait for incoming connections") + forwarderMode = flag.Bool("forwarder", false, "forwarder mode: only forward messages, neither encrypt nor decrypt messages") mailServerMode = flag.Bool("mailserver", false, "mail server mode: delivers expired messages on demand") requestMail = flag.Bool("mailclient", false, "request expired messages from the bootstrap server") asymmetricMode = flag.Bool("asym", false, "use asymmetric encryption") generateKey = flag.Bool("generatekey", false, "generate and show the private key") fileExMode = flag.Bool("fileexchange", false, "file exchange mode") - testMode = flag.Bool("test", false, "use of predefined parameters for diagnostics") + fileReader = flag.Bool("filereader", false, "load and decrypt messages saved as files, display as plain text") + testMode = flag.Bool("test", false, "use of predefined parameters for diagnostics (password, etc.)") echoMode = flag.Bool("echo", false, "echo mode: prints some arguments for diagnostics") argVerbosity = flag.Int("verbosity", int(log.LvlError), "log verbosity level") @@ -99,13 +103,14 @@ var ( argIDFile = flag.String("idfile", "", "file name with node id (private key)") argEnode = flag.String("boot", "", "bootstrap node you want to connect to (e.g. enode://e454......08d50@52.176.211.200:16428)") argTopic = flag.String("topic", "", "topic in hexadecimal format (e.g. 70a4beef)") - argSaveDir = flag.String("savedir", "", "directory where incoming messages will be saved as files") + argSaveDir = flag.String("savedir", "", "directory where all incoming messages will be saved as files") ) func main() { processArgs() initialize() run() + shutdown() } func processArgs() { @@ -205,21 +210,6 @@ func initialize() { MinimumAcceptedPOW: *argPoW, } - if *mailServerMode { - if len(msPassword) == 0 { - msPassword, err = console.Stdin.PromptPassword("Please enter the Mail Server password: ") - if err != nil { - utils.Fatalf("Failed to read Mail Server password: %s", err) - } - } - - shh = whisper.New(cfg) - shh.RegisterServer(&mailServer) - mailServer.Init(shh, *argDBPath, msPassword, *argServerPoW) - } else { - shh = whisper.New(cfg) - } - if *argPoW != whisper.DefaultMinimumPoW { err := shh.SetMinimumPoW(*argPoW) if err != nil { @@ -261,6 +251,26 @@ func initialize() { maxPeers = 800 } + _, err = crand.Read(entropy[:]) + if err != nil { + utils.Fatalf("crypto/rand failed: %s", err) + } + + if *mailServerMode { + if len(msPassword) == 0 { + msPassword, err = console.Stdin.PromptPassword("Please enter the Mail Server password: ") + if err != nil { + utils.Fatalf("Failed to read Mail Server password: %s", err) + } + } + + shh = whisper.New(cfg) + shh.RegisterServer(&mailServer) + mailServer.Init(shh, *argDBPath, msPassword, *argServerPoW) + } else { + shh = whisper.New(cfg) + } + server = &p2p.Server{ Config: p2p.Config{ PrivateKey: nodeid, @@ -276,10 +286,11 @@ func initialize() { } } -func startServer() { +func startServer() error { err := server.Start() if err != nil { - utils.Fatalf("Failed to start Whisper peer: %s.", err) + fmt.Printf("Failed to start Whisper peer: %s.", err) + return err } fmt.Printf("my public key: %s \n", common.ToHex(crypto.FromECDSAPub(&asymKey.PublicKey))) @@ -298,6 +309,7 @@ func startServer() { if !*forwarderMode { fmt.Printf("Please type the message. To quit type: '%s'\n", quitCommand) } + return nil } func isKeyValid(k *ecdsa.PublicKey) bool { @@ -411,8 +423,10 @@ func waitForConnection(timeout bool) { } func run() { - defer mailServer.Close() - startServer() + err := startServer() + if err != nil { + return + } defer server.Stop() shh.Start(nil) defer shh.Stop() @@ -425,21 +439,26 @@ func run() { requestExpiredMessagesLoop() } else if *fileExMode { sendFilesLoop() + } else if *fileReader { + fileReaderLoop() } else { sendLoop() } } +func shutdown() { + close(done) + mailServer.Close() +} + func sendLoop() { for { s := scanLine("") if s == quitCommand { fmt.Println("Quit command received") - close(done) - break + return } sendMsg([]byte(s)) - if *asymmetricMode { // print your own message for convenience, // because in asymmetric mode it is impossible to decrypt it @@ -455,13 +474,11 @@ func sendFilesLoop() { s := scanLine("") if s == quitCommand { fmt.Println("Quit command received") - close(done) - break + return } b, err := ioutil.ReadFile(s) if err != nil { fmt.Printf(">>> Error: %s \n", err) - continue } else { h := sendMsg(b) if (h == common.Hash{}) { @@ -475,6 +492,38 @@ func sendFilesLoop() { } } +func fileReaderLoop() { + watcher1 := shh.GetFilter(symFilterID) + watcher2 := shh.GetFilter(asymFilterID) + if watcher1 == nil && watcher2 == nil { + fmt.Println("Error: neither symmetric nor asymmetric filter is installed") + return + } + + for { + s := scanLine("") + if s == quitCommand { + fmt.Println("Quit command received") + return + } + raw, err := ioutil.ReadFile(s) + if err != nil { + fmt.Printf(">>> Error: %s \n", err) + } else { + env := whisper.Envelope{Data: raw} // the topic is zero + msg := env.Open(watcher1) // force-open envelope regardless of the topic + if msg == nil { + msg = env.Open(watcher2) + } + if msg == nil { + fmt.Printf(">>> Error: failed to decrypt the message \n") + } else { + printMessageInfo(msg) + } + } + } +} + func scanLine(prompt string) string { if len(prompt) > 0 { fmt.Print(prompt) @@ -548,20 +597,18 @@ func messageLoop() { for { select { case <-ticker.C: - messages := sf.Retrieve() + m1 := sf.Retrieve() + m2 := af.Retrieve() + messages := append(m1, m2...) for _, msg := range messages { - if *fileExMode || len(msg.Payload) > 2048 { + // All messages are saved upon specifying argSaveDir. + // fileExMode only specifies how messages are displayed on the console after they are saved. + // if fileExMode == true, only the hashes are displayed, since messages might be too big. + if len(*argSaveDir) > 0 { writeMessageToFile(*argSaveDir, msg) - } else { - printMessageInfo(msg) } - } - messages = af.Retrieve() - for _, msg := range messages { - if *fileExMode || len(msg.Payload) > 2048 { - writeMessageToFile(*argSaveDir, msg) - } else { + if !*fileExMode && len(msg.Payload) <= 2048 { printMessageInfo(msg) } } @@ -596,27 +643,30 @@ func writeMessageToFile(dir string, msg *whisper.ReceivedMessage) { address = crypto.PubkeyToAddress(*msg.Src) } - if whisper.IsPubKeyEqual(msg.Src, &asymKey.PublicKey) { - // message from myself: don't save, only report - fmt.Printf("\n%s <%x>: message received: '%s'\n", timestamp, address, name) - } else if len(dir) > 0 { + // this is a sample code; uncomment if you don't want to save your own messages. + //if whisper.IsPubKeyEqual(msg.Src, &asymKey.PublicKey) { + // fmt.Printf("\n%s <%x>: message from myself received, not saved: '%s'\n", timestamp, address, name) + // return + //} + + if len(dir) > 0 { fullpath := filepath.Join(dir, name) - err := ioutil.WriteFile(fullpath, msg.Payload, 0644) + err := ioutil.WriteFile(fullpath, msg.Raw, 0644) if err != nil { fmt.Printf("\n%s {%x}: message received but not saved: %s\n", timestamp, address, err) } else { - fmt.Printf("\n%s {%x}: message received and saved as '%s' (%d bytes)\n", timestamp, address, name, len(msg.Payload)) + fmt.Printf("\n%s {%x}: message received and saved as '%s' (%d bytes)\n", timestamp, address, name, len(msg.Raw)) } } else { - fmt.Printf("\n%s {%x}: big message received (%d bytes), but not saved: %s\n", timestamp, address, len(msg.Payload), name) + fmt.Printf("\n%s {%x}: message received (%d bytes), but not saved: %s\n", timestamp, address, len(msg.Raw), name) } } func requestExpiredMessagesLoop() { - var key, peerID []byte + var key, peerID, bloom []byte var timeLow, timeUpp uint32 var t string - var xt, empty whisper.TopicType + var xt whisper.TopicType keyID, err := shh.AddSymKeyFromPassword(msPassword) if err != nil { @@ -639,18 +689,19 @@ func requestExpiredMessagesLoop() { utils.Fatalf("Failed to parse the topic: %s", err) } xt = whisper.BytesToTopic(x) + bloom = whisper.TopicToBloom(xt) + obfuscateBloom(bloom) + } else { + bloom = whisper.MakeFullNodeBloom() } if timeUpp == 0 { timeUpp = 0xFFFFFFFF } - data := make([]byte, 8+whisper.TopicLength) + data := make([]byte, 8, 8+whisper.BloomFilterSize) binary.BigEndian.PutUint32(data, timeLow) binary.BigEndian.PutUint32(data[4:], timeUpp) - copy(data[8:], xt[:]) - if xt == empty { - data = data[:8] - } + data = append(data, bloom...) var params whisper.MessageParams params.PoW = *argServerPoW @@ -684,3 +735,20 @@ func extractIDFromEnode(s string) []byte { } return n.ID[:] } + +// obfuscateBloom adds 16 random bits to the the bloom +// filter, in order to obfuscate the containing topics. +// it does so deterministically within every session. +// despite additional bits, it will match on average +// 32000 times less messages than full node's bloom filter. +func obfuscateBloom(bloom []byte) { + const half = entropySize / 2 + for i := 0; i < half; i++ { + x := int(entropy[i]) + if entropy[half+i] < 128 { + x += 256 + } + + bloom[x/8] = 1 << uint(x%8) // set the bit number X + } +} diff --git a/common/big.go b/common/big.go index b552608bc..65d4377bf 100644 --- a/common/big.go +++ b/common/big.go @@ -25,6 +25,6 @@ var ( Big3 = big.NewInt(3) Big0 = big.NewInt(0) Big32 = big.NewInt(32) - Big256 = big.NewInt(0xff) + Big256 = big.NewInt(256) Big257 = big.NewInt(257) ) diff --git a/consensus/ethash/algorithm.go b/consensus/ethash/algorithm.go index 10767bb31..905a7b1ea 100644 --- a/consensus/ethash/algorithm.go +++ b/consensus/ethash/algorithm.go @@ -19,6 +19,7 @@ package ethash import ( "encoding/binary" "hash" + "math/big" "reflect" "runtime" "sync" @@ -47,6 +48,48 @@ const ( loopAccesses = 64 // Number of accesses in hashimoto loop ) +// cacheSize returns the size of the ethash verification cache that belongs to a certain +// block number. +func cacheSize(block uint64) uint64 { + epoch := int(block / epochLength) + if epoch < maxEpoch { + return cacheSizes[epoch] + } + return calcCacheSize(epoch) +} + +// calcCacheSize calculates the cache size for epoch. The cache size grows linearly, +// however, we always take the highest prime below the linearly growing threshold in order +// to reduce the risk of accidental regularities leading to cyclic behavior. +func calcCacheSize(epoch int) uint64 { + size := cacheInitBytes + cacheGrowthBytes*uint64(epoch) - hashBytes + for !new(big.Int).SetUint64(size / hashBytes).ProbablyPrime(1) { // Always accurate for n < 2^64 + size -= 2 * hashBytes + } + return size +} + +// datasetSize returns the size of the ethash mining dataset that belongs to a certain +// block number. +func datasetSize(block uint64) uint64 { + epoch := int(block / epochLength) + if epoch < maxEpoch { + return datasetSizes[epoch] + } + return calcDatasetSize(epoch) +} + +// calcDatasetSize calculates the dataset size for epoch. The dataset size grows linearly, +// however, we always take the highest prime below the linearly growing threshold in order +// to reduce the risk of accidental regularities leading to cyclic behavior. +func calcDatasetSize(epoch int) uint64 { + size := datasetInitBytes + datasetGrowthBytes*uint64(epoch) - mixBytes + for !new(big.Int).SetUint64(size / mixBytes).ProbablyPrime(1) { // Always accurate for n < 2^64 + size -= 2 * mixBytes + } + return size +} + // hasher is a repetitive hasher allowing the same hash data structures to be // reused between hash runs instead of requiring new ones to be created. type hasher func(dest []byte, data []byte) diff --git a/consensus/ethash/algorithm_go1.7.go b/consensus/ethash/algorithm_go1.7.go deleted file mode 100644 index c7f7f48e4..000000000 --- a/consensus/ethash/algorithm_go1.7.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// +build !go1.8 - -package ethash - -// cacheSize calculates and returns the size of the ethash verification cache that -// belongs to a certain block number. The cache size grows linearly, however, we -// always take the highest prime below the linearly growing threshold in order to -// reduce the risk of accidental regularities leading to cyclic behavior. -func cacheSize(block uint64) uint64 { - // If we have a pre-generated value, use that - epoch := int(block / epochLength) - if epoch < maxEpoch { - return cacheSizes[epoch] - } - // We don't have a way to verify primes fast before Go 1.8 - panic("fast prime testing unsupported in Go < 1.8") -} - -// datasetSize calculates and returns the size of the ethash mining dataset that -// belongs to a certain block number. The dataset size grows linearly, however, we -// always take the highest prime below the linearly growing threshold in order to -// reduce the risk of accidental regularities leading to cyclic behavior. -func datasetSize(block uint64) uint64 { - // If we have a pre-generated value, use that - epoch := int(block / epochLength) - if epoch < maxEpoch { - return datasetSizes[epoch] - } - // We don't have a way to verify primes fast before Go 1.8 - panic("fast prime testing unsupported in Go < 1.8") -} diff --git a/consensus/ethash/algorithm_go1.8.go b/consensus/ethash/algorithm_go1.8.go deleted file mode 100644 index 975fdffe5..000000000 --- a/consensus/ethash/algorithm_go1.8.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// +build go1.8 - -package ethash - -import "math/big" - -// cacheSize returns the size of the ethash verification cache that belongs to a certain -// block number. -func cacheSize(block uint64) uint64 { - epoch := int(block / epochLength) - if epoch < maxEpoch { - return cacheSizes[epoch] - } - return calcCacheSize(epoch) -} - -// calcCacheSize calculates the cache size for epoch. The cache size grows linearly, -// however, we always take the highest prime below the linearly growing threshold in order -// to reduce the risk of accidental regularities leading to cyclic behavior. -func calcCacheSize(epoch int) uint64 { - size := cacheInitBytes + cacheGrowthBytes*uint64(epoch) - hashBytes - for !new(big.Int).SetUint64(size / hashBytes).ProbablyPrime(1) { // Always accurate for n < 2^64 - size -= 2 * hashBytes - } - return size -} - -// datasetSize returns the size of the ethash mining dataset that belongs to a certain -// block number. -func datasetSize(block uint64) uint64 { - epoch := int(block / epochLength) - if epoch < maxEpoch { - return datasetSizes[epoch] - } - return calcDatasetSize(epoch) -} - -// calcDatasetSize calculates the dataset size for epoch. The dataset size grows linearly, -// however, we always take the highest prime below the linearly growing threshold in order -// to reduce the risk of accidental regularities leading to cyclic behavior. -func calcDatasetSize(epoch int) uint64 { - size := datasetInitBytes + datasetGrowthBytes*uint64(epoch) - mixBytes - for !new(big.Int).SetUint64(size / mixBytes).ProbablyPrime(1) { // Always accurate for n < 2^64 - size -= 2 * mixBytes - } - return size -} diff --git a/consensus/ethash/algorithm_go1.8_test.go b/consensus/ethash/algorithm_go1.8_test.go deleted file mode 100644 index 6648bd6a9..000000000 --- a/consensus/ethash/algorithm_go1.8_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// +build go1.8 - -package ethash - -import "testing" - -// Tests whether the dataset size calculator works correctly by cross checking the -// hard coded lookup table with the value generated by it. -func TestSizeCalculations(t *testing.T) { - // Verify all the cache and dataset sizes from the lookup table. - for epoch, want := range cacheSizes { - if size := calcCacheSize(epoch); size != want { - t.Errorf("cache %d: cache size mismatch: have %d, want %d", epoch, size, want) - } - } - for epoch, want := range datasetSizes { - if size := calcDatasetSize(epoch); size != want { - t.Errorf("dataset %d: dataset size mismatch: have %d, want %d", epoch, size, want) - } - } -} diff --git a/consensus/ethash/algorithm_test.go b/consensus/ethash/algorithm_test.go index a54f3b582..841e39233 100644 --- a/consensus/ethash/algorithm_test.go +++ b/consensus/ethash/algorithm_test.go @@ -30,6 +30,22 @@ import ( "github.com/ethereum/go-ethereum/core/types" ) +// Tests whether the dataset size calculator works correctly by cross checking the +// hard coded lookup table with the value generated by it. +func TestSizeCalculations(t *testing.T) { + // Verify all the cache and dataset sizes from the lookup table. + for epoch, want := range cacheSizes { + if size := calcCacheSize(epoch); size != want { + t.Errorf("cache %d: cache size mismatch: have %d, want %d", epoch, size, want) + } + } + for epoch, want := range datasetSizes { + if size := calcDatasetSize(epoch); size != want { + t.Errorf("dataset %d: dataset size mismatch: have %d, want %d", epoch, size, want) + } + } +} + // Tests that verification caches can be correctly generated. func TestCacheGeneration(t *testing.T) { tests := []struct { diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index 92a23d4a4..99eec8221 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -53,7 +53,6 @@ var ( errDuplicateUncle = errors.New("duplicate uncle") errUncleIsAncestor = errors.New("uncle is ancestor") errDanglingUncle = errors.New("uncle's parent is not ancestor") - errNonceOutOfRange = errors.New("nonce out of range") errInvalidDifficulty = errors.New("non-positive difficulty") errInvalidMixDigest = errors.New("invalid mix digest") errInvalidPoW = errors.New("invalid proof-of-work") @@ -356,7 +355,7 @@ func calcDifficultyByzantium(time uint64, parent *types.Header) *big.Int { if x.Cmp(params.MinimumDifficulty) < 0 { x.Set(params.MinimumDifficulty) } - // calculate a fake block numer for the ice-age delay: + // calculate a fake block number for the ice-age delay: // https://github.com/ethereum/EIPs/pull/669 // fake_block_number = min(0, block.number - 3_000_000 fakeBlockNumber := new(big.Int) @@ -474,18 +473,13 @@ func (ethash *Ethash) VerifySeal(chain consensus.ChainReader, header *types.Head if ethash.shared != nil { return ethash.shared.VerifySeal(chain, header) } - // Sanity check that the block number is below the lookup table size (60M blocks) - number := header.Number.Uint64() - if number/epochLength >= maxEpoch { - // Go < 1.7 cannot calculate new cache/dataset sizes (no fast prime check) - return errNonceOutOfRange - } // Ensure that we have a valid difficulty for the block if header.Difficulty.Sign() <= 0 { return errInvalidDifficulty } - // Recompute the digest and PoW value and verify against the header + number := header.Number.Uint64() + cache := ethash.cache(number) size := datasetSize(number) if ethash.config.PowMode == ModeTest { diff --git a/consensus/ethash/ethash.go b/consensus/ethash/ethash.go index 91e20112a..1b3dcee30 100644 --- a/consensus/ethash/ethash.go +++ b/consensus/ethash/ethash.go @@ -35,9 +35,9 @@ import ( mmap "github.com/edsrzf/mmap-go" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rpc" "github.com/hashicorp/golang-lru/simplelru" - metrics "github.com/rcrowley/go-metrics" ) var ErrInvalidDumpMagic = errors.New("invalid dump magic") diff --git a/console/console.go b/console/console.go index 52fe1f542..b280d4e65 100644 --- a/console/console.go +++ b/console/console.go @@ -26,6 +26,7 @@ import ( "regexp" "sort" "strings" + "syscall" "github.com/ethereum/go-ethereum/internal/jsre" "github.com/ethereum/go-ethereum/internal/web3ext" @@ -332,7 +333,7 @@ func (c *Console) Interactive() { }() // Monitor Ctrl-C too in case the input is empty and we need to bail abort := make(chan os.Signal, 1) - signal.Notify(abort, os.Interrupt) + signal.Notify(abort, syscall.SIGINT, syscall.SIGTERM) // Start sending prompts to the user and reading back inputs for { diff --git a/contracts/chequebook/cheque_test.go b/contracts/chequebook/cheque_test.go index b7555d081..6b6b28e65 100644 --- a/contracts/chequebook/cheque_test.go +++ b/contracts/chequebook/cheque_test.go @@ -281,8 +281,8 @@ func TestDeposit(t *testing.T) { t.Fatalf("expected balance %v, got %v", exp, chbook.Balance()) } - // autodeposit every 30ms if new cheque issued - interval := 30 * time.Millisecond + // autodeposit every 200ms if new cheque issued + interval := 200 * time.Millisecond chbook.AutoDeposit(interval, common.Big1, balance) _, err = chbook.Issue(addr1, amount) if err != nil { diff --git a/core/asm/lexer.go b/core/asm/lexer.go index a34b2cbd8..405499950 100644 --- a/core/asm/lexer.go +++ b/core/asm/lexer.go @@ -206,7 +206,7 @@ func lexLine(l *lexer) stateFn { return lexComment case isSpace(r): l.ignore() - case isAlphaNumeric(r) || r == '_': + case isLetter(r) || r == '_': return lexElement case isNumber(r): return lexNumber @@ -278,7 +278,7 @@ func lexElement(l *lexer) stateFn { return lexLine } -func isAlphaNumeric(t rune) bool { +func isLetter(t rune) bool { return unicode.IsLetter(t) } diff --git a/core/blockchain.go b/core/blockchain.go index 4ae0e4f4e..6006e6674 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -46,7 +46,7 @@ import ( ) var ( - blockInsertTimer = metrics.NewTimer("chain/inserts") + blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) ErrNoGenesis = errors.New("Genesis not found in chain") ) @@ -107,8 +107,8 @@ type BlockChain struct { procmu sync.RWMutex // block processor lock checkpoint int // checkpoint counts towards the new checkpoint - currentBlock *types.Block // Current head of the block chain - currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) + currentBlock atomic.Value // Current head of the block chain + currentFastBlock atomic.Value // Current head of the fast-sync chain (may be above the block chain!) stateCache state.Database // State database to reuse between imports (contains state cache) bodyCache *lru.Cache // Cache for the most recent block bodies @@ -224,10 +224,10 @@ func (bc *BlockChain) loadLastState() error { } } // Everything seems to be fine, set as the head block - bc.currentBlock = currentBlock + bc.currentBlock.Store(currentBlock) // Restore the last known head header - currentHeader := bc.currentBlock.Header() + currentHeader := currentBlock.Header() if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) { if header := bc.GetHeaderByHash(head); header != nil { currentHeader = header @@ -236,21 +236,23 @@ func (bc *BlockChain) loadLastState() error { bc.hc.SetCurrentHeader(currentHeader) // Restore the last known head fast block - bc.currentFastBlock = bc.currentBlock + bc.currentFastBlock.Store(currentBlock) if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { if block := bc.GetBlockByHash(head); block != nil { - bc.currentFastBlock = block + bc.currentFastBlock.Store(block) } } // Issue a status log for the user + currentFastBlock := bc.CurrentFastBlock() + headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) - blockTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64()) - fastTd := bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()) + blockTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64()) + fastTd := bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()) log.Info("Loaded most recent local header", "number", currentHeader.Number, "hash", currentHeader.Hash(), "td", headerTd) - log.Info("Loaded most recent local full block", "number", bc.currentBlock.Number(), "hash", bc.currentBlock.Hash(), "td", blockTd) - log.Info("Loaded most recent local fast block", "number", bc.currentFastBlock.Number(), "hash", bc.currentFastBlock.Hash(), "td", fastTd) + log.Info("Loaded most recent local full block", "number", currentBlock.Number(), "hash", currentBlock.Hash(), "td", blockTd) + log.Info("Loaded most recent local fast block", "number", currentFastBlock.Number(), "hash", currentFastBlock.Hash(), "td", fastTd) return nil } @@ -279,30 +281,32 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.futureBlocks.Purge() // Rewind the block chain, ensuring we don't end up with a stateless head block - if bc.currentBlock != nil && currentHeader.Number.Uint64() < bc.currentBlock.NumberU64() { - bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()) + if currentBlock := bc.CurrentBlock(); currentBlock != nil && currentHeader.Number.Uint64() < currentBlock.NumberU64() { + bc.currentBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())) } - if bc.currentBlock != nil { - if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil { + if currentBlock := bc.CurrentBlock(); currentBlock != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { // Rewound state missing, rolled back to before pivot, reset to genesis - bc.currentBlock = nil + bc.currentBlock.Store(bc.genesisBlock) } } // Rewind the fast block in a simpleton way to the target head - if bc.currentFastBlock != nil && currentHeader.Number.Uint64() < bc.currentFastBlock.NumberU64() { - bc.currentFastBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()) + if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock != nil && currentHeader.Number.Uint64() < currentFastBlock.NumberU64() { + bc.currentFastBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())) } // If either blocks reached nil, reset to the genesis state - if bc.currentBlock == nil { - bc.currentBlock = bc.genesisBlock + if currentBlock := bc.CurrentBlock(); currentBlock == nil { + bc.currentBlock.Store(bc.genesisBlock) } - if bc.currentFastBlock == nil { - bc.currentFastBlock = bc.genesisBlock + if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock == nil { + bc.currentFastBlock.Store(bc.genesisBlock) } - if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil { + currentBlock := bc.CurrentBlock() + currentFastBlock := bc.CurrentFastBlock() + if err := WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil { log.Crit("Failed to reset head full block", "err", err) } - if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil { + if err := WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil { log.Crit("Failed to reset head fast block", "err", err) } return bc.loadLastState() @@ -321,7 +325,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error { } // If all checks out, manually set the head block bc.mu.Lock() - bc.currentBlock = block + bc.currentBlock.Store(block) bc.mu.Unlock() log.Info("Committed new head block", "number", block.Number(), "hash", hash) @@ -330,28 +334,19 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error { // GasLimit returns the gas limit of the current HEAD block. func (bc *BlockChain) GasLimit() uint64 { - bc.mu.RLock() - defer bc.mu.RUnlock() - - return bc.currentBlock.GasLimit() + return bc.CurrentBlock().GasLimit() } // CurrentBlock retrieves the current head block of the canonical chain. The // block is retrieved from the blockchain's internal cache. func (bc *BlockChain) CurrentBlock() *types.Block { - bc.mu.RLock() - defer bc.mu.RUnlock() - - return bc.currentBlock + return bc.currentBlock.Load().(*types.Block) } // CurrentFastBlock retrieves the current fast-sync head block of the canonical // chain. The block is retrieved from the blockchain's internal cache. func (bc *BlockChain) CurrentFastBlock() *types.Block { - bc.mu.RLock() - defer bc.mu.RUnlock() - - return bc.currentFastBlock + return bc.currentFastBlock.Load().(*types.Block) } // SetProcessor sets the processor required for making state modifications. @@ -416,10 +411,10 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error { } bc.genesisBlock = genesis bc.insert(bc.genesisBlock) - bc.currentBlock = bc.genesisBlock + bc.currentBlock.Store(bc.genesisBlock) bc.hc.SetGenesis(bc.genesisBlock.Header()) bc.hc.SetCurrentHeader(bc.genesisBlock.Header()) - bc.currentFastBlock = bc.genesisBlock + bc.currentFastBlock.Store(bc.genesisBlock) return nil } @@ -444,7 +439,7 @@ func (bc *BlockChain) repair(head **types.Block) error { // Export writes the active chain to the given writer. func (bc *BlockChain) Export(w io.Writer) error { - return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64()) + return bc.ExportN(w, uint64(0), bc.CurrentBlock().NumberU64()) } // ExportN writes a subset of the active chain to the given writer. @@ -488,7 +483,7 @@ func (bc *BlockChain) insert(block *types.Block) { if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head block hash", "err", err) } - bc.currentBlock = block + bc.currentBlock.Store(block) // If the block is better than our head or is on a different chain, force update heads if updateHeads { @@ -497,7 +492,7 @@ func (bc *BlockChain) insert(block *types.Block) { if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head fast block hash", "err", err) } - bc.currentFastBlock = block + bc.currentFastBlock.Store(block) } } @@ -648,22 +643,21 @@ func (bc *BlockChain) Stop() { bc.wg.Wait() // Ensure the state of a recent block is also stored to disk before exiting. - // It is fine if this state does not exist (fast start/stop cycle), but it is - // advisable to leave an N block gap from the head so 1) a restart loads up - // the last N blocks as sync assistance to remote nodes; 2) a restart during - // a (small) reorg doesn't require deep reprocesses; 3) chain "repair" from - // missing states are constantly tested. - // - // This may be tuned a bit on mainnet if its too annoying to reprocess the last - // N blocks. + // We're writing three different states to catch different restart scenarios: + // - HEAD: So we don't need to reprocess any blocks in the general case + // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle + // - HEAD-127: So we have a hard limit on the number of blocks reexecuted if !bc.cacheConfig.Disabled { triedb := bc.stateCache.TrieDB() - if number := bc.CurrentBlock().NumberU64(); number >= triesInMemory { - recent := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - triesInMemory + 1) - log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root()) - if err := triedb.Commit(recent.Root(), true); err != nil { - log.Error("Failed to commit recent state trie", "err", err) + for _, offset := range []uint64{0, 1, triesInMemory - 1} { + if number := bc.CurrentBlock().NumberU64(); number > offset { + recent := bc.GetBlockByNumber(number - offset) + + log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root()) + if err := triedb.Commit(recent.Root(), true); err != nil { + log.Error("Failed to commit recent state trie", "err", err) + } } } for !bc.triegc.Empty() { @@ -715,13 +709,15 @@ func (bc *BlockChain) Rollback(chain []common.Hash) { if currentHeader.Hash() == hash { bc.hc.SetCurrentHeader(bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1)) } - if bc.currentFastBlock.Hash() == hash { - bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1) - WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()) + if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock.Hash() == hash { + newFastBlock := bc.GetBlock(currentFastBlock.ParentHash(), currentFastBlock.NumberU64()-1) + bc.currentFastBlock.Store(newFastBlock) + WriteHeadFastBlockHash(bc.db, newFastBlock.Hash()) } - if bc.currentBlock.Hash() == hash { - bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1) - WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()) + if currentBlock := bc.CurrentBlock(); currentBlock.Hash() == hash { + newBlock := bc.GetBlock(currentBlock.ParentHash(), currentBlock.NumberU64()-1) + bc.currentBlock.Store(newBlock) + WriteHeadBlockHash(bc.db, newBlock.Hash()) } } } @@ -830,11 +826,12 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ bc.mu.Lock() head := blockChain[len(blockChain)-1] if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case - if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 { + currentFastBlock := bc.CurrentFastBlock() + if bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()).Cmp(td) < 0 { if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { log.Crit("Failed to update head fast block hash", "err", err) } - bc.currentFastBlock = head + bc.currentFastBlock.Store(head) } } bc.mu.Unlock() @@ -881,7 +878,8 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. bc.mu.Lock() defer bc.mu.Unlock() - localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64()) + currentBlock := bc.CurrentBlock() + localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64()) externTd := new(big.Int).Add(block.Difficulty(), ptd) // Irrelevant of the canonical status, write the block itself to the database @@ -956,14 +954,15 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. // Second clause in the if statement reduces the vulnerability to selfish mining. // Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf reorg := externTd.Cmp(localTd) > 0 + currentBlock = bc.CurrentBlock() if !reorg && externTd.Cmp(localTd) == 0 { // Split same-difficulty blocks by number, then at random - reorg = block.NumberU64() < bc.currentBlock.NumberU64() || (block.NumberU64() == bc.currentBlock.NumberU64() && mrand.Float64() < 0.5) + reorg = block.NumberU64() < currentBlock.NumberU64() || (block.NumberU64() == currentBlock.NumberU64() && mrand.Float64() < 0.5) } if reorg { // Reorganise the chain if the parent is not the head block - if block.ParentHash() != bc.currentBlock.Hash() { - if err := bc.reorg(bc.currentBlock, block); err != nil { + if block.ParentHash() != currentBlock.Hash() { + if err := bc.reorg(currentBlock, block); err != nil { return NonStatTy, err } } @@ -1092,7 +1091,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty case err == consensus.ErrPrunedAncestor: // Block competing with the canonical chain, store in the db, but don't process // until the competitor TD goes above the canonical TD - localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64()) + currentBlock := bc.CurrentBlock() + localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64()) externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty()) if localTd.Cmp(externTd) > 0 { if err = bc.WriteBlockWithoutState(block, externTd); err != nil { @@ -1481,9 +1481,6 @@ func (bc *BlockChain) writeHeader(header *types.Header) error { // CurrentHeader retrieves the current head header of the canonical chain. The // header is retrieved from the HeaderChain's internal cache. func (bc *BlockChain) CurrentHeader() *types.Header { - bc.mu.RLock() - defer bc.mu.RUnlock() - return bc.hc.CurrentHeader() } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 635379161..748cdc5c7 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -34,26 +34,6 @@ import ( "github.com/ethereum/go-ethereum/params" ) -// newTestBlockChain creates a blockchain without validation. -func newTestBlockChain(fake bool) *BlockChain { - db, _ := ethdb.NewMemDatabase() - gspec := &Genesis{ - Config: params.TestChainConfig, - Difficulty: big.NewInt(1), - } - gspec.MustCommit(db) - engine := ethash.NewFullFaker() - if !fake { - engine = ethash.NewTester() - } - blockchain, err := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) - if err != nil { - panic(err) - } - blockchain.SetValidator(bproc{}) - return blockchain -} - // Test fork of length N starting from block i func testFork(t *testing.T, blockchain *BlockChain, i, n int, full bool, comparator func(td1, td2 *big.Int)) { // Copy old chain up to #i into a new db @@ -183,13 +163,18 @@ func insertChain(done chan bool, blockchain *BlockChain, chain types.Blocks, t * } func TestLastBlock(t *testing.T) { - bchain := newTestBlockChain(false) - defer bchain.Stop() + _, blockchain, err := newCanonical(ethash.NewFaker(), 0, true) + if err != nil { + t.Fatalf("failed to create pristine chain: %v", err) + } + defer blockchain.Stop() - block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.db, 0)[0] - bchain.insert(block) - if block.Hash() != GetHeadBlockHash(bchain.db) { - t.Errorf("Write/Get HeadBlockHash failed") + blocks := makeBlockChain(blockchain.CurrentBlock(), 1, ethash.NewFullFaker(), blockchain.db, 0) + if _, err := blockchain.InsertChain(blocks); err != nil { + t.Fatalf("Failed to insert block: %v", err) + } + if blocks[len(blocks)-1].Hash() != GetHeadBlockHash(blockchain.db) { + t.Fatalf("Write/Get HeadBlockHash failed") } } @@ -337,55 +322,13 @@ func testBrokenChain(t *testing.T, full bool) { } } -type bproc struct{} - -func (bproc) ValidateBody(*types.Block) error { return nil } -func (bproc) ValidateState(block, parent *types.Block, state *state.StateDB, receipts types.Receipts, usedGas uint64) error { - return nil -} -func (bproc) Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error) { - return nil, nil, 0, nil -} - -func makeHeaderChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.Header { - blocks := makeBlockChainWithDiff(genesis, d, seed) - headers := make([]*types.Header, len(blocks)) - for i, block := range blocks { - headers[i] = block.Header() - } - return headers -} - -func makeBlockChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.Block { - var chain []*types.Block - for i, difficulty := range d { - header := &types.Header{ - Coinbase: common.Address{seed}, - Number: big.NewInt(int64(i + 1)), - Difficulty: big.NewInt(int64(difficulty)), - UncleHash: types.EmptyUncleHash, - TxHash: types.EmptyRootHash, - ReceiptHash: types.EmptyRootHash, - Time: big.NewInt(int64(i) + 1), - } - if i == 0 { - header.ParentHash = genesis.Hash() - } else { - header.ParentHash = chain[i-1].Hash() - } - block := types.NewBlockWithHeader(header) - chain = append(chain, block) - } - return chain -} - // Tests that reorganising a long difficult chain after a short easy one // overwrites the canonical numbers and links in the database. func TestReorgLongHeaders(t *testing.T) { testReorgLong(t, false) } func TestReorgLongBlocks(t *testing.T) { testReorgLong(t, true) } func testReorgLong(t *testing.T, full bool) { - testReorg(t, []int{1, 2, 4}, []int{1, 2, 3, 4}, 10, full) + testReorg(t, []int64{0, 0, -9}, []int64{0, 0, 0, -9}, 393280, full) } // Tests that reorganising a short difficult chain after a long easy one @@ -394,45 +337,82 @@ func TestReorgShortHeaders(t *testing.T) { testReorgShort(t, false) } func TestReorgShortBlocks(t *testing.T) { testReorgShort(t, true) } func testReorgShort(t *testing.T, full bool) { - testReorg(t, []int{1, 2, 3, 4}, []int{1, 10}, 11, full) + // Create a long easy chain vs. a short heavy one. Due to difficulty adjustment + // we need a fairly long chain of blocks with different difficulties for a short + // one to become heavyer than a long one. The 96 is an empirical value. + easy := make([]int64, 96) + for i := 0; i < len(easy); i++ { + easy[i] = 60 + } + diff := make([]int64, len(easy)-1) + for i := 0; i < len(diff); i++ { + diff[i] = -9 + } + testReorg(t, easy, diff, 12615120, full) } -func testReorg(t *testing.T, first, second []int, td int64, full bool) { - bc := newTestBlockChain(true) - defer bc.Stop() +func testReorg(t *testing.T, first, second []int64, td int64, full bool) { + // Create a pristine chain and database + db, blockchain, err := newCanonical(ethash.NewFaker(), 0, full) + if err != nil { + t.Fatalf("failed to create pristine chain: %v", err) + } + defer blockchain.Stop() // Insert an easy and a difficult chain afterwards + easyBlocks, _ := GenerateChain(params.TestChainConfig, blockchain.CurrentBlock(), ethash.NewFaker(), db, len(first), func(i int, b *BlockGen) { + b.OffsetTime(first[i]) + }) + diffBlocks, _ := GenerateChain(params.TestChainConfig, blockchain.CurrentBlock(), ethash.NewFaker(), db, len(second), func(i int, b *BlockGen) { + b.OffsetTime(second[i]) + }) if full { - bc.InsertChain(makeBlockChainWithDiff(bc.genesisBlock, first, 11)) - bc.InsertChain(makeBlockChainWithDiff(bc.genesisBlock, second, 22)) + if _, err := blockchain.InsertChain(easyBlocks); err != nil { + t.Fatalf("failed to insert easy chain: %v", err) + } + if _, err := blockchain.InsertChain(diffBlocks); err != nil { + t.Fatalf("failed to insert difficult chain: %v", err) + } } else { - bc.InsertHeaderChain(makeHeaderChainWithDiff(bc.genesisBlock, first, 11), 1) - bc.InsertHeaderChain(makeHeaderChainWithDiff(bc.genesisBlock, second, 22), 1) + easyHeaders := make([]*types.Header, len(easyBlocks)) + for i, block := range easyBlocks { + easyHeaders[i] = block.Header() + } + diffHeaders := make([]*types.Header, len(diffBlocks)) + for i, block := range diffBlocks { + diffHeaders[i] = block.Header() + } + if _, err := blockchain.InsertHeaderChain(easyHeaders, 1); err != nil { + t.Fatalf("failed to insert easy chain: %v", err) + } + if _, err := blockchain.InsertHeaderChain(diffHeaders, 1); err != nil { + t.Fatalf("failed to insert difficult chain: %v", err) + } } // Check that the chain is valid number and link wise if full { - prev := bc.CurrentBlock() - for block := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 1); block.NumberU64() != 0; prev, block = block, bc.GetBlockByNumber(block.NumberU64()-1) { + prev := blockchain.CurrentBlock() + for block := blockchain.GetBlockByNumber(blockchain.CurrentBlock().NumberU64() - 1); block.NumberU64() != 0; prev, block = block, blockchain.GetBlockByNumber(block.NumberU64()-1) { if prev.ParentHash() != block.Hash() { t.Errorf("parent block hash mismatch: have %x, want %x", prev.ParentHash(), block.Hash()) } } } else { - prev := bc.CurrentHeader() - for header := bc.GetHeaderByNumber(bc.CurrentHeader().Number.Uint64() - 1); header.Number.Uint64() != 0; prev, header = header, bc.GetHeaderByNumber(header.Number.Uint64()-1) { + prev := blockchain.CurrentHeader() + for header := blockchain.GetHeaderByNumber(blockchain.CurrentHeader().Number.Uint64() - 1); header.Number.Uint64() != 0; prev, header = header, blockchain.GetHeaderByNumber(header.Number.Uint64()-1) { if prev.ParentHash != header.Hash() { t.Errorf("parent header hash mismatch: have %x, want %x", prev.ParentHash, header.Hash()) } } } // Make sure the chain total difficulty is the correct one - want := new(big.Int).Add(bc.genesisBlock.Difficulty(), big.NewInt(td)) + want := new(big.Int).Add(blockchain.genesisBlock.Difficulty(), big.NewInt(td)) if full { - if have := bc.GetTdByHash(bc.CurrentBlock().Hash()); have.Cmp(want) != 0 { + if have := blockchain.GetTdByHash(blockchain.CurrentBlock().Hash()); have.Cmp(want) != 0 { t.Errorf("total difficulty mismatch: have %v, want %v", have, want) } } else { - if have := bc.GetTdByHash(bc.CurrentHeader().Hash()); have.Cmp(want) != 0 { + if have := blockchain.GetTdByHash(blockchain.CurrentHeader().Hash()); have.Cmp(want) != 0 { t.Errorf("total difficulty mismatch: have %v, want %v", have, want) } } @@ -443,19 +423,28 @@ func TestBadHeaderHashes(t *testing.T) { testBadHashes(t, false) } func TestBadBlockHashes(t *testing.T) { testBadHashes(t, true) } func testBadHashes(t *testing.T, full bool) { - bc := newTestBlockChain(true) - defer bc.Stop() + // Create a pristine chain and database + db, blockchain, err := newCanonical(ethash.NewFaker(), 0, full) + if err != nil { + t.Fatalf("failed to create pristine chain: %v", err) + } + defer blockchain.Stop() // Create a chain, ban a hash and try to import - var err error if full { - blocks := makeBlockChainWithDiff(bc.genesisBlock, []int{1, 2, 4}, 10) + blocks := makeBlockChain(blockchain.CurrentBlock(), 3, ethash.NewFaker(), db, 10) + BadHashes[blocks[2].Header().Hash()] = true - _, err = bc.InsertChain(blocks) + defer func() { delete(BadHashes, blocks[2].Header().Hash()) }() + + _, err = blockchain.InsertChain(blocks) } else { - headers := makeHeaderChainWithDiff(bc.genesisBlock, []int{1, 2, 4}, 10) + headers := makeHeaderChain(blockchain.CurrentHeader(), 3, ethash.NewFaker(), db, 10) + BadHashes[headers[2].Hash()] = true - _, err = bc.InsertHeaderChain(headers, 1) + defer func() { delete(BadHashes, headers[2].Hash()) }() + + _, err = blockchain.InsertHeaderChain(headers, 1) } if err != ErrBlacklistedHash { t.Errorf("error mismatch: have: %v, want: %v", err, ErrBlacklistedHash) @@ -468,40 +457,41 @@ func TestReorgBadHeaderHashes(t *testing.T) { testReorgBadHashes(t, false) } func TestReorgBadBlockHashes(t *testing.T) { testReorgBadHashes(t, true) } func testReorgBadHashes(t *testing.T, full bool) { - bc := newTestBlockChain(true) - defer bc.Stop() - + // Create a pristine chain and database + db, blockchain, err := newCanonical(ethash.NewFaker(), 0, full) + if err != nil { + t.Fatalf("failed to create pristine chain: %v", err) + } // Create a chain, import and ban afterwards - headers := makeHeaderChainWithDiff(bc.genesisBlock, []int{1, 2, 3, 4}, 10) - blocks := makeBlockChainWithDiff(bc.genesisBlock, []int{1, 2, 3, 4}, 10) + headers := makeHeaderChain(blockchain.CurrentHeader(), 4, ethash.NewFaker(), db, 10) + blocks := makeBlockChain(blockchain.CurrentBlock(), 4, ethash.NewFaker(), db, 10) if full { - if _, err := bc.InsertChain(blocks); err != nil { - t.Fatalf("failed to import blocks: %v", err) + if _, err = blockchain.InsertChain(blocks); err != nil { + t.Errorf("failed to import blocks: %v", err) } - if bc.CurrentBlock().Hash() != blocks[3].Hash() { - t.Errorf("last block hash mismatch: have: %x, want %x", bc.CurrentBlock().Hash(), blocks[3].Header().Hash()) + if blockchain.CurrentBlock().Hash() != blocks[3].Hash() { + t.Errorf("last block hash mismatch: have: %x, want %x", blockchain.CurrentBlock().Hash(), blocks[3].Header().Hash()) } BadHashes[blocks[3].Header().Hash()] = true defer func() { delete(BadHashes, blocks[3].Header().Hash()) }() } else { - if _, err := bc.InsertHeaderChain(headers, 1); err != nil { - t.Fatalf("failed to import headers: %v", err) + if _, err = blockchain.InsertHeaderChain(headers, 1); err != nil { + t.Errorf("failed to import headers: %v", err) } - if bc.CurrentHeader().Hash() != headers[3].Hash() { - t.Errorf("last header hash mismatch: have: %x, want %x", bc.CurrentHeader().Hash(), headers[3].Hash()) + if blockchain.CurrentHeader().Hash() != headers[3].Hash() { + t.Errorf("last header hash mismatch: have: %x, want %x", blockchain.CurrentHeader().Hash(), headers[3].Hash()) } BadHashes[headers[3].Hash()] = true defer func() { delete(BadHashes, headers[3].Hash()) }() } + blockchain.Stop() // Create a new BlockChain and check that it rolled back the state. - ncm, err := NewBlockChain(bc.db, nil, bc.chainConfig, ethash.NewFaker(), vm.Config{}) + ncm, err := NewBlockChain(blockchain.db, nil, blockchain.chainConfig, ethash.NewFaker(), vm.Config{}) if err != nil { t.Fatalf("failed to create new chain manager: %v", err) } - defer ncm.Stop() - if full { if ncm.CurrentBlock().Hash() != blocks[2].Header().Hash() { t.Errorf("last block hash mismatch: have: %x, want %x", ncm.CurrentBlock().Hash(), blocks[2].Header().Hash()) @@ -514,6 +504,7 @@ func testReorgBadHashes(t *testing.T, full bool) { t.Errorf("last header hash mismatch: have: %x, want %x", ncm.CurrentHeader().Hash(), headers[2].Hash()) } } + ncm.Stop() } // Tests chain insertions in the face of one entity containing an invalid nonce. @@ -989,10 +980,13 @@ done: // Tests if the canonical block can be fetched from the database during chain insertion. func TestCanonicalBlockRetrieval(t *testing.T) { - bc := newTestBlockChain(true) - defer bc.Stop() + _, blockchain, err := newCanonical(ethash.NewFaker(), 0, true) + if err != nil { + t.Fatalf("failed to create pristine chain: %v", err) + } + defer blockchain.Stop() - chain, _ := GenerateChain(bc.chainConfig, bc.genesisBlock, ethash.NewFaker(), bc.db, 10, func(i int, gen *BlockGen) {}) + chain, _ := GenerateChain(blockchain.chainConfig, blockchain.genesisBlock, ethash.NewFaker(), blockchain.db, 10, func(i int, gen *BlockGen) {}) var pend sync.WaitGroup pend.Add(len(chain)) @@ -1003,14 +997,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) { // try to retrieve a block by its canonical hash and see if the block data can be retrieved. for { - ch := GetCanonicalHash(bc.db, block.NumberU64()) + ch := GetCanonicalHash(blockchain.db, block.NumberU64()) if ch == (common.Hash{}) { continue // busy wait for canonical hash to be written } if ch != block.Hash() { t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex()) } - fb := GetBlock(bc.db, ch, block.NumberU64()) + fb := GetBlock(blockchain.db, ch, block.NumberU64()) if fb == nil { t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex()) } @@ -1021,7 +1015,7 @@ func TestCanonicalBlockRetrieval(t *testing.T) { } }(chain[i]) - if _, err := bc.InsertChain(types.Blocks{chain[i]}); err != nil { + if _, err := blockchain.InsertChain(types.Blocks{chain[i]}); err != nil { t.Fatalf("failed to insert block %d: %v", i, err) } } diff --git a/core/database_util.go b/core/database_util.go index c6b125dae..8c4698985 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -47,6 +47,7 @@ var ( headHeaderKey = []byte("LastHeader") headBlockKey = []byte("LastBlock") headFastKey = []byte("LastFast") + trieSyncKey = []byte("TrieSync") // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header @@ -70,8 +71,8 @@ var ( ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error - preimageCounter = metrics.NewCounter("db/preimage/total") - preimageHitCounter = metrics.NewCounter("db/preimage/hits") + preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) + preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) ) // TxLookupEntry is a positional metadata to help looking up the data content of @@ -146,6 +147,16 @@ func GetHeadFastBlockHash(db DatabaseReader) common.Hash { return common.BytesToHash(data) } +// GetTrieSyncProgress retrieves the number of tries nodes fast synced to allow +// reportinc correct numbers across restarts. +func GetTrieSyncProgress(db DatabaseReader) uint64 { + data, _ := db.Get(trieSyncKey) + if len(data) == 0 { + return 0 + } + return new(big.Int).SetBytes(data).Uint64() +} + // GetHeaderRLP retrieves a block header in its raw RLP database encoding, or nil // if the header's not found. func GetHeaderRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { @@ -374,6 +385,15 @@ func WriteHeadFastBlockHash(db ethdb.Putter, hash common.Hash) error { return nil } +// WriteTrieSyncProgress stores the fast sync trie process counter to support +// retrieving it across restarts. +func WriteTrieSyncProgress(db ethdb.Putter, count uint64) error { + if err := db.Put(trieSyncKey, new(big.Int).SetUint64(count).Bytes()); err != nil { + log.Crit("Failed to store fast sync trie progress", "err", err) + } + return nil +} + // WriteHeader serializes a block header into the database. func WriteHeader(db ethdb.Putter, header *types.Header) error { data, err := rlp.EncodeToBytes(header) diff --git a/core/fees.go b/core/fees.go deleted file mode 100644 index 83275ea36..000000000 --- a/core/fees.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package core - -import ( - "math/big" -) - -var BlockReward = big.NewInt(5e+18) diff --git a/core/genesis_test.go b/core/genesis_test.go index cd548d4b1..052ded699 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -118,10 +118,12 @@ func TestSetupGenesis(t *testing.T) { // Commit the 'old' genesis block with Homestead transition at #2. // Advance to block #4, past the homestead transition block of customg. genesis := oldcustomg.MustCommit(db) + bc, _ := NewBlockChain(db, nil, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{}) defer bc.Stop() - bc.SetValidator(bproc{}) - bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0)) + + blocks, _ := GenerateChain(oldcustomg.Config, genesis, ethash.NewFaker(), db, 4, nil) + bc.InsertChain(blocks) bc.CurrentBlock() // This should return a compatibility error. return SetupGenesisBlock(db, &customg) diff --git a/core/headerchain.go b/core/headerchain.go index 0e5215293..73cd5d2c4 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -32,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/hashicorp/golang-lru" + "sync/atomic" ) const ( @@ -51,8 +52,8 @@ type HeaderChain struct { chainDb ethdb.Database genesisHeader *types.Header - currentHeader *types.Header // Current head of the header chain (may be above the block chain!) - currentHeaderHash common.Hash // Hash of the current head of the header chain (prevent recomputing all the time) + currentHeader atomic.Value // Current head of the header chain (may be above the block chain!) + currentHeaderHash common.Hash // Hash of the current head of the header chain (prevent recomputing all the time) headerCache *lru.Cache // Cache for the most recent block headers tdCache *lru.Cache // Cache for the most recent block total difficulties @@ -95,13 +96,13 @@ func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine c return nil, ErrNoGenesis } - hc.currentHeader = hc.genesisHeader + hc.currentHeader.Store(hc.genesisHeader) if head := GetHeadBlockHash(chainDb); head != (common.Hash{}) { if chead := hc.GetHeaderByHash(head); chead != nil { - hc.currentHeader = chead + hc.currentHeader.Store(chead) } } - hc.currentHeaderHash = hc.currentHeader.Hash() + hc.currentHeaderHash = hc.CurrentHeader().Hash() return hc, nil } @@ -139,7 +140,7 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if ptd == nil { return NonStatTy, consensus.ErrUnknownAncestor } - localTd := hc.GetTd(hc.currentHeaderHash, hc.currentHeader.Number.Uint64()) + localTd := hc.GetTd(hc.currentHeaderHash, hc.CurrentHeader().Number.Uint64()) externTd := new(big.Int).Add(header.Difficulty, ptd) // Irrelevant of the canonical status, write the td and header to the database @@ -181,7 +182,8 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if err := WriteHeadHeaderHash(hc.chainDb, hash); err != nil { log.Crit("Failed to insert head header hash", "err", err) } - hc.currentHeaderHash, hc.currentHeader = hash, types.CopyHeader(header) + hc.currentHeaderHash = hash + hc.currentHeader.Store(types.CopyHeader(header)) status = CanonStatTy } else { @@ -383,7 +385,7 @@ func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header { // CurrentHeader retrieves the current head header of the canonical chain. The // header is retrieved from the HeaderChain's internal cache. func (hc *HeaderChain) CurrentHeader() *types.Header { - return hc.currentHeader + return hc.currentHeader.Load().(*types.Header) } // SetCurrentHeader sets the current head header of the canonical chain. @@ -391,7 +393,7 @@ func (hc *HeaderChain) SetCurrentHeader(head *types.Header) { if err := WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil { log.Crit("Failed to insert head header hash", "err", err) } - hc.currentHeader = head + hc.currentHeader.Store(head) hc.currentHeaderHash = head.Hash() } @@ -403,19 +405,20 @@ type DeleteCallback func(common.Hash, uint64) // will be deleted and the new one set. func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { height := uint64(0) - if hc.currentHeader != nil { - height = hc.currentHeader.Number.Uint64() + + if hdr := hc.CurrentHeader(); hdr != nil { + height = hdr.Number.Uint64() } - for hc.currentHeader != nil && hc.currentHeader.Number.Uint64() > head { - hash := hc.currentHeader.Hash() - num := hc.currentHeader.Number.Uint64() + for hdr := hc.CurrentHeader(); hdr != nil && hdr.Number.Uint64() > head; hdr = hc.CurrentHeader() { + hash := hdr.Hash() + num := hdr.Number.Uint64() if delFn != nil { delFn(hash, num) } DeleteHeader(hc.chainDb, hash, num) DeleteTd(hc.chainDb, hash, num) - hc.currentHeader = hc.GetHeader(hc.currentHeader.ParentHash, hc.currentHeader.Number.Uint64()-1) + hc.currentHeader.Store(hc.GetHeader(hdr.ParentHash, hdr.Number.Uint64()-1)) } // Roll back the canonical chain numbering for i := height; i > head; i-- { @@ -426,10 +429,10 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { hc.tdCache.Purge() hc.numberCache.Purge() - if hc.currentHeader == nil { - hc.currentHeader = hc.genesisHeader + if hc.CurrentHeader() == nil { + hc.currentHeader.Store(hc.genesisHeader) } - hc.currentHeaderHash = hc.currentHeader.Hash() + hc.currentHeaderHash = hc.CurrentHeader().Hash() if err := WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil { log.Crit("Failed to reset head header hash", "err", err) diff --git a/core/tx_pool.go b/core/tx_pool.go index dc3ddc423..0534fe57a 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -87,20 +87,20 @@ var ( var ( // Metrics for the pending pool - pendingDiscardCounter = metrics.NewCounter("txpool/pending/discard") - pendingReplaceCounter = metrics.NewCounter("txpool/pending/replace") - pendingRateLimitCounter = metrics.NewCounter("txpool/pending/ratelimit") // Dropped due to rate limiting - pendingNofundsCounter = metrics.NewCounter("txpool/pending/nofunds") // Dropped due to out-of-funds + pendingDiscardCounter = metrics.NewRegisteredCounter("txpool/pending/discard", nil) + pendingReplaceCounter = metrics.NewRegisteredCounter("txpool/pending/replace", nil) + pendingRateLimitCounter = metrics.NewRegisteredCounter("txpool/pending/ratelimit", nil) // Dropped due to rate limiting + pendingNofundsCounter = metrics.NewRegisteredCounter("txpool/pending/nofunds", nil) // Dropped due to out-of-funds // Metrics for the queued pool - queuedDiscardCounter = metrics.NewCounter("txpool/queued/discard") - queuedReplaceCounter = metrics.NewCounter("txpool/queued/replace") - queuedRateLimitCounter = metrics.NewCounter("txpool/queued/ratelimit") // Dropped due to rate limiting - queuedNofundsCounter = metrics.NewCounter("txpool/queued/nofunds") // Dropped due to out-of-funds + queuedDiscardCounter = metrics.NewRegisteredCounter("txpool/queued/discard", nil) + queuedReplaceCounter = metrics.NewRegisteredCounter("txpool/queued/replace", nil) + queuedRateLimitCounter = metrics.NewRegisteredCounter("txpool/queued/ratelimit", nil) // Dropped due to rate limiting + queuedNofundsCounter = metrics.NewRegisteredCounter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds // General tx metrics - invalidTxCounter = metrics.NewCounter("txpool/invalid") - underpricedTxCounter = metrics.NewCounter("txpool/underpriced") + invalidTxCounter = metrics.NewRegisteredCounter("txpool/invalid", nil) + underpricedTxCounter = metrics.NewRegisteredCounter("txpool/underpriced", nil) ) // TxStatus is the current status of a transaction as seen by the pool. diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 7344b6043..237450ea9 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -251,26 +251,12 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { return common.LeftPadBytes(base.Exp(base, exp, mod).Bytes(), int(modLen)), nil } -var ( - // errNotOnCurve is returned if a point being unmarshalled as a bn256 elliptic - // curve point is not on the curve. - errNotOnCurve = errors.New("point not on elliptic curve") - - // errInvalidCurvePoint is returned if a point being unmarshalled as a bn256 - // elliptic curve point is invalid. - errInvalidCurvePoint = errors.New("invalid elliptic curve point") -) - // newCurvePoint unmarshals a binary blob into a bn256 elliptic curve point, // returning it, or an error if the point is invalid. func newCurvePoint(blob []byte) (*bn256.G1, error) { - p, onCurve := new(bn256.G1).Unmarshal(blob) - if !onCurve { - return nil, errNotOnCurve - } - gx, gy, _, _ := p.CurvePoints() - if gx.Cmp(bn256.P) >= 0 || gy.Cmp(bn256.P) >= 0 { - return nil, errInvalidCurvePoint + p := new(bn256.G1) + if _, err := p.Unmarshal(blob); err != nil { + return nil, err } return p, nil } @@ -278,14 +264,9 @@ func newCurvePoint(blob []byte) (*bn256.G1, error) { // newTwistPoint unmarshals a binary blob into a bn256 elliptic curve point, // returning it, or an error if the point is invalid. func newTwistPoint(blob []byte) (*bn256.G2, error) { - p, onCurve := new(bn256.G2).Unmarshal(blob) - if !onCurve { - return nil, errNotOnCurve - } - x2, y2, _, _ := p.CurvePoints() - if x2.Real().Cmp(bn256.P) >= 0 || x2.Imag().Cmp(bn256.P) >= 0 || - y2.Real().Cmp(bn256.P) >= 0 || y2.Imag().Cmp(bn256.P) >= 0 { - return nil, errInvalidCurvePoint + p := new(bn256.G2) + if _, err := p.Unmarshal(blob); err != nil { + return nil, err } return p, nil } diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 766172501..6daf4e10d 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -302,6 +302,66 @@ func opMulmod(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *S return nil, nil } +// opSHL implements Shift Left +// The SHL instruction (shift left) pops 2 values from the stack, first arg1 and then arg2, +// and pushes on the stack arg2 shifted to the left by arg1 number of bits. +func opSHL(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { + // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards + shift, value := math.U256(stack.pop()), math.U256(stack.peek()) + defer evm.interpreter.intPool.put(shift) // First operand back into the pool + + if shift.Cmp(common.Big256) >= 0 { + value.SetUint64(0) + return nil, nil + } + n := uint(shift.Uint64()) + math.U256(value.Lsh(value, n)) + + return nil, nil +} + +// opSHR implements Logical Shift Right +// The SHR instruction (logical shift right) pops 2 values from the stack, first arg1 and then arg2, +// and pushes on the stack arg2 shifted to the right by arg1 number of bits with zero fill. +func opSHR(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { + // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards + shift, value := math.U256(stack.pop()), math.U256(stack.peek()) + defer evm.interpreter.intPool.put(shift) // First operand back into the pool + + if shift.Cmp(common.Big256) >= 0 { + value.SetUint64(0) + return nil, nil + } + n := uint(shift.Uint64()) + math.U256(value.Rsh(value, n)) + + return nil, nil +} + +// opSAR implements Arithmetic Shift Right +// The SAR instruction (arithmetic shift right) pops 2 values from the stack, first arg1 and then arg2, +// and pushes on the stack arg2 shifted to the right by arg1 number of bits with sign extension. +func opSAR(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { + // Note, S256 returns (potentially) a new bigint, so we're popping, not peeking this one + shift, value := math.U256(stack.pop()), math.S256(stack.pop()) + defer evm.interpreter.intPool.put(shift) // First operand back into the pool + + if shift.Cmp(common.Big256) >= 0 { + if value.Sign() > 0 { + value.SetUint64(0) + } else { + value.SetInt64(-1) + } + stack.push(math.U256(value)) + return nil, nil + } + n := uint(shift.Uint64()) + value.Rsh(value, n) + stack.push(math.U256(value)) + + return nil, nil +} + func opSha3(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { offset, size := stack.pop(), stack.pop() data := memory.Get(offset.Int64(), size.Int64()) diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 180433ea8..eef4328bd 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -24,6 +24,48 @@ import ( "github.com/ethereum/go-ethereum/params" ) +type twoOperandTest struct { + x string + y string + expected string +} + +func testTwoOperandOp(t *testing.T, tests []twoOperandTest, opFn func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error)) { + var ( + env = NewEVM(Context{}, nil, params.TestChainConfig, Config{EnableJit: false, ForceJit: false}) + stack = newstack() + pc = uint64(0) + ) + for i, test := range tests { + x := new(big.Int).SetBytes(common.Hex2Bytes(test.x)) + shift := new(big.Int).SetBytes(common.Hex2Bytes(test.y)) + expected := new(big.Int).SetBytes(common.Hex2Bytes(test.expected)) + stack.push(x) + stack.push(shift) + opFn(&pc, env, nil, nil, stack) + actual := stack.pop() + if actual.Cmp(expected) != 0 { + t.Errorf("Testcase %d, expected %v, got %v", i, expected, actual) + } + // Check pool usage + // 1.pool is not allowed to contain anything on the stack + // 2.pool is not allowed to contain the same pointers twice + if env.interpreter.intPool.pool.len() > 0 { + + poolvals := make(map[*big.Int]struct{}) + poolvals[actual] = struct{}{} + + for env.interpreter.intPool.pool.len() > 0 { + key := env.interpreter.intPool.get() + if _, exist := poolvals[key]; exist { + t.Errorf("Testcase %d, pool contains double-entry", i) + } + poolvals[key] = struct{}{} + } + } + } +} + func TestByteOp(t *testing.T) { var ( env = NewEVM(Context{}, nil, params.TestChainConfig, Config{EnableJit: false, ForceJit: false}) @@ -57,6 +99,98 @@ func TestByteOp(t *testing.T) { } } +func TestSHL(t *testing.T) { + // Testcases from https://github.com/ethereum/EIPs/blob/master/EIPS/eip-145.md#shl-shift-left + tests := []twoOperandTest{ + {"0000000000000000000000000000000000000000000000000000000000000001", "00", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "01", "0000000000000000000000000000000000000000000000000000000000000002"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "ff", "8000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "0100", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "0101", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "00", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "01", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ff", "8000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0100", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000000", "01", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "01", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"}, + } + testTwoOperandOp(t, tests, opSHL) +} + +func TestSHR(t *testing.T) { + // Testcases from https://github.com/ethereum/EIPs/blob/master/EIPS/eip-145.md#shr-logical-shift-right + tests := []twoOperandTest{ + {"0000000000000000000000000000000000000000000000000000000000000001", "00", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "01", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "01", "4000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "ff", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "0100", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "0101", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "00", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "01", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ff", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0100", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000000", "01", "0000000000000000000000000000000000000000000000000000000000000000"}, + } + testTwoOperandOp(t, tests, opSHR) +} + +func TestSAR(t *testing.T) { + // Testcases from https://github.com/ethereum/EIPs/blob/master/EIPS/eip-145.md#sar-arithmetic-shift-right + tests := []twoOperandTest{ + {"0000000000000000000000000000000000000000000000000000000000000001", "00", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "01", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "01", "c000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "ff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "0100", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"8000000000000000000000000000000000000000000000000000000000000000", "0101", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "00", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "01", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0100", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {"0000000000000000000000000000000000000000000000000000000000000000", "01", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"4000000000000000000000000000000000000000000000000000000000000000", "fe", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "f8", "000000000000000000000000000000000000000000000000000000000000007f"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "fe", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0100", "0000000000000000000000000000000000000000000000000000000000000000"}, + } + + testTwoOperandOp(t, tests, opSAR) +} + +func TestSGT(t *testing.T) { + tests := []twoOperandTest{ + {"0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000001", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + } + testTwoOperandOp(t, tests, opSgt) +} + +func TestSLT(t *testing.T) { + tests := []twoOperandTest{ + {"0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"0000000000000000000000000000000000000000000000000000000000000001", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001"}, + {"8000000000000000000000000000000000000000000000000000000000000001", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"8000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"}, + {"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001"}, + } + testTwoOperandOp(t, tests, opSlt) +} + func opBenchmark(bench *testing.B, op func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error), args ...string) { var ( env = NewEVM(Context{}, nil, params.TestChainConfig, Config{EnableJit: false, ForceJit: false}) @@ -259,3 +393,22 @@ func BenchmarkOpMulmod(b *testing.B) { opBenchmark(b, opMulmod, x, y, z) } + +func BenchmarkOpSHL(b *testing.B) { + x := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201ffffffff" + y := "ff" + + opBenchmark(b, opSHL, x, y) +} +func BenchmarkOpSHR(b *testing.B) { + x := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201ffffffff" + y := "ff" + + opBenchmark(b, opSHR, x, y) +} +func BenchmarkOpSAR(b *testing.B) { + x := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201ffffffff" + y := "ff" + + opBenchmark(b, opSAR, x, y) +} diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 82a6d3de6..95490adfc 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -37,8 +37,6 @@ type Config struct { // NoRecursion disabled Interpreter call, callcode, // delegate call and create. NoRecursion bool - // Disable gas metering - DisableGasMetering bool // Enable recording of SHA3/keccak preimages EnablePreimageRecording bool // JumpTable contains the EVM instruction table. This @@ -68,6 +66,8 @@ func NewInterpreter(evm *EVM, cfg Config) *Interpreter { // we'll set the default jump table. if !cfg.JumpTable[STOP].valid { switch { + case evm.ChainConfig().IsConstantinople(evm.BlockNumber): + cfg.JumpTable = constantinopleInstructionSet case evm.ChainConfig().IsByzantium(evm.BlockNumber): cfg.JumpTable = byzantiumInstructionSet case evm.ChainConfig().IsHomestead(evm.BlockNumber): @@ -187,14 +187,11 @@ func (in *Interpreter) Run(contract *Contract, input []byte) (ret []byte, err er return nil, errGasUintOverflow } } - - if !in.cfg.DisableGasMetering { - // consume the gas and return an error if not enough gas is available. - // cost is explicitly set so that the capture state defer method cas get the proper cost - cost, err = operation.gasCost(in.gasTable, in.evm, contract, stack, mem, memorySize) - if err != nil || !contract.UseGas(cost) { - return nil, ErrOutOfGas - } + // consume the gas and return an error if not enough gas is available. + // cost is explicitly set so that the capture state defer method cas get the proper cost + cost, err = operation.gasCost(in.gasTable, in.evm, contract, stack, mem, memorySize) + if err != nil || !contract.UseGas(cost) { + return nil, ErrOutOfGas } if memorySize > 0 { mem.Resize(memorySize) diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go index a1c5ad9c6..338994135 100644 --- a/core/vm/jump_table.go +++ b/core/vm/jump_table.go @@ -51,11 +51,38 @@ type operation struct { } var ( - frontierInstructionSet = NewFrontierInstructionSet() - homesteadInstructionSet = NewHomesteadInstructionSet() - byzantiumInstructionSet = NewByzantiumInstructionSet() + frontierInstructionSet = NewFrontierInstructionSet() + homesteadInstructionSet = NewHomesteadInstructionSet() + byzantiumInstructionSet = NewByzantiumInstructionSet() + constantinopleInstructionSet = NewConstantinopleInstructionSet() ) +// NewConstantinopleInstructionSet returns the frontier, homestead +// byzantium and contantinople instructions. +func NewConstantinopleInstructionSet() [256]operation { + // instructions that can be executed during the byzantium phase. + instructionSet := NewByzantiumInstructionSet() + instructionSet[SHL] = operation{ + execute: opSHL, + gasCost: constGasFunc(GasFastestStep), + validateStack: makeStackFunc(2, 1), + valid: true, + } + instructionSet[SHR] = operation{ + execute: opSHR, + gasCost: constGasFunc(GasFastestStep), + validateStack: makeStackFunc(2, 1), + valid: true, + } + instructionSet[SAR] = operation{ + execute: opSAR, + gasCost: constGasFunc(GasFastestStep), + validateStack: makeStackFunc(2, 1), + valid: true, + } + return instructionSet +} + // NewByzantiumInstructionSet returns the frontier, homestead and // byzantium instructions. func NewByzantiumInstructionSet() [256]operation { diff --git a/core/vm/opcodes.go b/core/vm/opcodes.go index 0c6550735..7fe55b72f 100644 --- a/core/vm/opcodes.go +++ b/core/vm/opcodes.go @@ -63,6 +63,9 @@ const ( XOR NOT BYTE + SHL + SHR + SAR SHA3 = 0x20 ) @@ -234,6 +237,9 @@ var opCodeToString = map[OpCode]string{ OR: "OR", XOR: "XOR", BYTE: "BYTE", + SHL: "SHL", + SHR: "SHR", + SAR: "SAR", ADDMOD: "ADDMOD", MULMOD: "MULMOD", @@ -400,6 +406,9 @@ var stringToOp = map[string]OpCode{ "OR": OR, "XOR": XOR, "BYTE": BYTE, + "SHL": SHL, + "SHR": SHR, + "SAR": SAR, "ADDMOD": ADDMOD, "MULMOD": MULMOD, "SHA3": SHA3, diff --git a/crypto/bn256/bn256_amd64.go b/crypto/bn256/bn256_amd64.go new file mode 100644 index 000000000..35b4839c2 --- /dev/null +++ b/crypto/bn256/bn256_amd64.go @@ -0,0 +1,63 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +// +build amd64,!appengine,!gccgo + +// Package bn256 implements the Optimal Ate pairing over a 256-bit Barreto-Naehrig curve. +package bn256 + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare" +) + +// G1 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G1 struct { + bn256.G1 +} + +// Add sets e to a+b and then returns e. +func (e *G1) Add(a, b *G1) *G1 { + e.G1.Add(&a.G1, &b.G1) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { + e.G1.ScalarMult(&a.G1, k) + return e +} + +// G2 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G2 struct { + bn256.G2 +} + +// PairingCheck calculates the Optimal Ate pairing for a set of points. +func PairingCheck(a []*G1, b []*G2) bool { + as := make([]*bn256.G1, len(a)) + for i, p := range a { + as[i] = &p.G1 + } + bs := make([]*bn256.G2, len(b)) + for i, p := range b { + bs[i] = &p.G2 + } + return bn256.PairingCheck(as, bs) +} diff --git a/crypto/bn256/bn256_other.go b/crypto/bn256/bn256_other.go new file mode 100644 index 000000000..81977a0a8 --- /dev/null +++ b/crypto/bn256/bn256_other.go @@ -0,0 +1,63 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +// +build !amd64 appengine gccgo + +// Package bn256 implements the Optimal Ate pairing over a 256-bit Barreto-Naehrig curve. +package bn256 + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/crypto/bn256/google" +) + +// G1 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G1 struct { + bn256.G1 +} + +// Add sets e to a+b and then returns e. +func (e *G1) Add(a, b *G1) *G1 { + e.G1.Add(&a.G1, &b.G1) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { + e.G1.ScalarMult(&a.G1, k) + return e +} + +// G2 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G2 struct { + bn256.G2 +} + +// PairingCheck calculates the Optimal Ate pairing for a set of points. +func PairingCheck(a []*G1, b []*G2) bool { + as := make([]*bn256.G1, len(a)) + for i, p := range a { + as[i] = &p.G1 + } + bs := make([]*bn256.G2, len(b)) + for i, p := range b { + bs[i] = &p.G2 + } + return bn256.PairingCheck(as, bs) +} diff --git a/crypto/bn256/cloudflare/bn256.go b/crypto/bn256/cloudflare/bn256.go new file mode 100644 index 000000000..c6ea2d07e --- /dev/null +++ b/crypto/bn256/cloudflare/bn256.go @@ -0,0 +1,481 @@ +// Package bn256 implements a particular bilinear group at the 128-bit security +// level. +// +// Bilinear groups are the basis of many of the new cryptographic protocols that +// have been proposed over the past decade. They consist of a triplet of groups +// (G₁, G₂ and GT) such that there exists a function e(g₁ˣ,g₂ʸ)=gTˣʸ (where gₓ +// is a generator of the respective group). That function is called a pairing +// function. +// +// This package specifically implements the Optimal Ate pairing over a 256-bit +// Barreto-Naehrig curve as described in +// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible +// with the implementation described in that paper. +package bn256 + +import ( + "crypto/rand" + "errors" + "io" + "math/big" +) + +func randomK(r io.Reader) (k *big.Int, err error) { + for { + k, err = rand.Int(r, Order) + if k.Sign() > 0 || err != nil { + return + } + } +} + +// G1 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G1 struct { + p *curvePoint +} + +// RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r. +func RandomG1(r io.Reader) (*big.Int, *G1, error) { + k, err := randomK(r) + if err != nil { + return nil, nil, err + } + + return k, new(G1).ScalarBaseMult(k), nil +} + +func (g *G1) String() string { + return "bn256.G1" + g.p.String() +} + +// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// returns e. +func (e *G1) ScalarBaseMult(k *big.Int) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Mul(curveGen, k) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Mul(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *G1) Add(a, b *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Add(a.p, b.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *G1) Neg(a *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Neg(a.p) + return e +} + +// Set sets e to a and then returns e. +func (e *G1) Set(a *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Set(a.p) + return e +} + +// Marshal converts e to a byte slice. +func (e *G1) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + e.p.MakeAffine() + ret := make([]byte, numBytes*2) + if e.p.IsInfinity() { + return ret + } + temp := &gfP{} + + montDecode(temp, &e.p.x) + temp.Marshal(ret) + montDecode(temp, &e.p.y) + temp.Marshal(ret[numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G1) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + if len(m) < 2*numBytes { + return nil, errors.New("bn256: not enough data") + } + // Unmarshal the points and check their caps + if e.p == nil { + e.p = &curvePoint{} + } else { + e.p.x, e.p.y = gfP{0}, gfP{0} + } + var err error + if err = e.p.x.Unmarshal(m); err != nil { + return nil, err + } + if err = e.p.y.Unmarshal(m[numBytes:]); err != nil { + return nil, err + } + // Encode into Montgomery form and ensure it's on the curve + montEncode(&e.p.x, &e.p.x) + montEncode(&e.p.y, &e.p.y) + + zero := gfP{0} + if e.p.x == zero && e.p.y == zero { + // This is the point at infinity. + e.p.y = *newGFp(1) + e.p.z = gfP{0} + e.p.t = gfP{0} + } else { + e.p.z = *newGFp(1) + e.p.t = *newGFp(1) + + if !e.p.IsOnCurve() { + return nil, errors.New("bn256: malformed point") + } + } + return m[2*numBytes:], nil +} + +// G2 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G2 struct { + p *twistPoint +} + +// RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r. +func RandomG2(r io.Reader) (*big.Int, *G2, error) { + k, err := randomK(r) + if err != nil { + return nil, nil, err + } + + return k, new(G2).ScalarBaseMult(k), nil +} + +func (e *G2) String() string { + return "bn256.G2" + e.p.String() +} + +// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// returns out. +func (e *G2) ScalarBaseMult(k *big.Int) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Mul(twistGen, k) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Mul(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *G2) Add(a, b *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Add(a.p, b.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *G2) Neg(a *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Neg(a.p) + return e +} + +// Set sets e to a and then returns e. +func (e *G2) Set(a *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Set(a.p) + return e +} + +// Marshal converts e into a byte slice. +func (e *G2) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if e.p == nil { + e.p = &twistPoint{} + } + + e.p.MakeAffine() + ret := make([]byte, numBytes*4) + if e.p.IsInfinity() { + return ret + } + temp := &gfP{} + + montDecode(temp, &e.p.x.x) + temp.Marshal(ret) + montDecode(temp, &e.p.x.y) + temp.Marshal(ret[numBytes:]) + montDecode(temp, &e.p.y.x) + temp.Marshal(ret[2*numBytes:]) + montDecode(temp, &e.p.y.y) + temp.Marshal(ret[3*numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G2) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + if len(m) < 4*numBytes { + return nil, errors.New("bn256: not enough data") + } + // Unmarshal the points and check their caps + if e.p == nil { + e.p = &twistPoint{} + } + var err error + if err = e.p.x.x.Unmarshal(m); err != nil { + return nil, err + } + if err = e.p.x.y.Unmarshal(m[numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.Unmarshal(m[2*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.Unmarshal(m[3*numBytes:]); err != nil { + return nil, err + } + // Encode into Montgomery form and ensure it's on the curve + montEncode(&e.p.x.x, &e.p.x.x) + montEncode(&e.p.x.y, &e.p.x.y) + montEncode(&e.p.y.x, &e.p.y.x) + montEncode(&e.p.y.y, &e.p.y.y) + + if e.p.x.IsZero() && e.p.y.IsZero() { + // This is the point at infinity. + e.p.y.SetOne() + e.p.z.SetZero() + e.p.t.SetZero() + } else { + e.p.z.SetOne() + e.p.t.SetOne() + + if !e.p.IsOnCurve() { + return nil, errors.New("bn256: malformed point") + } + } + return m[4*numBytes:], nil +} + +// GT is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type GT struct { + p *gfP12 +} + +// Pair calculates an Optimal Ate pairing. +func Pair(g1 *G1, g2 *G2) *GT { + return >{optimalAte(g2.p, g1.p)} +} + +// PairingCheck calculates the Optimal Ate pairing for a set of points. +func PairingCheck(a []*G1, b []*G2) bool { + acc := new(gfP12) + acc.SetOne() + + for i := 0; i < len(a); i++ { + if a[i].p.IsInfinity() || b[i].p.IsInfinity() { + continue + } + acc.Mul(acc, miller(b[i].p, a[i].p)) + } + return finalExponentiation(acc).IsOne() +} + +// Miller applies Miller's algorithm, which is a bilinear function from the +// source groups to F_p^12. Miller(g1, g2).Finalize() is equivalent to Pair(g1, +// g2). +func Miller(g1 *G1, g2 *G2) *GT { + return >{miller(g2.p, g1.p)} +} + +func (g *GT) String() string { + return "bn256.GT" + g.p.String() +} + +// ScalarMult sets e to a*k and then returns e. +func (e *GT) ScalarMult(a *GT, k *big.Int) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Exp(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *GT) Add(a, b *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Mul(a.p, b.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *GT) Neg(a *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Conjugate(a.p) + return e +} + +// Set sets e to a and then returns e. +func (e *GT) Set(a *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Set(a.p) + return e +} + +// Finalize is a linear function from F_p^12 to GT. +func (e *GT) Finalize() *GT { + ret := finalExponentiation(e.p) + e.p.Set(ret) + return e +} + +// Marshal converts e into a byte slice. +func (e *GT) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + ret := make([]byte, numBytes*12) + temp := &gfP{} + + montDecode(temp, &e.p.x.x.x) + temp.Marshal(ret) + montDecode(temp, &e.p.x.x.y) + temp.Marshal(ret[numBytes:]) + montDecode(temp, &e.p.x.y.x) + temp.Marshal(ret[2*numBytes:]) + montDecode(temp, &e.p.x.y.y) + temp.Marshal(ret[3*numBytes:]) + montDecode(temp, &e.p.x.z.x) + temp.Marshal(ret[4*numBytes:]) + montDecode(temp, &e.p.x.z.y) + temp.Marshal(ret[5*numBytes:]) + montDecode(temp, &e.p.y.x.x) + temp.Marshal(ret[6*numBytes:]) + montDecode(temp, &e.p.y.x.y) + temp.Marshal(ret[7*numBytes:]) + montDecode(temp, &e.p.y.y.x) + temp.Marshal(ret[8*numBytes:]) + montDecode(temp, &e.p.y.y.y) + temp.Marshal(ret[9*numBytes:]) + montDecode(temp, &e.p.y.z.x) + temp.Marshal(ret[10*numBytes:]) + montDecode(temp, &e.p.y.z.y) + temp.Marshal(ret[11*numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *GT) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if len(m) < 12*numBytes { + return nil, errors.New("bn256: not enough data") + } + + if e.p == nil { + e.p = &gfP12{} + } + + var err error + if err = e.p.x.x.x.Unmarshal(m); err != nil { + return nil, err + } + if err = e.p.x.x.y.Unmarshal(m[numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.y.x.Unmarshal(m[2*numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.y.y.Unmarshal(m[3*numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.z.x.Unmarshal(m[4*numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.z.y.Unmarshal(m[5*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.x.Unmarshal(m[6*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.y.Unmarshal(m[7*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.x.Unmarshal(m[8*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.y.Unmarshal(m[9*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.z.x.Unmarshal(m[10*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.z.y.Unmarshal(m[11*numBytes:]); err != nil { + return nil, err + } + montEncode(&e.p.x.x.x, &e.p.x.x.x) + montEncode(&e.p.x.x.y, &e.p.x.x.y) + montEncode(&e.p.x.y.x, &e.p.x.y.x) + montEncode(&e.p.x.y.y, &e.p.x.y.y) + montEncode(&e.p.x.z.x, &e.p.x.z.x) + montEncode(&e.p.x.z.y, &e.p.x.z.y) + montEncode(&e.p.y.x.x, &e.p.y.x.x) + montEncode(&e.p.y.x.y, &e.p.y.x.y) + montEncode(&e.p.y.y.x, &e.p.y.y.x) + montEncode(&e.p.y.y.y, &e.p.y.y.y) + montEncode(&e.p.y.z.x, &e.p.y.z.x) + montEncode(&e.p.y.z.y, &e.p.y.z.y) + + return m[12*numBytes:], nil +} diff --git a/crypto/bn256/cloudflare/bn256_test.go b/crypto/bn256/cloudflare/bn256_test.go new file mode 100644 index 000000000..369a3edaa --- /dev/null +++ b/crypto/bn256/cloudflare/bn256_test.go @@ -0,0 +1,118 @@ +// +build amd64,!appengine,!gccgo + +package bn256 + +import ( + "bytes" + "crypto/rand" + "testing" +) + +func TestG1Marshal(t *testing.T) { + _, Ga, err := RandomG1(rand.Reader) + if err != nil { + t.Fatal(err) + } + ma := Ga.Marshal() + + Gb := new(G1) + _, err = Gb.Unmarshal(ma) + if err != nil { + t.Fatal(err) + } + mb := Gb.Marshal() + + if !bytes.Equal(ma, mb) { + t.Fatal("bytes are different") + } +} + +func TestG2Marshal(t *testing.T) { + _, Ga, err := RandomG2(rand.Reader) + if err != nil { + t.Fatal(err) + } + ma := Ga.Marshal() + + Gb := new(G2) + _, err = Gb.Unmarshal(ma) + if err != nil { + t.Fatal(err) + } + mb := Gb.Marshal() + + if !bytes.Equal(ma, mb) { + t.Fatal("bytes are different") + } +} + +func TestBilinearity(t *testing.T) { + for i := 0; i < 2; i++ { + a, p1, _ := RandomG1(rand.Reader) + b, p2, _ := RandomG2(rand.Reader) + e1 := Pair(p1, p2) + + e2 := Pair(&G1{curveGen}, &G2{twistGen}) + e2.ScalarMult(e2, a) + e2.ScalarMult(e2, b) + + if *e1.p != *e2.p { + t.Fatalf("bad pairing result: %s", e1) + } + } +} + +func TestTripartiteDiffieHellman(t *testing.T) { + a, _ := rand.Int(rand.Reader, Order) + b, _ := rand.Int(rand.Reader, Order) + c, _ := rand.Int(rand.Reader, Order) + + pa, pb, pc := new(G1), new(G1), new(G1) + qa, qb, qc := new(G2), new(G2), new(G2) + + pa.Unmarshal(new(G1).ScalarBaseMult(a).Marshal()) + qa.Unmarshal(new(G2).ScalarBaseMult(a).Marshal()) + pb.Unmarshal(new(G1).ScalarBaseMult(b).Marshal()) + qb.Unmarshal(new(G2).ScalarBaseMult(b).Marshal()) + pc.Unmarshal(new(G1).ScalarBaseMult(c).Marshal()) + qc.Unmarshal(new(G2).ScalarBaseMult(c).Marshal()) + + k1 := Pair(pb, qc) + k1.ScalarMult(k1, a) + k1Bytes := k1.Marshal() + + k2 := Pair(pc, qa) + k2.ScalarMult(k2, b) + k2Bytes := k2.Marshal() + + k3 := Pair(pa, qb) + k3.ScalarMult(k3, c) + k3Bytes := k3.Marshal() + + if !bytes.Equal(k1Bytes, k2Bytes) || !bytes.Equal(k2Bytes, k3Bytes) { + t.Errorf("keys didn't agree") + } +} + +func BenchmarkG1(b *testing.B) { + x, _ := rand.Int(rand.Reader, Order) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + new(G1).ScalarBaseMult(x) + } +} + +func BenchmarkG2(b *testing.B) { + x, _ := rand.Int(rand.Reader, Order) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + new(G2).ScalarBaseMult(x) + } +} +func BenchmarkPairing(b *testing.B) { + for i := 0; i < b.N; i++ { + Pair(&G1{curveGen}, &G2{twistGen}) + } +} diff --git a/crypto/bn256/cloudflare/constants.go b/crypto/bn256/cloudflare/constants.go new file mode 100644 index 000000000..5122aae64 --- /dev/null +++ b/crypto/bn256/cloudflare/constants.go @@ -0,0 +1,59 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bn256 + +import ( + "math/big" +) + +func bigFromBase10(s string) *big.Int { + n, _ := new(big.Int).SetString(s, 10) + return n +} + +// u is the BN parameter that determines the prime: 1868033³. +var u = bigFromBase10("4965661367192848881") + +// Order is the number of elements in both G₁ and G₂: 36u⁴+36u³+18u²+6u+1. +var Order = bigFromBase10("21888242871839275222246405745257275088548364400416034343698204186575808495617") + +// P is a prime over which we form a basic field: 36u⁴+36u³+24u²+6u+1. +var P = bigFromBase10("21888242871839275222246405745257275088696311157297823662689037894645226208583") + +// p2 is p, represented as little-endian 64-bit words. +var p2 = [4]uint64{0x3c208c16d87cfd47, 0x97816a916871ca8d, 0xb85045b68181585d, 0x30644e72e131a029} + +// np is the negative inverse of p, mod 2^256. +var np = [4]uint64{0x87d20782e4866389, 0x9ede7d651eca6ac9, 0xd8afcbd01833da80, 0xf57a22b791888c6b} + +// rN1 is R^-1 where R = 2^256 mod p. +var rN1 = &gfP{0xed84884a014afa37, 0xeb2022850278edf8, 0xcf63e9cfb74492d9, 0x2e67157159e5c639} + +// r2 is R^2 where R = 2^256 mod p. +var r2 = &gfP{0xf32cfc5b538afa89, 0xb5e71911d44501fb, 0x47ab1eff0a417ff6, 0x06d89f71cab8351f} + +// r3 is R^3 where R = 2^256 mod p. +var r3 = &gfP{0xb1cd6dafda1530df, 0x62f210e6a7283db6, 0xef7f0b0c0ada0afb, 0x20fd6e902d592544} + +// xiToPMinus1Over6 is ξ^((p-1)/6) where ξ = i+9. +var xiToPMinus1Over6 = &gfP2{gfP{0xa222ae234c492d72, 0xd00f02a4565de15b, 0xdc2ff3a253dfc926, 0x10a75716b3899551}, gfP{0xaf9ba69633144907, 0xca6b1d7387afb78a, 0x11bded5ef08a2087, 0x02f34d751a1f3a7c}} + +// xiToPMinus1Over3 is ξ^((p-1)/3) where ξ = i+9. +var xiToPMinus1Over3 = &gfP2{gfP{0x6e849f1ea0aa4757, 0xaa1c7b6d89f89141, 0xb6e713cdfae0ca3a, 0x26694fbb4e82ebc3}, gfP{0xb5773b104563ab30, 0x347f91c8a9aa6454, 0x7a007127242e0991, 0x1956bcd8118214ec}} + +// xiToPMinus1Over2 is ξ^((p-1)/2) where ξ = i+9. +var xiToPMinus1Over2 = &gfP2{gfP{0xa1d77ce45ffe77c7, 0x07affd117826d1db, 0x6d16bd27bb7edc6b, 0x2c87200285defecc}, gfP{0xe4bbdd0c2936b629, 0xbb30f162e133bacb, 0x31a9d1b6f9645366, 0x253570bea500f8dd}} + +// xiToPSquaredMinus1Over3 is ξ^((p²-1)/3) where ξ = i+9. +var xiToPSquaredMinus1Over3 = &gfP{0x3350c88e13e80b9c, 0x7dce557cdb5e56b9, 0x6001b4b8b615564a, 0x2682e617020217e0} + +// xiTo2PSquaredMinus2Over3 is ξ^((2p²-2)/3) where ξ = i+9 (a cubic root of unity, mod p). +var xiTo2PSquaredMinus2Over3 = &gfP{0x71930c11d782e155, 0xa6bb947cffbe3323, 0xaa303344d4741444, 0x2c3b3f0d26594943} + +// xiToPSquaredMinus1Over6 is ξ^((1p²-1)/6) where ξ = i+9 (a cubic root of -1, mod p). +var xiToPSquaredMinus1Over6 = &gfP{0xca8d800500fa1bf2, 0xf0c5d61468b39769, 0x0e201271ad0d4418, 0x04290f65bad856e6} + +// xiTo2PMinus2Over3 is ξ^((2p-2)/3) where ξ = i+9. +var xiTo2PMinus2Over3 = &gfP2{gfP{0x5dddfd154bd8c949, 0x62cb29a5a4445b60, 0x37bc870a0c7dd2b9, 0x24830a9d3171f0fd}, gfP{0x7361d77f843abe92, 0xa5bb2bd3273411fb, 0x9c941f314b3e2399, 0x15df9cddbb9fd3ec}} diff --git a/crypto/bn256/cloudflare/curve.go b/crypto/bn256/cloudflare/curve.go new file mode 100644 index 000000000..b6aecc0a6 --- /dev/null +++ b/crypto/bn256/cloudflare/curve.go @@ -0,0 +1,229 @@ +package bn256 + +import ( + "math/big" +) + +// curvePoint implements the elliptic curve y²=x³+3. Points are kept in Jacobian +// form and t=z² when valid. G₁ is the set of points of this curve on GF(p). +type curvePoint struct { + x, y, z, t gfP +} + +var curveB = newGFp(3) + +// curveGen is the generator of G₁. +var curveGen = &curvePoint{ + x: *newGFp(1), + y: *newGFp(2), + z: *newGFp(1), + t: *newGFp(1), +} + +func (c *curvePoint) String() string { + c.MakeAffine() + x, y := &gfP{}, &gfP{} + montDecode(x, &c.x) + montDecode(y, &c.y) + return "(" + x.String() + ", " + y.String() + ")" +} + +func (c *curvePoint) Set(a *curvePoint) { + c.x.Set(&a.x) + c.y.Set(&a.y) + c.z.Set(&a.z) + c.t.Set(&a.t) +} + +// IsOnCurve returns true iff c is on the curve. +func (c *curvePoint) IsOnCurve() bool { + c.MakeAffine() + if c.IsInfinity() { + return true + } + + y2, x3 := &gfP{}, &gfP{} + gfpMul(y2, &c.y, &c.y) + gfpMul(x3, &c.x, &c.x) + gfpMul(x3, x3, &c.x) + gfpAdd(x3, x3, curveB) + + return *y2 == *x3 +} + +func (c *curvePoint) SetInfinity() { + c.x = gfP{0} + c.y = *newGFp(1) + c.z = gfP{0} + c.t = gfP{0} +} + +func (c *curvePoint) IsInfinity() bool { + return c.z == gfP{0} +} + +func (c *curvePoint) Add(a, b *curvePoint) { + if a.IsInfinity() { + c.Set(b) + return + } + if b.IsInfinity() { + c.Set(a) + return + } + + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3 + + // Normalize the points by replacing a = [x1:y1:z1] and b = [x2:y2:z2] + // by [u1:s1:z1·z2] and [u2:s2:z1·z2] + // where u1 = x1·z2², s1 = y1·z2³ and u1 = x2·z1², s2 = y2·z1³ + z12, z22 := &gfP{}, &gfP{} + gfpMul(z12, &a.z, &a.z) + gfpMul(z22, &b.z, &b.z) + + u1, u2 := &gfP{}, &gfP{} + gfpMul(u1, &a.x, z22) + gfpMul(u2, &b.x, z12) + + t, s1 := &gfP{}, &gfP{} + gfpMul(t, &b.z, z22) + gfpMul(s1, &a.y, t) + + s2 := &gfP{} + gfpMul(t, &a.z, z12) + gfpMul(s2, &b.y, t) + + // Compute x = (2h)²(s²-u1-u2) + // where s = (s2-s1)/(u2-u1) is the slope of the line through + // (u1,s1) and (u2,s2). The extra factor 2h = 2(u2-u1) comes from the value of z below. + // This is also: + // 4(s2-s1)² - 4h²(u1+u2) = 4(s2-s1)² - 4h³ - 4h²(2u1) + // = r² - j - 2v + // with the notations below. + h := &gfP{} + gfpSub(h, u2, u1) + xEqual := *h == gfP{0} + + gfpAdd(t, h, h) + // i = 4h² + i := &gfP{} + gfpMul(i, t, t) + // j = 4h³ + j := &gfP{} + gfpMul(j, h, i) + + gfpSub(t, s2, s1) + yEqual := *t == gfP{0} + if xEqual && yEqual { + c.Double(a) + return + } + r := &gfP{} + gfpAdd(r, t, t) + + v := &gfP{} + gfpMul(v, u1, i) + + // t4 = 4(s2-s1)² + t4, t6 := &gfP{}, &gfP{} + gfpMul(t4, r, r) + gfpAdd(t, v, v) + gfpSub(t6, t4, j) + + gfpSub(&c.x, t6, t) + + // Set y = -(2h)³(s1 + s*(x/4h²-u1)) + // This is also + // y = - 2·s1·j - (s2-s1)(2x - 2i·u1) = r(v-x) - 2·s1·j + gfpSub(t, v, &c.x) // t7 + gfpMul(t4, s1, j) // t8 + gfpAdd(t6, t4, t4) // t9 + gfpMul(t4, r, t) // t10 + gfpSub(&c.y, t4, t6) + + // Set z = 2(u2-u1)·z1·z2 = 2h·z1·z2 + gfpAdd(t, &a.z, &b.z) // t11 + gfpMul(t4, t, t) // t12 + gfpSub(t, t4, z12) // t13 + gfpSub(t4, t, z22) // t14 + gfpMul(&c.z, t4, h) +} + +func (c *curvePoint) Double(a *curvePoint) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + A, B, C := &gfP{}, &gfP{}, &gfP{} + gfpMul(A, &a.x, &a.x) + gfpMul(B, &a.y, &a.y) + gfpMul(C, B, B) + + t, t2 := &gfP{}, &gfP{} + gfpAdd(t, &a.x, B) + gfpMul(t2, t, t) + gfpSub(t, t2, A) + gfpSub(t2, t, C) + + d, e, f := &gfP{}, &gfP{}, &gfP{} + gfpAdd(d, t2, t2) + gfpAdd(t, A, A) + gfpAdd(e, t, A) + gfpMul(f, e, e) + + gfpAdd(t, d, d) + gfpSub(&c.x, f, t) + + gfpAdd(t, C, C) + gfpAdd(t2, t, t) + gfpAdd(t, t2, t2) + gfpSub(&c.y, d, &c.x) + gfpMul(t2, e, &c.y) + gfpSub(&c.y, t2, t) + + gfpMul(t, &a.y, &a.z) + gfpAdd(&c.z, t, t) +} + +func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int) { + sum, t := &curvePoint{}, &curvePoint{} + sum.SetInfinity() + + for i := scalar.BitLen(); i >= 0; i-- { + t.Double(sum) + if scalar.Bit(i) != 0 { + sum.Add(t, a) + } else { + sum.Set(t) + } + } + c.Set(sum) +} + +func (c *curvePoint) MakeAffine() { + if c.z == *newGFp(1) { + return + } else if c.z == *newGFp(0) { + c.x = gfP{0} + c.y = *newGFp(1) + c.t = gfP{0} + return + } + + zInv := &gfP{} + zInv.Invert(&c.z) + + t, zInv2 := &gfP{}, &gfP{} + gfpMul(t, &c.y, zInv) + gfpMul(zInv2, zInv, zInv) + + gfpMul(&c.x, &c.x, zInv2) + gfpMul(&c.y, t, zInv2) + + c.z = *newGFp(1) + c.t = *newGFp(1) +} + +func (c *curvePoint) Neg(a *curvePoint) { + c.x.Set(&a.x) + gfpNeg(&c.y, &a.y) + c.z.Set(&a.z) + c.t = gfP{0} +} diff --git a/crypto/bn256/cloudflare/example_test.go b/crypto/bn256/cloudflare/example_test.go new file mode 100644 index 000000000..2ee545c67 --- /dev/null +++ b/crypto/bn256/cloudflare/example_test.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!appengine,!gccgo + +package bn256 + +import ( + "crypto/rand" +) + +func ExamplePair() { + // This implements the tripartite Diffie-Hellman algorithm from "A One + // Round Protocol for Tripartite Diffie-Hellman", A. Joux. + // http://www.springerlink.com/content/cddc57yyva0hburb/fulltext.pdf + + // Each of three parties, a, b and c, generate a private value. + a, _ := rand.Int(rand.Reader, Order) + b, _ := rand.Int(rand.Reader, Order) + c, _ := rand.Int(rand.Reader, Order) + + // Then each party calculates g₁ and g₂ times their private value. + pa := new(G1).ScalarBaseMult(a) + qa := new(G2).ScalarBaseMult(a) + + pb := new(G1).ScalarBaseMult(b) + qb := new(G2).ScalarBaseMult(b) + + pc := new(G1).ScalarBaseMult(c) + qc := new(G2).ScalarBaseMult(c) + + // Now each party exchanges its public values with the other two and + // all parties can calculate the shared key. + k1 := Pair(pb, qc) + k1.ScalarMult(k1, a) + + k2 := Pair(pc, qa) + k2.ScalarMult(k2, b) + + k3 := Pair(pa, qb) + k3.ScalarMult(k3, c) + + // k1, k2 and k3 will all be equal. +} diff --git a/crypto/bn256/cloudflare/gfp.go b/crypto/bn256/cloudflare/gfp.go new file mode 100644 index 000000000..e8e84e7b3 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp.go @@ -0,0 +1,81 @@ +package bn256 + +import ( + "errors" + "fmt" +) + +type gfP [4]uint64 + +func newGFp(x int64) (out *gfP) { + if x >= 0 { + out = &gfP{uint64(x)} + } else { + out = &gfP{uint64(-x)} + gfpNeg(out, out) + } + + montEncode(out, out) + return out +} + +func (e *gfP) String() string { + return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0]) +} + +func (e *gfP) Set(f *gfP) { + e[0] = f[0] + e[1] = f[1] + e[2] = f[2] + e[3] = f[3] +} + +func (e *gfP) Invert(f *gfP) { + bits := [4]uint64{0x3c208c16d87cfd45, 0x97816a916871ca8d, 0xb85045b68181585d, 0x30644e72e131a029} + + sum, power := &gfP{}, &gfP{} + sum.Set(rN1) + power.Set(f) + + for word := 0; word < 4; word++ { + for bit := uint(0); bit < 64; bit++ { + if (bits[word]>>bit)&1 == 1 { + gfpMul(sum, sum, power) + } + gfpMul(power, power, power) + } + } + + gfpMul(sum, sum, r3) + e.Set(sum) +} + +func (e *gfP) Marshal(out []byte) { + for w := uint(0); w < 4; w++ { + for b := uint(0); b < 8; b++ { + out[8*w+b] = byte(e[3-w] >> (56 - 8*b)) + } + } +} + +func (e *gfP) Unmarshal(in []byte) error { + // Unmarshal the bytes into little endian form + for w := uint(0); w < 4; w++ { + for b := uint(0); b < 8; b++ { + e[3-w] += uint64(in[8*w+b]) << (56 - 8*b) + } + } + // Ensure the point respects the curve modulus + for i := 3; i >= 0; i-- { + if e[i] < p2[i] { + return nil + } + if e[i] > p2[i] { + return errors.New("bn256: coordinate exceeds modulus") + } + } + return errors.New("bn256: coordinate equals modulus") +} + +func montEncode(c, a *gfP) { gfpMul(c, a, r2) } +func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) } diff --git a/crypto/bn256/cloudflare/gfp.h b/crypto/bn256/cloudflare/gfp.h new file mode 100644 index 000000000..66f5a4d07 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp.h @@ -0,0 +1,32 @@ +#define storeBlock(a0,a1,a2,a3, r) \ + MOVQ a0, 0+r \ + MOVQ a1, 8+r \ + MOVQ a2, 16+r \ + MOVQ a3, 24+r + +#define loadBlock(r, a0,a1,a2,a3) \ + MOVQ 0+r, a0 \ + MOVQ 8+r, a1 \ + MOVQ 16+r, a2 \ + MOVQ 24+r, a3 + +#define gfpCarry(a0,a1,a2,a3,a4, b0,b1,b2,b3,b4) \ + \ // b = a-p + MOVQ a0, b0 \ + MOVQ a1, b1 \ + MOVQ a2, b2 \ + MOVQ a3, b3 \ + MOVQ a4, b4 \ + \ + SUBQ ·p2+0(SB), b0 \ + SBBQ ·p2+8(SB), b1 \ + SBBQ ·p2+16(SB), b2 \ + SBBQ ·p2+24(SB), b3 \ + SBBQ $0, b4 \ + \ + \ // if b is negative then return a + \ // else return b + CMOVQCC b0, a0 \ + CMOVQCC b1, a1 \ + CMOVQCC b2, a2 \ + CMOVQCC b3, a3 diff --git a/crypto/bn256/cloudflare/gfp12.go b/crypto/bn256/cloudflare/gfp12.go new file mode 100644 index 000000000..93fb368a7 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp12.go @@ -0,0 +1,160 @@ +package bn256 + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. + +import ( + "math/big" +) + +// gfP12 implements the field of size p¹² as a quadratic extension of gfP6 +// where ω²=τ. +type gfP12 struct { + x, y gfP6 // value is xω + y +} + +func (e *gfP12) String() string { + return "(" + e.x.String() + "," + e.y.String() + ")" +} + +func (e *gfP12) Set(a *gfP12) *gfP12 { + e.x.Set(&a.x) + e.y.Set(&a.y) + return e +} + +func (e *gfP12) SetZero() *gfP12 { + e.x.SetZero() + e.y.SetZero() + return e +} + +func (e *gfP12) SetOne() *gfP12 { + e.x.SetZero() + e.y.SetOne() + return e +} + +func (e *gfP12) IsZero() bool { + return e.x.IsZero() && e.y.IsZero() +} + +func (e *gfP12) IsOne() bool { + return e.x.IsZero() && e.y.IsOne() +} + +func (e *gfP12) Conjugate(a *gfP12) *gfP12 { + e.x.Neg(&a.x) + e.y.Set(&a.y) + return e +} + +func (e *gfP12) Neg(a *gfP12) *gfP12 { + e.x.Neg(&a.x) + e.y.Neg(&a.y) + return e +} + +// Frobenius computes (xω+y)^p = x^p ω·ξ^((p-1)/6) + y^p +func (e *gfP12) Frobenius(a *gfP12) *gfP12 { + e.x.Frobenius(&a.x) + e.y.Frobenius(&a.y) + e.x.MulScalar(&e.x, xiToPMinus1Over6) + return e +} + +// FrobeniusP2 computes (xω+y)^p² = x^p² ω·ξ^((p²-1)/6) + y^p² +func (e *gfP12) FrobeniusP2(a *gfP12) *gfP12 { + e.x.FrobeniusP2(&a.x) + e.x.MulGFP(&e.x, xiToPSquaredMinus1Over6) + e.y.FrobeniusP2(&a.y) + return e +} + +func (e *gfP12) FrobeniusP4(a *gfP12) *gfP12 { + e.x.FrobeniusP4(&a.x) + e.x.MulGFP(&e.x, xiToPSquaredMinus1Over3) + e.y.FrobeniusP4(&a.y) + return e +} + +func (e *gfP12) Add(a, b *gfP12) *gfP12 { + e.x.Add(&a.x, &b.x) + e.y.Add(&a.y, &b.y) + return e +} + +func (e *gfP12) Sub(a, b *gfP12) *gfP12 { + e.x.Sub(&a.x, &b.x) + e.y.Sub(&a.y, &b.y) + return e +} + +func (e *gfP12) Mul(a, b *gfP12) *gfP12 { + tx := (&gfP6{}).Mul(&a.x, &b.y) + t := (&gfP6{}).Mul(&b.x, &a.y) + tx.Add(tx, t) + + ty := (&gfP6{}).Mul(&a.y, &b.y) + t.Mul(&a.x, &b.x).MulTau(t) + + e.x.Set(tx) + e.y.Add(ty, t) + return e +} + +func (e *gfP12) MulScalar(a *gfP12, b *gfP6) *gfP12 { + e.x.Mul(&e.x, b) + e.y.Mul(&e.y, b) + return e +} + +func (c *gfP12) Exp(a *gfP12, power *big.Int) *gfP12 { + sum := (&gfP12{}).SetOne() + t := &gfP12{} + + for i := power.BitLen() - 1; i >= 0; i-- { + t.Square(sum) + if power.Bit(i) != 0 { + sum.Mul(t, a) + } else { + sum.Set(t) + } + } + + c.Set(sum) + return c +} + +func (e *gfP12) Square(a *gfP12) *gfP12 { + // Complex squaring algorithm + v0 := (&gfP6{}).Mul(&a.x, &a.y) + + t := (&gfP6{}).MulTau(&a.x) + t.Add(&a.y, t) + ty := (&gfP6{}).Add(&a.x, &a.y) + ty.Mul(ty, t).Sub(ty, v0) + t.MulTau(v0) + ty.Sub(ty, t) + + e.x.Add(v0, v0) + e.y.Set(ty) + return e +} + +func (e *gfP12) Invert(a *gfP12) *gfP12 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + t1, t2 := &gfP6{}, &gfP6{} + + t1.Square(&a.x) + t2.Square(&a.y) + t1.MulTau(t1).Sub(t2, t1) + t2.Invert(t1) + + e.x.Neg(&a.x) + e.y.Set(&a.y) + e.MulScalar(e, t2) + return e +} diff --git a/crypto/bn256/cloudflare/gfp2.go b/crypto/bn256/cloudflare/gfp2.go new file mode 100644 index 000000000..90a89e8b4 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp2.go @@ -0,0 +1,156 @@ +package bn256 + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. + +// gfP2 implements a field of size p² as a quadratic extension of the base field +// where i²=-1. +type gfP2 struct { + x, y gfP // value is xi+y. +} + +func gfP2Decode(in *gfP2) *gfP2 { + out := &gfP2{} + montDecode(&out.x, &in.x) + montDecode(&out.y, &in.y) + return out +} + +func (e *gfP2) String() string { + return "(" + e.x.String() + ", " + e.y.String() + ")" +} + +func (e *gfP2) Set(a *gfP2) *gfP2 { + e.x.Set(&a.x) + e.y.Set(&a.y) + return e +} + +func (e *gfP2) SetZero() *gfP2 { + e.x = gfP{0} + e.y = gfP{0} + return e +} + +func (e *gfP2) SetOne() *gfP2 { + e.x = gfP{0} + e.y = *newGFp(1) + return e +} + +func (e *gfP2) IsZero() bool { + zero := gfP{0} + return e.x == zero && e.y == zero +} + +func (e *gfP2) IsOne() bool { + zero, one := gfP{0}, *newGFp(1) + return e.x == zero && e.y == one +} + +func (e *gfP2) Conjugate(a *gfP2) *gfP2 { + e.y.Set(&a.y) + gfpNeg(&e.x, &a.x) + return e +} + +func (e *gfP2) Neg(a *gfP2) *gfP2 { + gfpNeg(&e.x, &a.x) + gfpNeg(&e.y, &a.y) + return e +} + +func (e *gfP2) Add(a, b *gfP2) *gfP2 { + gfpAdd(&e.x, &a.x, &b.x) + gfpAdd(&e.y, &a.y, &b.y) + return e +} + +func (e *gfP2) Sub(a, b *gfP2) *gfP2 { + gfpSub(&e.x, &a.x, &b.x) + gfpSub(&e.y, &a.y, &b.y) + return e +} + +// See "Multiplication and Squaring in Pairing-Friendly Fields", +// http://eprint.iacr.org/2006/471.pdf +func (e *gfP2) Mul(a, b *gfP2) *gfP2 { + tx, t := &gfP{}, &gfP{} + gfpMul(tx, &a.x, &b.y) + gfpMul(t, &b.x, &a.y) + gfpAdd(tx, tx, t) + + ty := &gfP{} + gfpMul(ty, &a.y, &b.y) + gfpMul(t, &a.x, &b.x) + gfpSub(ty, ty, t) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) MulScalar(a *gfP2, b *gfP) *gfP2 { + gfpMul(&e.x, &a.x, b) + gfpMul(&e.y, &a.y, b) + return e +} + +// MulXi sets e=ξa where ξ=i+9 and then returns e. +func (e *gfP2) MulXi(a *gfP2) *gfP2 { + // (xi+y)(i+9) = (9x+y)i+(9y-x) + tx := &gfP{} + gfpAdd(tx, &a.x, &a.x) + gfpAdd(tx, tx, tx) + gfpAdd(tx, tx, tx) + gfpAdd(tx, tx, &a.x) + + gfpAdd(tx, tx, &a.y) + + ty := &gfP{} + gfpAdd(ty, &a.y, &a.y) + gfpAdd(ty, ty, ty) + gfpAdd(ty, ty, ty) + gfpAdd(ty, ty, &a.y) + + gfpSub(ty, ty, &a.x) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) Square(a *gfP2) *gfP2 { + // Complex squaring algorithm: + // (xi+y)² = (x+y)(y-x) + 2*i*x*y + tx, ty := &gfP{}, &gfP{} + gfpSub(tx, &a.y, &a.x) + gfpAdd(ty, &a.x, &a.y) + gfpMul(ty, tx, ty) + + gfpMul(tx, &a.x, &a.y) + gfpAdd(tx, tx, tx) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) Invert(a *gfP2) *gfP2 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + t1, t2 := &gfP{}, &gfP{} + gfpMul(t1, &a.x, &a.x) + gfpMul(t2, &a.y, &a.y) + gfpAdd(t1, t1, t2) + + inv := &gfP{} + inv.Invert(t1) + + gfpNeg(t1, &a.x) + + gfpMul(&e.x, t1, inv) + gfpMul(&e.y, &a.y, inv) + return e +} diff --git a/crypto/bn256/cloudflare/gfp6.go b/crypto/bn256/cloudflare/gfp6.go new file mode 100644 index 000000000..83d61b781 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp6.go @@ -0,0 +1,213 @@ +package bn256 + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. + +// gfP6 implements the field of size p⁶ as a cubic extension of gfP2 where τ³=ξ +// and ξ=i+3. +type gfP6 struct { + x, y, z gfP2 // value is xτ² + yτ + z +} + +func (e *gfP6) String() string { + return "(" + e.x.String() + ", " + e.y.String() + ", " + e.z.String() + ")" +} + +func (e *gfP6) Set(a *gfP6) *gfP6 { + e.x.Set(&a.x) + e.y.Set(&a.y) + e.z.Set(&a.z) + return e +} + +func (e *gfP6) SetZero() *gfP6 { + e.x.SetZero() + e.y.SetZero() + e.z.SetZero() + return e +} + +func (e *gfP6) SetOne() *gfP6 { + e.x.SetZero() + e.y.SetZero() + e.z.SetOne() + return e +} + +func (e *gfP6) IsZero() bool { + return e.x.IsZero() && e.y.IsZero() && e.z.IsZero() +} + +func (e *gfP6) IsOne() bool { + return e.x.IsZero() && e.y.IsZero() && e.z.IsOne() +} + +func (e *gfP6) Neg(a *gfP6) *gfP6 { + e.x.Neg(&a.x) + e.y.Neg(&a.y) + e.z.Neg(&a.z) + return e +} + +func (e *gfP6) Frobenius(a *gfP6) *gfP6 { + e.x.Conjugate(&a.x) + e.y.Conjugate(&a.y) + e.z.Conjugate(&a.z) + + e.x.Mul(&e.x, xiTo2PMinus2Over3) + e.y.Mul(&e.y, xiToPMinus1Over3) + return e +} + +// FrobeniusP2 computes (xτ²+yτ+z)^(p²) = xτ^(2p²) + yτ^(p²) + z +func (e *gfP6) FrobeniusP2(a *gfP6) *gfP6 { + // τ^(2p²) = τ²τ^(2p²-2) = τ²ξ^((2p²-2)/3) + e.x.MulScalar(&a.x, xiTo2PSquaredMinus2Over3) + // τ^(p²) = ττ^(p²-1) = τξ^((p²-1)/3) + e.y.MulScalar(&a.y, xiToPSquaredMinus1Over3) + e.z.Set(&a.z) + return e +} + +func (e *gfP6) FrobeniusP4(a *gfP6) *gfP6 { + e.x.MulScalar(&a.x, xiToPSquaredMinus1Over3) + e.y.MulScalar(&a.y, xiTo2PSquaredMinus2Over3) + e.z.Set(&a.z) + return e +} + +func (e *gfP6) Add(a, b *gfP6) *gfP6 { + e.x.Add(&a.x, &b.x) + e.y.Add(&a.y, &b.y) + e.z.Add(&a.z, &b.z) + return e +} + +func (e *gfP6) Sub(a, b *gfP6) *gfP6 { + e.x.Sub(&a.x, &b.x) + e.y.Sub(&a.y, &b.y) + e.z.Sub(&a.z, &b.z) + return e +} + +func (e *gfP6) Mul(a, b *gfP6) *gfP6 { + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Section 4, Karatsuba method. + // http://eprint.iacr.org/2006/471.pdf + v0 := (&gfP2{}).Mul(&a.z, &b.z) + v1 := (&gfP2{}).Mul(&a.y, &b.y) + v2 := (&gfP2{}).Mul(&a.x, &b.x) + + t0 := (&gfP2{}).Add(&a.x, &a.y) + t1 := (&gfP2{}).Add(&b.x, &b.y) + tz := (&gfP2{}).Mul(t0, t1) + tz.Sub(tz, v1).Sub(tz, v2).MulXi(tz).Add(tz, v0) + + t0.Add(&a.y, &a.z) + t1.Add(&b.y, &b.z) + ty := (&gfP2{}).Mul(t0, t1) + t0.MulXi(v2) + ty.Sub(ty, v0).Sub(ty, v1).Add(ty, t0) + + t0.Add(&a.x, &a.z) + t1.Add(&b.x, &b.z) + tx := (&gfP2{}).Mul(t0, t1) + tx.Sub(tx, v0).Add(tx, v1).Sub(tx, v2) + + e.x.Set(tx) + e.y.Set(ty) + e.z.Set(tz) + return e +} + +func (e *gfP6) MulScalar(a *gfP6, b *gfP2) *gfP6 { + e.x.Mul(&a.x, b) + e.y.Mul(&a.y, b) + e.z.Mul(&a.z, b) + return e +} + +func (e *gfP6) MulGFP(a *gfP6, b *gfP) *gfP6 { + e.x.MulScalar(&a.x, b) + e.y.MulScalar(&a.y, b) + e.z.MulScalar(&a.z, b) + return e +} + +// MulTau computes τ·(aτ²+bτ+c) = bτ²+cτ+aξ +func (e *gfP6) MulTau(a *gfP6) *gfP6 { + tz := (&gfP2{}).MulXi(&a.x) + ty := (&gfP2{}).Set(&a.y) + + e.y.Set(&a.z) + e.x.Set(ty) + e.z.Set(tz) + return e +} + +func (e *gfP6) Square(a *gfP6) *gfP6 { + v0 := (&gfP2{}).Square(&a.z) + v1 := (&gfP2{}).Square(&a.y) + v2 := (&gfP2{}).Square(&a.x) + + c0 := (&gfP2{}).Add(&a.x, &a.y) + c0.Square(c0).Sub(c0, v1).Sub(c0, v2).MulXi(c0).Add(c0, v0) + + c1 := (&gfP2{}).Add(&a.y, &a.z) + c1.Square(c1).Sub(c1, v0).Sub(c1, v1) + xiV2 := (&gfP2{}).MulXi(v2) + c1.Add(c1, xiV2) + + c2 := (&gfP2{}).Add(&a.x, &a.z) + c2.Square(c2).Sub(c2, v0).Add(c2, v1).Sub(c2, v2) + + e.x.Set(c2) + e.y.Set(c1) + e.z.Set(c0) + return e +} + +func (e *gfP6) Invert(a *gfP6) *gfP6 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + + // Here we can give a short explanation of how it works: let j be a cubic root of + // unity in GF(p²) so that 1+j+j²=0. + // Then (xτ² + yτ + z)(xj²τ² + yjτ + z)(xjτ² + yj²τ + z) + // = (xτ² + yτ + z)(Cτ²+Bτ+A) + // = (x³ξ²+y³ξ+z³-3ξxyz) = F is an element of the base field (the norm). + // + // On the other hand (xj²τ² + yjτ + z)(xjτ² + yj²τ + z) + // = τ²(y²-ξxz) + τ(ξx²-yz) + (z²-ξxy) + // + // So that's why A = (z²-ξxy), B = (ξx²-yz), C = (y²-ξxz) + t1 := (&gfP2{}).Mul(&a.x, &a.y) + t1.MulXi(t1) + + A := (&gfP2{}).Square(&a.z) + A.Sub(A, t1) + + B := (&gfP2{}).Square(&a.x) + B.MulXi(B) + t1.Mul(&a.y, &a.z) + B.Sub(B, t1) + + C := (&gfP2{}).Square(&a.y) + t1.Mul(&a.x, &a.z) + C.Sub(C, t1) + + F := (&gfP2{}).Mul(C, &a.y) + F.MulXi(F) + t1.Mul(A, &a.z) + F.Add(F, t1) + t1.Mul(B, &a.x).MulXi(t1) + F.Add(F, t1) + + F.Invert(F) + + e.x.Mul(C, F) + e.y.Mul(B, F) + e.z.Mul(A, F) + return e +} diff --git a/crypto/bn256/cloudflare/gfp_amd64.go b/crypto/bn256/cloudflare/gfp_amd64.go new file mode 100644 index 000000000..ac4f1a9c6 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp_amd64.go @@ -0,0 +1,15 @@ +// +build amd64,!appengine,!gccgo + +package bn256 + +// go:noescape +func gfpNeg(c, a *gfP) + +//go:noescape +func gfpAdd(c, a, b *gfP) + +//go:noescape +func gfpSub(c, a, b *gfP) + +//go:noescape +func gfpMul(c, a, b *gfP) diff --git a/crypto/bn256/cloudflare/gfp_amd64.s b/crypto/bn256/cloudflare/gfp_amd64.s new file mode 100644 index 000000000..2d0176f2e --- /dev/null +++ b/crypto/bn256/cloudflare/gfp_amd64.s @@ -0,0 +1,97 @@ +// +build amd64,!appengine,!gccgo + +#include "gfp.h" +#include "mul.h" +#include "mul_bmi2.h" + +TEXT ·gfpNeg(SB),0,$0-16 + MOVQ ·p2+0(SB), R8 + MOVQ ·p2+8(SB), R9 + MOVQ ·p2+16(SB), R10 + MOVQ ·p2+24(SB), R11 + + MOVQ a+8(FP), DI + SUBQ 0(DI), R8 + SBBQ 8(DI), R9 + SBBQ 16(DI), R10 + SBBQ 24(DI), R11 + + MOVQ $0, AX + gfpCarry(R8,R9,R10,R11,AX, R12,R13,R14,R15,BX) + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpAdd(SB),0,$0-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + loadBlock(0(DI), R8,R9,R10,R11) + MOVQ $0, R12 + + ADDQ 0(SI), R8 + ADCQ 8(SI), R9 + ADCQ 16(SI), R10 + ADCQ 24(SI), R11 + ADCQ $0, R12 + + gfpCarry(R8,R9,R10,R11,R12, R13,R14,R15,AX,BX) + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpSub(SB),0,$0-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + loadBlock(0(DI), R8,R9,R10,R11) + + MOVQ ·p2+0(SB), R12 + MOVQ ·p2+8(SB), R13 + MOVQ ·p2+16(SB), R14 + MOVQ ·p2+24(SB), R15 + MOVQ $0, AX + + SUBQ 0(SI), R8 + SBBQ 8(SI), R9 + SBBQ 16(SI), R10 + SBBQ 24(SI), R11 + + CMOVQCC AX, R12 + CMOVQCC AX, R13 + CMOVQCC AX, R14 + CMOVQCC AX, R15 + + ADDQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + ADCQ R15, R11 + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpMul(SB),0,$160-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + // Jump to a slightly different implementation if MULX isn't supported. + CMPB runtime·support_bmi2(SB), $0 + JE nobmi2Mul + + mulBMI2(0(DI),8(DI),16(DI),24(DI), 0(SI)) + storeBlock( R8, R9,R10,R11, 0(SP)) + storeBlock(R12,R13,R14,R15, 32(SP)) + gfpReduceBMI2() + JMP end + +nobmi2Mul: + mul(0(DI),8(DI),16(DI),24(DI), 0(SI), 0(SP)) + gfpReduce(0(SP)) + +end: + MOVQ c+0(FP), DI + storeBlock(R12,R13,R14,R15, 0(DI)) + RET diff --git a/crypto/bn256/cloudflare/gfp_pure.go b/crypto/bn256/cloudflare/gfp_pure.go new file mode 100644 index 000000000..8fa5d3053 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp_pure.go @@ -0,0 +1,19 @@ +// +build !amd64 appengine gccgo + +package bn256 + +func gfpNeg(c, a *gfP) { + panic("unsupported architecture") +} + +func gfpAdd(c, a, b *gfP) { + panic("unsupported architecture") +} + +func gfpSub(c, a, b *gfP) { + panic("unsupported architecture") +} + +func gfpMul(c, a, b *gfP) { + panic("unsupported architecture") +} diff --git a/crypto/bn256/cloudflare/gfp_test.go b/crypto/bn256/cloudflare/gfp_test.go new file mode 100644 index 000000000..aff5e0531 --- /dev/null +++ b/crypto/bn256/cloudflare/gfp_test.go @@ -0,0 +1,62 @@ +// +build amd64,!appengine,!gccgo + +package bn256 + +import ( + "testing" +) + +// Tests that negation works the same way on both assembly-optimized and pure Go +// implementation. +func TestGFpNeg(t *testing.T) { + n := &gfP{0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefdeadbeef, 0xfeebdaedfeebdaed} + w := &gfP{0xfedcba9876543211, 0x0123456789abcdef, 0x2152411021524110, 0x0114251201142512} + h := &gfP{} + + gfpNeg(h, n) + if *h != *w { + t.Errorf("negation mismatch: have %#x, want %#x", *h, *w) + } +} + +// Tests that addition works the same way on both assembly-optimized and pure Go +// implementation. +func TestGFpAdd(t *testing.T) { + a := &gfP{0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefdeadbeef, 0xfeebdaedfeebdaed} + b := &gfP{0xfedcba9876543210, 0x0123456789abcdef, 0xfeebdaedfeebdaed, 0xdeadbeefdeadbeef} + w := &gfP{0xc3df73e9278302b8, 0x687e956e978e3572, 0x254954275c18417f, 0xad354b6afc67f9b4} + h := &gfP{} + + gfpAdd(h, a, b) + if *h != *w { + t.Errorf("addition mismatch: have %#x, want %#x", *h, *w) + } +} + +// Tests that subtraction works the same way on both assembly-optimized and pure Go +// implementation. +func TestGFpSub(t *testing.T) { + a := &gfP{0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefdeadbeef, 0xfeebdaedfeebdaed} + b := &gfP{0xfedcba9876543210, 0x0123456789abcdef, 0xfeebdaedfeebdaed, 0xdeadbeefdeadbeef} + w := &gfP{0x02468acf13579bdf, 0xfdb97530eca86420, 0xdfc1e401dfc1e402, 0x203e1bfe203e1bfd} + h := &gfP{} + + gfpSub(h, a, b) + if *h != *w { + t.Errorf("subtraction mismatch: have %#x, want %#x", *h, *w) + } +} + +// Tests that multiplication works the same way on both assembly-optimized and pure Go +// implementation. +func TestGFpMul(t *testing.T) { + a := &gfP{0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefdeadbeef, 0xfeebdaedfeebdaed} + b := &gfP{0xfedcba9876543210, 0x0123456789abcdef, 0xfeebdaedfeebdaed, 0xdeadbeefdeadbeef} + w := &gfP{0xcbcbd377f7ad22d3, 0x3b89ba5d849379bf, 0x87b61627bd38b6d2, 0xc44052a2a0e654b2} + h := &gfP{} + + gfpMul(h, a, b) + if *h != *w { + t.Errorf("multiplication mismatch: have %#x, want %#x", *h, *w) + } +} diff --git a/crypto/bn256/cloudflare/main_test.go b/crypto/bn256/cloudflare/main_test.go new file mode 100644 index 000000000..f0d59a404 --- /dev/null +++ b/crypto/bn256/cloudflare/main_test.go @@ -0,0 +1,73 @@ +// +build amd64,!appengine,!gccgo + +package bn256 + +import ( + "testing" + + "crypto/rand" +) + +func TestRandomG2Marshal(t *testing.T) { + for i := 0; i < 10; i++ { + n, g2, err := RandomG2(rand.Reader) + if err != nil { + t.Error(err) + continue + } + t.Logf("%d: %x\n", n, g2.Marshal()) + } +} + +func TestPairings(t *testing.T) { + a1 := new(G1).ScalarBaseMult(bigFromBase10("1")) + a2 := new(G1).ScalarBaseMult(bigFromBase10("2")) + a37 := new(G1).ScalarBaseMult(bigFromBase10("37")) + an1 := new(G1).ScalarBaseMult(bigFromBase10("21888242871839275222246405745257275088548364400416034343698204186575808495616")) + + b0 := new(G2).ScalarBaseMult(bigFromBase10("0")) + b1 := new(G2).ScalarBaseMult(bigFromBase10("1")) + b2 := new(G2).ScalarBaseMult(bigFromBase10("2")) + b27 := new(G2).ScalarBaseMult(bigFromBase10("27")) + b999 := new(G2).ScalarBaseMult(bigFromBase10("999")) + bn1 := new(G2).ScalarBaseMult(bigFromBase10("21888242871839275222246405745257275088548364400416034343698204186575808495616")) + + p1 := Pair(a1, b1) + pn1 := Pair(a1, bn1) + np1 := Pair(an1, b1) + if pn1.String() != np1.String() { + t.Error("Pairing mismatch: e(a, -b) != e(-a, b)") + } + if !PairingCheck([]*G1{a1, an1}, []*G2{b1, b1}) { + t.Error("MultiAte check gave false negative!") + } + p0 := new(GT).Add(p1, pn1) + p0_2 := Pair(a1, b0) + if p0.String() != p0_2.String() { + t.Error("Pairing mismatch: e(a, b) * e(a, -b) != 1") + } + p0_3 := new(GT).ScalarMult(p1, bigFromBase10("21888242871839275222246405745257275088548364400416034343698204186575808495617")) + if p0.String() != p0_3.String() { + t.Error("Pairing mismatch: e(a, b) has wrong order") + } + p2 := Pair(a2, b1) + p2_2 := Pair(a1, b2) + p2_3 := new(GT).ScalarMult(p1, bigFromBase10("2")) + if p2.String() != p2_2.String() { + t.Error("Pairing mismatch: e(a, b * 2) != e(a * 2, b)") + } + if p2.String() != p2_3.String() { + t.Error("Pairing mismatch: e(a, b * 2) != e(a, b) ** 2") + } + if p2.String() == p1.String() { + t.Error("Pairing is degenerate!") + } + if PairingCheck([]*G1{a1, a1}, []*G2{b1, b1}) { + t.Error("MultiAte check gave false positive!") + } + p999 := Pair(a37, b27) + p999_2 := Pair(a1, b999) + if p999.String() != p999_2.String() { + t.Error("Pairing mismatch: e(a * 37, b * 27) != e(a, b * 999)") + } +} diff --git a/crypto/bn256/cloudflare/mul.h b/crypto/bn256/cloudflare/mul.h new file mode 100644 index 000000000..bab5da831 --- /dev/null +++ b/crypto/bn256/cloudflare/mul.h @@ -0,0 +1,181 @@ +#define mul(a0,a1,a2,a3, rb, stack) \ + MOVQ a0, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a0, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a0, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a0, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + storeBlock(R8,R9,R10,R11, 0+stack) \ + MOVQ R12, 32+stack \ + \ + MOVQ a1, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a1, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a1, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a1, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 8+stack, R8 \ + ADCQ 16+stack, R9 \ + ADCQ 24+stack, R10 \ + ADCQ 32+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 8+stack) \ + MOVQ R12, 40+stack \ + \ + MOVQ a2, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a2, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a2, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a2, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 16+stack, R8 \ + ADCQ 24+stack, R9 \ + ADCQ 32+stack, R10 \ + ADCQ 40+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 16+stack) \ + MOVQ R12, 48+stack \ + \ + MOVQ a3, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a3, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a3, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a3, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 24+stack, R8 \ + ADCQ 32+stack, R9 \ + ADCQ 40+stack, R10 \ + ADCQ 48+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 24+stack) \ + MOVQ R12, 56+stack + +#define gfpReduce(stack) \ + \ // m = (T * N') mod R, store m in R8:R9:R10:R11 + MOVQ ·np+0(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ ·np+0(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ ·np+0(SB), AX \ + MULQ 16+stack \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ ·np+0(SB), AX \ + MULQ 24+stack \ + ADDQ AX, R11 \ + \ + MOVQ ·np+8(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R12 \ + MOVQ DX, R13 \ + MOVQ ·np+8(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R13 \ + ADCQ $0, DX \ + MOVQ DX, R14 \ + MOVQ ·np+8(SB), AX \ + MULQ 16+stack \ + ADDQ AX, R14 \ + \ + ADDQ R12, R9 \ + ADCQ R13, R10 \ + ADCQ R14, R11 \ + \ + MOVQ ·np+16(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R12 \ + MOVQ DX, R13 \ + MOVQ ·np+16(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R13 \ + \ + ADDQ R12, R10 \ + ADCQ R13, R11 \ + \ + MOVQ ·np+24(SB), AX \ + MULQ 0+stack \ + ADDQ AX, R11 \ + \ + storeBlock(R8,R9,R10,R11, 64+stack) \ + \ + \ // m * N + mul(·p2+0(SB),·p2+8(SB),·p2+16(SB),·p2+24(SB), 64+stack, 96+stack) \ + \ + \ // Add the 512-bit intermediate to m*N + loadBlock(96+stack, R8,R9,R10,R11) \ + loadBlock(128+stack, R12,R13,R14,R15) \ + \ + MOVQ $0, AX \ + ADDQ 0+stack, R8 \ + ADCQ 8+stack, R9 \ + ADCQ 16+stack, R10 \ + ADCQ 24+stack, R11 \ + ADCQ 32+stack, R12 \ + ADCQ 40+stack, R13 \ + ADCQ 48+stack, R14 \ + ADCQ 56+stack, R15 \ + ADCQ $0, AX \ + \ + gfpCarry(R12,R13,R14,R15,AX, R8,R9,R10,R11,BX) diff --git a/crypto/bn256/cloudflare/mul_bmi2.h b/crypto/bn256/cloudflare/mul_bmi2.h new file mode 100644 index 000000000..71ad0499a --- /dev/null +++ b/crypto/bn256/cloudflare/mul_bmi2.h @@ -0,0 +1,112 @@ +#define mulBMI2(a0,a1,a2,a3, rb) \ + MOVQ a0, DX \ + MOVQ $0, R13 \ + MULXQ 0+rb, R8, R9 \ + MULXQ 8+rb, AX, R10 \ + ADDQ AX, R9 \ + MULXQ 16+rb, AX, R11 \ + ADCQ AX, R10 \ + MULXQ 24+rb, AX, R12 \ + ADCQ AX, R11 \ + ADCQ $0, R12 \ + ADCQ $0, R13 \ + \ + MOVQ a1, DX \ + MOVQ $0, R14 \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R9 \ + ADCQ BX, R10 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R11 \ + ADCQ BX, R12 \ + ADCQ $0, R13 \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R12 \ + ADCQ BX, R13 \ + ADCQ $0, R14 \ + \ + MOVQ a2, DX \ + MOVQ $0, R15 \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R12 \ + ADCQ BX, R13 \ + ADCQ $0, R14 \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R11 \ + ADCQ BX, R12 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R13 \ + ADCQ BX, R14 \ + ADCQ $0, R15 \ + \ + MOVQ a3, DX \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R11 \ + ADCQ BX, R12 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R13 \ + ADCQ BX, R14 \ + ADCQ $0, R15 \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R12 \ + ADCQ BX, R13 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R14 \ + ADCQ BX, R15 + +#define gfpReduceBMI2() \ + \ // m = (T * N') mod R, store m in R8:R9:R10:R11 + MOVQ ·np+0(SB), DX \ + MULXQ 0(SP), R8, R9 \ + MULXQ 8(SP), AX, R10 \ + ADDQ AX, R9 \ + MULXQ 16(SP), AX, R11 \ + ADCQ AX, R10 \ + MULXQ 24(SP), AX, BX \ + ADCQ AX, R11 \ + \ + MOVQ ·np+8(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R9 \ + ADCQ BX, R10 \ + MULXQ 16(SP), AX, BX \ + ADCQ AX, R11 \ + MULXQ 8(SP), AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + \ + MOVQ ·np+16(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 8(SP), AX, BX \ + ADDQ AX, R11 \ + \ + MOVQ ·np+24(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R11 \ + \ + storeBlock(R8,R9,R10,R11, 64(SP)) \ + \ + \ // m * N + mulBMI2(·p2+0(SB),·p2+8(SB),·p2+16(SB),·p2+24(SB), 64(SP)) \ + \ + \ // Add the 512-bit intermediate to m*N + MOVQ $0, AX \ + ADDQ 0(SP), R8 \ + ADCQ 8(SP), R9 \ + ADCQ 16(SP), R10 \ + ADCQ 24(SP), R11 \ + ADCQ 32(SP), R12 \ + ADCQ 40(SP), R13 \ + ADCQ 48(SP), R14 \ + ADCQ 56(SP), R15 \ + ADCQ $0, AX \ + \ + gfpCarry(R12,R13,R14,R15,AX, R8,R9,R10,R11,BX) diff --git a/crypto/bn256/cloudflare/optate.go b/crypto/bn256/cloudflare/optate.go new file mode 100644 index 000000000..b71e50e3a --- /dev/null +++ b/crypto/bn256/cloudflare/optate.go @@ -0,0 +1,271 @@ +package bn256 + +func lineFunctionAdd(r, p *twistPoint, q *curvePoint, r2 *gfP2) (a, b, c *gfP2, rOut *twistPoint) { + // See the mixed addition algorithm from "Faster Computation of the + // Tate Pairing", http://arxiv.org/pdf/0904.0854v3.pdf + B := (&gfP2{}).Mul(&p.x, &r.t) + + D := (&gfP2{}).Add(&p.y, &r.z) + D.Square(D).Sub(D, r2).Sub(D, &r.t).Mul(D, &r.t) + + H := (&gfP2{}).Sub(B, &r.x) + I := (&gfP2{}).Square(H) + + E := (&gfP2{}).Add(I, I) + E.Add(E, E) + + J := (&gfP2{}).Mul(H, E) + + L1 := (&gfP2{}).Sub(D, &r.y) + L1.Sub(L1, &r.y) + + V := (&gfP2{}).Mul(&r.x, E) + + rOut = &twistPoint{} + rOut.x.Square(L1).Sub(&rOut.x, J).Sub(&rOut.x, V).Sub(&rOut.x, V) + + rOut.z.Add(&r.z, H).Square(&rOut.z).Sub(&rOut.z, &r.t).Sub(&rOut.z, I) + + t := (&gfP2{}).Sub(V, &rOut.x) + t.Mul(t, L1) + t2 := (&gfP2{}).Mul(&r.y, J) + t2.Add(t2, t2) + rOut.y.Sub(t, t2) + + rOut.t.Square(&rOut.z) + + t.Add(&p.y, &rOut.z).Square(t).Sub(t, r2).Sub(t, &rOut.t) + + t2.Mul(L1, &p.x) + t2.Add(t2, t2) + a = (&gfP2{}).Sub(t2, t) + + c = (&gfP2{}).MulScalar(&rOut.z, &q.y) + c.Add(c, c) + + b = (&gfP2{}).Neg(L1) + b.MulScalar(b, &q.x).Add(b, b) + + return +} + +func lineFunctionDouble(r *twistPoint, q *curvePoint) (a, b, c *gfP2, rOut *twistPoint) { + // See the doubling algorithm for a=0 from "Faster Computation of the + // Tate Pairing", http://arxiv.org/pdf/0904.0854v3.pdf + A := (&gfP2{}).Square(&r.x) + B := (&gfP2{}).Square(&r.y) + C := (&gfP2{}).Square(B) + + D := (&gfP2{}).Add(&r.x, B) + D.Square(D).Sub(D, A).Sub(D, C).Add(D, D) + + E := (&gfP2{}).Add(A, A) + E.Add(E, A) + + G := (&gfP2{}).Square(E) + + rOut = &twistPoint{} + rOut.x.Sub(G, D).Sub(&rOut.x, D) + + rOut.z.Add(&r.y, &r.z).Square(&rOut.z).Sub(&rOut.z, B).Sub(&rOut.z, &r.t) + + rOut.y.Sub(D, &rOut.x).Mul(&rOut.y, E) + t := (&gfP2{}).Add(C, C) + t.Add(t, t).Add(t, t) + rOut.y.Sub(&rOut.y, t) + + rOut.t.Square(&rOut.z) + + t.Mul(E, &r.t).Add(t, t) + b = (&gfP2{}).Neg(t) + b.MulScalar(b, &q.x) + + a = (&gfP2{}).Add(&r.x, E) + a.Square(a).Sub(a, A).Sub(a, G) + t.Add(B, B).Add(t, t) + a.Sub(a, t) + + c = (&gfP2{}).Mul(&rOut.z, &r.t) + c.Add(c, c).MulScalar(c, &q.y) + + return +} + +func mulLine(ret *gfP12, a, b, c *gfP2) { + a2 := &gfP6{} + a2.y.Set(a) + a2.z.Set(b) + a2.Mul(a2, &ret.x) + t3 := (&gfP6{}).MulScalar(&ret.y, c) + + t := (&gfP2{}).Add(b, c) + t2 := &gfP6{} + t2.y.Set(a) + t2.z.Set(t) + ret.x.Add(&ret.x, &ret.y) + + ret.y.Set(t3) + + ret.x.Mul(&ret.x, t2).Sub(&ret.x, a2).Sub(&ret.x, &ret.y) + a2.MulTau(a2) + ret.y.Add(&ret.y, a2) +} + +// sixuPlus2NAF is 6u+2 in non-adjacent form. +var sixuPlus2NAF = []int8{0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0, + 0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1, + 1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, + 1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, 1} + +// miller implements the Miller loop for calculating the Optimal Ate pairing. +// See algorithm 1 from http://cryptojedi.org/papers/dclxvi-20100714.pdf +func miller(q *twistPoint, p *curvePoint) *gfP12 { + ret := (&gfP12{}).SetOne() + + aAffine := &twistPoint{} + aAffine.Set(q) + aAffine.MakeAffine() + + bAffine := &curvePoint{} + bAffine.Set(p) + bAffine.MakeAffine() + + minusA := &twistPoint{} + minusA.Neg(aAffine) + + r := &twistPoint{} + r.Set(aAffine) + + r2 := (&gfP2{}).Square(&aAffine.y) + + for i := len(sixuPlus2NAF) - 1; i > 0; i-- { + a, b, c, newR := lineFunctionDouble(r, bAffine) + if i != len(sixuPlus2NAF)-1 { + ret.Square(ret) + } + + mulLine(ret, a, b, c) + r = newR + + switch sixuPlus2NAF[i-1] { + case 1: + a, b, c, newR = lineFunctionAdd(r, aAffine, bAffine, r2) + case -1: + a, b, c, newR = lineFunctionAdd(r, minusA, bAffine, r2) + default: + continue + } + + mulLine(ret, a, b, c) + r = newR + } + + // In order to calculate Q1 we have to convert q from the sextic twist + // to the full GF(p^12) group, apply the Frobenius there, and convert + // back. + // + // The twist isomorphism is (x', y') -> (xω², yω³). If we consider just + // x for a moment, then after applying the Frobenius, we have x̄ω^(2p) + // where x̄ is the conjugate of x. If we are going to apply the inverse + // isomorphism we need a value with a single coefficient of ω² so we + // rewrite this as x̄ω^(2p-2)ω². ξ⁶ = ω and, due to the construction of + // p, 2p-2 is a multiple of six. Therefore we can rewrite as + // x̄ξ^((p-1)/3)ω² and applying the inverse isomorphism eliminates the + // ω². + // + // A similar argument can be made for the y value. + + q1 := &twistPoint{} + q1.x.Conjugate(&aAffine.x).Mul(&q1.x, xiToPMinus1Over3) + q1.y.Conjugate(&aAffine.y).Mul(&q1.y, xiToPMinus1Over2) + q1.z.SetOne() + q1.t.SetOne() + + // For Q2 we are applying the p² Frobenius. The two conjugations cancel + // out and we are left only with the factors from the isomorphism. In + // the case of x, we end up with a pure number which is why + // xiToPSquaredMinus1Over3 is ∈ GF(p). With y we get a factor of -1. We + // ignore this to end up with -Q2. + + minusQ2 := &twistPoint{} + minusQ2.x.MulScalar(&aAffine.x, xiToPSquaredMinus1Over3) + minusQ2.y.Set(&aAffine.y) + minusQ2.z.SetOne() + minusQ2.t.SetOne() + + r2.Square(&q1.y) + a, b, c, newR := lineFunctionAdd(r, q1, bAffine, r2) + mulLine(ret, a, b, c) + r = newR + + r2.Square(&minusQ2.y) + a, b, c, newR = lineFunctionAdd(r, minusQ2, bAffine, r2) + mulLine(ret, a, b, c) + r = newR + + return ret +} + +// finalExponentiation computes the (p¹²-1)/Order-th power of an element of +// GF(p¹²) to obtain an element of GT (steps 13-15 of algorithm 1 from +// http://cryptojedi.org/papers/dclxvi-20100714.pdf) +func finalExponentiation(in *gfP12) *gfP12 { + t1 := &gfP12{} + + // This is the p^6-Frobenius + t1.x.Neg(&in.x) + t1.y.Set(&in.y) + + inv := &gfP12{} + inv.Invert(in) + t1.Mul(t1, inv) + + t2 := (&gfP12{}).FrobeniusP2(t1) + t1.Mul(t1, t2) + + fp := (&gfP12{}).Frobenius(t1) + fp2 := (&gfP12{}).FrobeniusP2(t1) + fp3 := (&gfP12{}).Frobenius(fp2) + + fu := (&gfP12{}).Exp(t1, u) + fu2 := (&gfP12{}).Exp(fu, u) + fu3 := (&gfP12{}).Exp(fu2, u) + + y3 := (&gfP12{}).Frobenius(fu) + fu2p := (&gfP12{}).Frobenius(fu2) + fu3p := (&gfP12{}).Frobenius(fu3) + y2 := (&gfP12{}).FrobeniusP2(fu2) + + y0 := &gfP12{} + y0.Mul(fp, fp2).Mul(y0, fp3) + + y1 := (&gfP12{}).Conjugate(t1) + y5 := (&gfP12{}).Conjugate(fu2) + y3.Conjugate(y3) + y4 := (&gfP12{}).Mul(fu, fu2p) + y4.Conjugate(y4) + + y6 := (&gfP12{}).Mul(fu3, fu3p) + y6.Conjugate(y6) + + t0 := (&gfP12{}).Square(y6) + t0.Mul(t0, y4).Mul(t0, y5) + t1.Mul(y3, y5).Mul(t1, t0) + t0.Mul(t0, y2) + t1.Square(t1).Mul(t1, t0).Square(t1) + t0.Mul(t1, y1) + t1.Mul(t1, y0) + t0.Square(t0).Mul(t0, t1) + + return t0 +} + +func optimalAte(a *twistPoint, b *curvePoint) *gfP12 { + e := miller(a, b) + ret := finalExponentiation(e) + + if a.IsInfinity() || b.IsInfinity() { + ret.SetOne() + } + return ret +} diff --git a/crypto/bn256/cloudflare/twist.go b/crypto/bn256/cloudflare/twist.go new file mode 100644 index 000000000..0c2f80d4e --- /dev/null +++ b/crypto/bn256/cloudflare/twist.go @@ -0,0 +1,204 @@ +package bn256 + +import ( + "math/big" +) + +// twistPoint implements the elliptic curve y²=x³+3/ξ over GF(p²). Points are +// kept in Jacobian form and t=z² when valid. The group G₂ is the set of +// n-torsion points of this curve over GF(p²) (where n = Order) +type twistPoint struct { + x, y, z, t gfP2 +} + +var twistB = &gfP2{ + gfP{0x38e7ecccd1dcff67, 0x65f0b37d93ce0d3e, 0xd749d0dd22ac00aa, 0x0141b9ce4a688d4d}, + gfP{0x3bf938e377b802a8, 0x020b1b273633535d, 0x26b7edf049755260, 0x2514c6324384a86d}, +} + +// twistGen is the generator of group G₂. +var twistGen = &twistPoint{ + gfP2{ + gfP{0xafb4737da84c6140, 0x6043dd5a5802d8c4, 0x09e950fc52a02f86, 0x14fef0833aea7b6b}, + gfP{0x8e83b5d102bc2026, 0xdceb1935497b0172, 0xfbb8264797811adf, 0x19573841af96503b}, + }, + gfP2{ + gfP{0x64095b56c71856ee, 0xdc57f922327d3cbb, 0x55f935be33351076, 0x0da4a0e693fd6482}, + gfP{0x619dfa9d886be9f6, 0xfe7fd297f59e9b78, 0xff9e1a62231b7dfe, 0x28fd7eebae9e4206}, + }, + gfP2{*newGFp(0), *newGFp(1)}, + gfP2{*newGFp(0), *newGFp(1)}, +} + +func (c *twistPoint) String() string { + c.MakeAffine() + x, y := gfP2Decode(&c.x), gfP2Decode(&c.y) + return "(" + x.String() + ", " + y.String() + ")" +} + +func (c *twistPoint) Set(a *twistPoint) { + c.x.Set(&a.x) + c.y.Set(&a.y) + c.z.Set(&a.z) + c.t.Set(&a.t) +} + +// IsOnCurve returns true iff c is on the curve. +func (c *twistPoint) IsOnCurve() bool { + c.MakeAffine() + if c.IsInfinity() { + return true + } + + y2, x3 := &gfP2{}, &gfP2{} + y2.Square(&c.y) + x3.Square(&c.x).Mul(x3, &c.x).Add(x3, twistB) + + if *y2 != *x3 { + return false + } + cneg := &twistPoint{} + cneg.Mul(c, Order) + return cneg.z.IsZero() +} + +func (c *twistPoint) SetInfinity() { + c.x.SetZero() + c.y.SetOne() + c.z.SetZero() + c.t.SetZero() +} + +func (c *twistPoint) IsInfinity() bool { + return c.z.IsZero() +} + +func (c *twistPoint) Add(a, b *twistPoint) { + // For additional comments, see the same function in curve.go. + + if a.IsInfinity() { + c.Set(b) + return + } + if b.IsInfinity() { + c.Set(a) + return + } + + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3 + z12 := (&gfP2{}).Square(&a.z) + z22 := (&gfP2{}).Square(&b.z) + u1 := (&gfP2{}).Mul(&a.x, z22) + u2 := (&gfP2{}).Mul(&b.x, z12) + + t := (&gfP2{}).Mul(&b.z, z22) + s1 := (&gfP2{}).Mul(&a.y, t) + + t.Mul(&a.z, z12) + s2 := (&gfP2{}).Mul(&b.y, t) + + h := (&gfP2{}).Sub(u2, u1) + xEqual := h.IsZero() + + t.Add(h, h) + i := (&gfP2{}).Square(t) + j := (&gfP2{}).Mul(h, i) + + t.Sub(s2, s1) + yEqual := t.IsZero() + if xEqual && yEqual { + c.Double(a) + return + } + r := (&gfP2{}).Add(t, t) + + v := (&gfP2{}).Mul(u1, i) + + t4 := (&gfP2{}).Square(r) + t.Add(v, v) + t6 := (&gfP2{}).Sub(t4, j) + c.x.Sub(t6, t) + + t.Sub(v, &c.x) // t7 + t4.Mul(s1, j) // t8 + t6.Add(t4, t4) // t9 + t4.Mul(r, t) // t10 + c.y.Sub(t4, t6) + + t.Add(&a.z, &b.z) // t11 + t4.Square(t) // t12 + t.Sub(t4, z12) // t13 + t4.Sub(t, z22) // t14 + c.z.Mul(t4, h) +} + +func (c *twistPoint) Double(a *twistPoint) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + A := (&gfP2{}).Square(&a.x) + B := (&gfP2{}).Square(&a.y) + C := (&gfP2{}).Square(B) + + t := (&gfP2{}).Add(&a.x, B) + t2 := (&gfP2{}).Square(t) + t.Sub(t2, A) + t2.Sub(t, C) + d := (&gfP2{}).Add(t2, t2) + t.Add(A, A) + e := (&gfP2{}).Add(t, A) + f := (&gfP2{}).Square(e) + + t.Add(d, d) + c.x.Sub(f, t) + + t.Add(C, C) + t2.Add(t, t) + t.Add(t2, t2) + c.y.Sub(d, &c.x) + t2.Mul(e, &c.y) + c.y.Sub(t2, t) + + t.Mul(&a.y, &a.z) + c.z.Add(t, t) +} + +func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) { + sum, t := &twistPoint{}, &twistPoint{} + + for i := scalar.BitLen(); i >= 0; i-- { + t.Double(sum) + if scalar.Bit(i) != 0 { + sum.Add(t, a) + } else { + sum.Set(t) + } + } + + c.Set(sum) +} + +func (c *twistPoint) MakeAffine() { + if c.z.IsOne() { + return + } else if c.z.IsZero() { + c.x.SetZero() + c.y.SetOne() + c.t.SetZero() + return + } + + zInv := (&gfP2{}).Invert(&c.z) + t := (&gfP2{}).Mul(&c.y, zInv) + zInv2 := (&gfP2{}).Square(zInv) + c.y.Mul(t, zInv2) + t.Mul(&c.x, zInv2) + c.x.Set(t) + c.z.SetOne() + c.t.SetOne() +} + +func (c *twistPoint) Neg(a *twistPoint) { + c.x.Set(&a.x) + c.y.Neg(&a.y) + c.z.Set(&a.z) + c.t.SetZero() +} diff --git a/crypto/bn256/bn256.go b/crypto/bn256/google/bn256.go index 7144c31a8..5da83e033 100644 --- a/crypto/bn256/bn256.go +++ b/crypto/bn256/google/bn256.go @@ -18,6 +18,7 @@ package bn256 import ( "crypto/rand" + "errors" "io" "math/big" ) @@ -115,21 +116,25 @@ func (n *G1) Marshal() []byte { // Unmarshal sets e to the result of converting the output of Marshal back into // a group element and then returns e. -func (e *G1) Unmarshal(m []byte) (*G1, bool) { +func (e *G1) Unmarshal(m []byte) ([]byte, error) { // Each value is a 256-bit number. const numBytes = 256 / 8 - if len(m) != 2*numBytes { - return nil, false + return nil, errors.New("bn256: not enough data") } - + // Unmarshal the points and check their caps if e.p == nil { e.p = newCurvePoint(nil) } - e.p.x.SetBytes(m[0*numBytes : 1*numBytes]) + if e.p.x.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } e.p.y.SetBytes(m[1*numBytes : 2*numBytes]) - + if e.p.y.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } + // Ensure the point is on the curve if e.p.x.Sign() == 0 && e.p.y.Sign() == 0 { // This is the point at infinity. e.p.y.SetInt64(1) @@ -140,11 +145,10 @@ func (e *G1) Unmarshal(m []byte) (*G1, bool) { e.p.t.SetInt64(1) if !e.p.IsOnCurve() { - return nil, false + return nil, errors.New("bn256: malformed point") } } - - return e, true + return m[2*numBytes:], nil } // G2 is an abstract cyclic group. The zero value is suitable for use as the @@ -233,23 +237,33 @@ func (n *G2) Marshal() []byte { // Unmarshal sets e to the result of converting the output of Marshal back into // a group element and then returns e. -func (e *G2) Unmarshal(m []byte) (*G2, bool) { +func (e *G2) Unmarshal(m []byte) ([]byte, error) { // Each value is a 256-bit number. const numBytes = 256 / 8 - if len(m) != 4*numBytes { - return nil, false + return nil, errors.New("bn256: not enough data") } - + // Unmarshal the points and check their caps if e.p == nil { e.p = newTwistPoint(nil) } - e.p.x.x.SetBytes(m[0*numBytes : 1*numBytes]) + if e.p.x.x.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } e.p.x.y.SetBytes(m[1*numBytes : 2*numBytes]) + if e.p.x.y.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } e.p.y.x.SetBytes(m[2*numBytes : 3*numBytes]) + if e.p.y.x.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } e.p.y.y.SetBytes(m[3*numBytes : 4*numBytes]) - + if e.p.y.y.Cmp(P) >= 0 { + return nil, errors.New("bn256: coordinate exceeds modulus") + } + // Ensure the point is on the curve if e.p.x.x.Sign() == 0 && e.p.x.y.Sign() == 0 && e.p.y.x.Sign() == 0 && @@ -263,11 +277,10 @@ func (e *G2) Unmarshal(m []byte) (*G2, bool) { e.p.t.SetOne() if !e.p.IsOnCurve() { - return nil, false + return nil, errors.New("bn256: malformed point") } } - - return e, true + return m[4*numBytes:], nil } // GT is an abstract cyclic group. The zero value is suitable for use as the diff --git a/crypto/bn256/bn256_test.go b/crypto/bn256/google/bn256_test.go index 866065d0c..a4497ada9 100644 --- a/crypto/bn256/bn256_test.go +++ b/crypto/bn256/google/bn256_test.go @@ -219,15 +219,16 @@ func TestBilinearity(t *testing.T) { func TestG1Marshal(t *testing.T) { g := new(G1).ScalarBaseMult(new(big.Int).SetInt64(1)) form := g.Marshal() - _, ok := new(G1).Unmarshal(form) - if !ok { + _, err := new(G1).Unmarshal(form) + if err != nil { t.Fatalf("failed to unmarshal") } g.ScalarBaseMult(Order) form = g.Marshal() - g2, ok := new(G1).Unmarshal(form) - if !ok { + + g2 := new(G1) + if _, err = g2.Unmarshal(form); err != nil { t.Fatalf("failed to unmarshal ∞") } if !g2.p.IsInfinity() { @@ -238,15 +239,15 @@ func TestG1Marshal(t *testing.T) { func TestG2Marshal(t *testing.T) { g := new(G2).ScalarBaseMult(new(big.Int).SetInt64(1)) form := g.Marshal() - _, ok := new(G2).Unmarshal(form) - if !ok { + _, err := new(G2).Unmarshal(form) + if err != nil { t.Fatalf("failed to unmarshal") } g.ScalarBaseMult(Order) form = g.Marshal() - g2, ok := new(G2).Unmarshal(form) - if !ok { + g2 := new(G2) + if _, err = g2.Unmarshal(form); err != nil { t.Fatalf("failed to unmarshal ∞") } if !g2.p.IsInfinity() { @@ -273,12 +274,18 @@ func TestTripartiteDiffieHellman(t *testing.T) { b, _ := rand.Int(rand.Reader, Order) c, _ := rand.Int(rand.Reader, Order) - pa, _ := new(G1).Unmarshal(new(G1).ScalarBaseMult(a).Marshal()) - qa, _ := new(G2).Unmarshal(new(G2).ScalarBaseMult(a).Marshal()) - pb, _ := new(G1).Unmarshal(new(G1).ScalarBaseMult(b).Marshal()) - qb, _ := new(G2).Unmarshal(new(G2).ScalarBaseMult(b).Marshal()) - pc, _ := new(G1).Unmarshal(new(G1).ScalarBaseMult(c).Marshal()) - qc, _ := new(G2).Unmarshal(new(G2).ScalarBaseMult(c).Marshal()) + pa := new(G1) + pa.Unmarshal(new(G1).ScalarBaseMult(a).Marshal()) + qa := new(G2) + qa.Unmarshal(new(G2).ScalarBaseMult(a).Marshal()) + pb := new(G1) + pb.Unmarshal(new(G1).ScalarBaseMult(b).Marshal()) + qb := new(G2) + qb.Unmarshal(new(G2).ScalarBaseMult(b).Marshal()) + pc := new(G1) + pc.Unmarshal(new(G1).ScalarBaseMult(c).Marshal()) + qc := new(G2) + qc.Unmarshal(new(G2).ScalarBaseMult(c).Marshal()) k1 := Pair(pb, qc) k1.ScalarMult(k1, a) diff --git a/crypto/bn256/constants.go b/crypto/bn256/google/constants.go index ab649d7f3..ab649d7f3 100644 --- a/crypto/bn256/constants.go +++ b/crypto/bn256/google/constants.go diff --git a/crypto/bn256/curve.go b/crypto/bn256/google/curve.go index 3e679fdc7..3e679fdc7 100644 --- a/crypto/bn256/curve.go +++ b/crypto/bn256/google/curve.go diff --git a/crypto/bn256/example_test.go b/crypto/bn256/google/example_test.go index b2d19807a..b2d19807a 100644 --- a/crypto/bn256/example_test.go +++ b/crypto/bn256/google/example_test.go diff --git a/crypto/bn256/gfp12.go b/crypto/bn256/google/gfp12.go index f084eddf2..f084eddf2 100644 --- a/crypto/bn256/gfp12.go +++ b/crypto/bn256/google/gfp12.go diff --git a/crypto/bn256/gfp2.go b/crypto/bn256/google/gfp2.go index 3981f6cb4..3981f6cb4 100644 --- a/crypto/bn256/gfp2.go +++ b/crypto/bn256/google/gfp2.go diff --git a/crypto/bn256/gfp6.go b/crypto/bn256/google/gfp6.go index 218856617..218856617 100644 --- a/crypto/bn256/gfp6.go +++ b/crypto/bn256/google/gfp6.go diff --git a/crypto/bn256/main_test.go b/crypto/bn256/google/main_test.go index 0230f1b19..0230f1b19 100644 --- a/crypto/bn256/main_test.go +++ b/crypto/bn256/google/main_test.go diff --git a/crypto/bn256/optate.go b/crypto/bn256/google/optate.go index 9d6957062..9d6957062 100644 --- a/crypto/bn256/optate.go +++ b/crypto/bn256/google/optate.go diff --git a/crypto/bn256/twist.go b/crypto/bn256/google/twist.go index 95b966e2e..1f5a4d9de 100644 --- a/crypto/bn256/twist.go +++ b/crypto/bn256/google/twist.go @@ -76,7 +76,13 @@ func (c *twistPoint) IsOnCurve() bool { yy.Sub(yy, xxx) yy.Sub(yy, twistB) yy.Minimal() - return yy.x.Sign() == 0 && yy.y.Sign() == 0 + + if yy.x.Sign() != 0 || yy.y.Sign() != 0 { + return false + } + cneg := newTwistPoint(pool) + cneg.Mul(c, Order, pool) + return cneg.z.IsZero() } func (c *twistPoint) SetInfinity() { diff --git a/dashboard/dashboard.go b/dashboard/dashboard.go index 09038638e..2ca795187 100644 --- a/dashboard/dashboard.go +++ b/dashboard/dashboard.go @@ -36,10 +36,10 @@ import ( "github.com/elastic/gosigar" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" - "github.com/rcrowley/go-metrics" "golang.org/x/net/websocket" ) diff --git a/eth/api_backend.go b/eth/api_backend.go index 91f392f94..ecd5488a2 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -104,6 +104,18 @@ func (b *EthApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)), nil } +func (b *EthApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { + receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + if receipts == nil { + return nil, nil + } + logs := make([][]*types.Log, len(receipts)) + for i, receipt := range receipts { + logs[i] = receipt.Logs + } + return logs, nil +} + func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 7ede530a9..70febf4cb 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -27,12 +27,13 @@ import ( ethereum "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/params" - "github.com/rcrowley/go-metrics" ) var ( @@ -221,7 +222,10 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockC quitCh: make(chan struct{}), stateCh: make(chan dataPack), stateSyncStart: make(chan *stateSync), - trackStateReq: make(chan *stateReq), + syncStatsState: stateSyncStats{ + processed: core.GetTrieSyncProgress(stateDb), + }, + trackStateReq: make(chan *stateReq), } go dl.qosTuner() go dl.stateFetcher() diff --git a/eth/downloader/metrics.go b/eth/downloader/metrics.go index 58764ccf0..d4eb33794 100644 --- a/eth/downloader/metrics.go +++ b/eth/downloader/metrics.go @@ -23,21 +23,21 @@ import ( ) var ( - headerInMeter = metrics.NewMeter("eth/downloader/headers/in") - headerReqTimer = metrics.NewTimer("eth/downloader/headers/req") - headerDropMeter = metrics.NewMeter("eth/downloader/headers/drop") - headerTimeoutMeter = metrics.NewMeter("eth/downloader/headers/timeout") + headerInMeter = metrics.NewRegisteredMeter("eth/downloader/headers/in", nil) + headerReqTimer = metrics.NewRegisteredTimer("eth/downloader/headers/req", nil) + headerDropMeter = metrics.NewRegisteredMeter("eth/downloader/headers/drop", nil) + headerTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/headers/timeout", nil) - bodyInMeter = metrics.NewMeter("eth/downloader/bodies/in") - bodyReqTimer = metrics.NewTimer("eth/downloader/bodies/req") - bodyDropMeter = metrics.NewMeter("eth/downloader/bodies/drop") - bodyTimeoutMeter = metrics.NewMeter("eth/downloader/bodies/timeout") + bodyInMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/in", nil) + bodyReqTimer = metrics.NewRegisteredTimer("eth/downloader/bodies/req", nil) + bodyDropMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/drop", nil) + bodyTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/timeout", nil) - receiptInMeter = metrics.NewMeter("eth/downloader/receipts/in") - receiptReqTimer = metrics.NewTimer("eth/downloader/receipts/req") - receiptDropMeter = metrics.NewMeter("eth/downloader/receipts/drop") - receiptTimeoutMeter = metrics.NewMeter("eth/downloader/receipts/timeout") + receiptInMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/in", nil) + receiptReqTimer = metrics.NewRegisteredTimer("eth/downloader/receipts/req", nil) + receiptDropMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/drop", nil) + receiptTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/timeout", nil) - stateInMeter = metrics.NewMeter("eth/downloader/states/in") - stateDropMeter = metrics.NewMeter("eth/downloader/states/drop") + stateInMeter = metrics.NewRegisteredMeter("eth/downloader/states/in", nil) + stateDropMeter = metrics.NewRegisteredMeter("eth/downloader/states/drop", nil) ) diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index a1a70e46e..359cce54b 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -28,7 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" - "github.com/rcrowley/go-metrics" + "github.com/ethereum/go-ethereum/metrics" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 9cc65a208..ee6c7b491 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -23,6 +23,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/ethdb" @@ -466,4 +467,7 @@ func (s *stateSync) updateStats(written, duplicate, unexpected int, duration tim if written > 0 || duplicate > 0 || unexpected > 0 { log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "retry", len(s.tasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) } + if written > 0 { + core.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed) + } } diff --git a/eth/fetcher/metrics.go b/eth/fetcher/metrics.go index 1ed8075bf..d68d12f00 100644 --- a/eth/fetcher/metrics.go +++ b/eth/fetcher/metrics.go @@ -23,21 +23,21 @@ import ( ) var ( - propAnnounceInMeter = metrics.NewMeter("eth/fetcher/prop/announces/in") - propAnnounceOutTimer = metrics.NewTimer("eth/fetcher/prop/announces/out") - propAnnounceDropMeter = metrics.NewMeter("eth/fetcher/prop/announces/drop") - propAnnounceDOSMeter = metrics.NewMeter("eth/fetcher/prop/announces/dos") + propAnnounceInMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/announces/in", nil) + propAnnounceOutTimer = metrics.NewRegisteredTimer("eth/fetcher/prop/announces/out", nil) + propAnnounceDropMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/announces/drop", nil) + propAnnounceDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/announces/dos", nil) - propBroadcastInMeter = metrics.NewMeter("eth/fetcher/prop/broadcasts/in") - propBroadcastOutTimer = metrics.NewTimer("eth/fetcher/prop/broadcasts/out") - propBroadcastDropMeter = metrics.NewMeter("eth/fetcher/prop/broadcasts/drop") - propBroadcastDOSMeter = metrics.NewMeter("eth/fetcher/prop/broadcasts/dos") + propBroadcastInMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/broadcasts/in", nil) + propBroadcastOutTimer = metrics.NewRegisteredTimer("eth/fetcher/prop/broadcasts/out", nil) + propBroadcastDropMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/broadcasts/drop", nil) + propBroadcastDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/prop/broadcasts/dos", nil) - headerFetchMeter = metrics.NewMeter("eth/fetcher/fetch/headers") - bodyFetchMeter = metrics.NewMeter("eth/fetcher/fetch/bodies") + headerFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/fetch/headers", nil) + bodyFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/fetch/bodies", nil) - headerFilterInMeter = metrics.NewMeter("eth/fetcher/filter/headers/in") - headerFilterOutMeter = metrics.NewMeter("eth/fetcher/filter/headers/out") - bodyFilterInMeter = metrics.NewMeter("eth/fetcher/filter/bodies/in") - bodyFilterOutMeter = metrics.NewMeter("eth/fetcher/filter/bodies/out") + headerFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/filter/headers/in", nil) + headerFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/filter/headers/out", nil) + bodyFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/filter/bodies/in", nil) + bodyFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/filter/bodies/out", nil) ) diff --git a/eth/filters/filter.go b/eth/filters/filter.go index 43d7e2a81..5dfe60e77 100644 --- a/eth/filters/filter.go +++ b/eth/filters/filter.go @@ -34,6 +34,7 @@ type Backend interface { EventMux() *event.TypeMux HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) + GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) SubscribeTxPreEvent(chan<- core.TxPreEvent) event.Subscription SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription @@ -201,16 +202,28 @@ func (f *Filter) unindexedLogs(ctx context.Context, end uint64) ([]*types.Log, e // match the filter criteria. This function is called when the bloom filter signals a potential match. func (f *Filter) checkMatches(ctx context.Context, header *types.Header) (logs []*types.Log, err error) { // Get the logs of the block - receipts, err := f.backend.GetReceipts(ctx, header.Hash()) + logsList, err := f.backend.GetLogs(ctx, header.Hash()) if err != nil { return nil, err } var unfiltered []*types.Log - for _, receipt := range receipts { - unfiltered = append(unfiltered, receipt.Logs...) + for _, logs := range logsList { + unfiltered = append(unfiltered, logs...) } logs = filterLogs(unfiltered, nil, nil, f.addresses, f.topics) if len(logs) > 0 { + // We have matching logs, check if we need to resolve full logs via the light client + if logs[0].TxHash == (common.Hash{}) { + receipts, err := f.backend.GetReceipts(ctx, header.Hash()) + if err != nil { + return nil, err + } + unfiltered = unfiltered[:0] + for _, receipt := range receipts { + unfiltered = append(unfiltered, receipt.Logs...) + } + logs = filterLogs(unfiltered, nil, nil, f.addresses, f.topics) + } return logs, nil } return nil, nil diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index b09998f9c..f8097c7b9 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -375,19 +375,35 @@ func (es *EventSystem) lightFilterLogs(header *types.Header, addresses []common. // Get the logs of the block ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - receipts, err := es.backend.GetReceipts(ctx, header.Hash()) + logsList, err := es.backend.GetLogs(ctx, header.Hash()) if err != nil { return nil } var unfiltered []*types.Log - for _, receipt := range receipts { - for _, log := range receipt.Logs { + for _, logs := range logsList { + for _, log := range logs { logcopy := *log logcopy.Removed = remove unfiltered = append(unfiltered, &logcopy) } } logs := filterLogs(unfiltered, nil, nil, addresses, topics) + if len(logs) > 0 && logs[0].TxHash == (common.Hash{}) { + // We have matching but non-derived logs + receipts, err := es.backend.GetReceipts(ctx, header.Hash()) + if err != nil { + return nil + } + unfiltered = unfiltered[:0] + for _, receipt := range receipts { + for _, log := range receipt.Logs { + logcopy := *log + logcopy.Removed = remove + unfiltered = append(unfiltered, &logcopy) + } + } + logs = filterLogs(unfiltered, nil, nil, addresses, topics) + } return logs } return nil diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index 7ec3b4be7..61761151a 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -69,8 +69,19 @@ func (b *testBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumbe } func (b *testBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - num := core.GetBlockNumber(b.db, blockHash) - return core.GetBlockReceipts(b.db, blockHash, num), nil + number := core.GetBlockNumber(b.db, blockHash) + return core.GetBlockReceipts(b.db, blockHash, number), nil +} + +func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { + number := core.GetBlockNumber(b.db, blockHash) + receipts := core.GetBlockReceipts(b.db, blockHash, number) + + logs := make([][]*types.Log, len(receipts)) + for i, receipt := range receipts { + logs[i] = receipt.Logs + } + return logs, nil } func (b *testBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { diff --git a/eth/handler.go b/eth/handler.go index c2426544f..3fae0cd00 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -249,7 +249,8 @@ func (pm *ProtocolManager) newPeer(pv int, p *p2p.Peer, rw p2p.MsgReadWriter) *p // handle is the callback invoked to manage the life cycle of an eth peer. When // this function terminates, the peer is disconnected. func (pm *ProtocolManager) handle(p *peer) error { - if pm.peers.Len() >= pm.maxPeers { + // Ignore maxPeers if this is a trusted peer + if pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { return p2p.DiscTooManyPeers } p.Log().Debug("Ethereum peer connected", "name", p.Name()) diff --git a/eth/metrics.go b/eth/metrics.go index 5fa2597d4..0533a2a87 100644 --- a/eth/metrics.go +++ b/eth/metrics.go @@ -22,38 +22,38 @@ import ( ) var ( - propTxnInPacketsMeter = metrics.NewMeter("eth/prop/txns/in/packets") - propTxnInTrafficMeter = metrics.NewMeter("eth/prop/txns/in/traffic") - propTxnOutPacketsMeter = metrics.NewMeter("eth/prop/txns/out/packets") - propTxnOutTrafficMeter = metrics.NewMeter("eth/prop/txns/out/traffic") - propHashInPacketsMeter = metrics.NewMeter("eth/prop/hashes/in/packets") - propHashInTrafficMeter = metrics.NewMeter("eth/prop/hashes/in/traffic") - propHashOutPacketsMeter = metrics.NewMeter("eth/prop/hashes/out/packets") - propHashOutTrafficMeter = metrics.NewMeter("eth/prop/hashes/out/traffic") - propBlockInPacketsMeter = metrics.NewMeter("eth/prop/blocks/in/packets") - propBlockInTrafficMeter = metrics.NewMeter("eth/prop/blocks/in/traffic") - propBlockOutPacketsMeter = metrics.NewMeter("eth/prop/blocks/out/packets") - propBlockOutTrafficMeter = metrics.NewMeter("eth/prop/blocks/out/traffic") - reqHeaderInPacketsMeter = metrics.NewMeter("eth/req/headers/in/packets") - reqHeaderInTrafficMeter = metrics.NewMeter("eth/req/headers/in/traffic") - reqHeaderOutPacketsMeter = metrics.NewMeter("eth/req/headers/out/packets") - reqHeaderOutTrafficMeter = metrics.NewMeter("eth/req/headers/out/traffic") - reqBodyInPacketsMeter = metrics.NewMeter("eth/req/bodies/in/packets") - reqBodyInTrafficMeter = metrics.NewMeter("eth/req/bodies/in/traffic") - reqBodyOutPacketsMeter = metrics.NewMeter("eth/req/bodies/out/packets") - reqBodyOutTrafficMeter = metrics.NewMeter("eth/req/bodies/out/traffic") - reqStateInPacketsMeter = metrics.NewMeter("eth/req/states/in/packets") - reqStateInTrafficMeter = metrics.NewMeter("eth/req/states/in/traffic") - reqStateOutPacketsMeter = metrics.NewMeter("eth/req/states/out/packets") - reqStateOutTrafficMeter = metrics.NewMeter("eth/req/states/out/traffic") - reqReceiptInPacketsMeter = metrics.NewMeter("eth/req/receipts/in/packets") - reqReceiptInTrafficMeter = metrics.NewMeter("eth/req/receipts/in/traffic") - reqReceiptOutPacketsMeter = metrics.NewMeter("eth/req/receipts/out/packets") - reqReceiptOutTrafficMeter = metrics.NewMeter("eth/req/receipts/out/traffic") - miscInPacketsMeter = metrics.NewMeter("eth/misc/in/packets") - miscInTrafficMeter = metrics.NewMeter("eth/misc/in/traffic") - miscOutPacketsMeter = metrics.NewMeter("eth/misc/out/packets") - miscOutTrafficMeter = metrics.NewMeter("eth/misc/out/traffic") + propTxnInPacketsMeter = metrics.NewRegisteredMeter("eth/prop/txns/in/packets", nil) + propTxnInTrafficMeter = metrics.NewRegisteredMeter("eth/prop/txns/in/traffic", nil) + propTxnOutPacketsMeter = metrics.NewRegisteredMeter("eth/prop/txns/out/packets", nil) + propTxnOutTrafficMeter = metrics.NewRegisteredMeter("eth/prop/txns/out/traffic", nil) + propHashInPacketsMeter = metrics.NewRegisteredMeter("eth/prop/hashes/in/packets", nil) + propHashInTrafficMeter = metrics.NewRegisteredMeter("eth/prop/hashes/in/traffic", nil) + propHashOutPacketsMeter = metrics.NewRegisteredMeter("eth/prop/hashes/out/packets", nil) + propHashOutTrafficMeter = metrics.NewRegisteredMeter("eth/prop/hashes/out/traffic", nil) + propBlockInPacketsMeter = metrics.NewRegisteredMeter("eth/prop/blocks/in/packets", nil) + propBlockInTrafficMeter = metrics.NewRegisteredMeter("eth/prop/blocks/in/traffic", nil) + propBlockOutPacketsMeter = metrics.NewRegisteredMeter("eth/prop/blocks/out/packets", nil) + propBlockOutTrafficMeter = metrics.NewRegisteredMeter("eth/prop/blocks/out/traffic", nil) + reqHeaderInPacketsMeter = metrics.NewRegisteredMeter("eth/req/headers/in/packets", nil) + reqHeaderInTrafficMeter = metrics.NewRegisteredMeter("eth/req/headers/in/traffic", nil) + reqHeaderOutPacketsMeter = metrics.NewRegisteredMeter("eth/req/headers/out/packets", nil) + reqHeaderOutTrafficMeter = metrics.NewRegisteredMeter("eth/req/headers/out/traffic", nil) + reqBodyInPacketsMeter = metrics.NewRegisteredMeter("eth/req/bodies/in/packets", nil) + reqBodyInTrafficMeter = metrics.NewRegisteredMeter("eth/req/bodies/in/traffic", nil) + reqBodyOutPacketsMeter = metrics.NewRegisteredMeter("eth/req/bodies/out/packets", nil) + reqBodyOutTrafficMeter = metrics.NewRegisteredMeter("eth/req/bodies/out/traffic", nil) + reqStateInPacketsMeter = metrics.NewRegisteredMeter("eth/req/states/in/packets", nil) + reqStateInTrafficMeter = metrics.NewRegisteredMeter("eth/req/states/in/traffic", nil) + reqStateOutPacketsMeter = metrics.NewRegisteredMeter("eth/req/states/out/packets", nil) + reqStateOutTrafficMeter = metrics.NewRegisteredMeter("eth/req/states/out/traffic", nil) + reqReceiptInPacketsMeter = metrics.NewRegisteredMeter("eth/req/receipts/in/packets", nil) + reqReceiptInTrafficMeter = metrics.NewRegisteredMeter("eth/req/receipts/in/traffic", nil) + reqReceiptOutPacketsMeter = metrics.NewRegisteredMeter("eth/req/receipts/out/packets", nil) + reqReceiptOutTrafficMeter = metrics.NewRegisteredMeter("eth/req/receipts/out/traffic", nil) + miscInPacketsMeter = metrics.NewRegisteredMeter("eth/misc/in/packets", nil) + miscInTrafficMeter = metrics.NewRegisteredMeter("eth/misc/in/traffic", nil) + miscOutPacketsMeter = metrics.NewRegisteredMeter("eth/misc/out/packets", nil) + miscOutTrafficMeter = metrics.NewRegisteredMeter("eth/misc/out/traffic", nil) ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of diff --git a/ethdb/database.go b/ethdb/database.go index d86585f07..57d38f7f5 100644 --- a/ethdb/database.go +++ b/ethdb/database.go @@ -29,8 +29,6 @@ import ( "github.com/syndtr/goleveldb/leveldb/filter" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" - - gometrics "github.com/rcrowley/go-metrics" ) var OpenFileLimit = 64 @@ -39,15 +37,15 @@ type LDBDatabase struct { fn string // filename for reporting db *leveldb.DB // LevelDB instance - getTimer gometrics.Timer // Timer for measuring the database get request counts and latencies - putTimer gometrics.Timer // Timer for measuring the database put request counts and latencies - delTimer gometrics.Timer // Timer for measuring the database delete request counts and latencies - missMeter gometrics.Meter // Meter for measuring the missed database get requests - readMeter gometrics.Meter // Meter for measuring the database get request data usage - writeMeter gometrics.Meter // Meter for measuring the database put request data usage - compTimeMeter gometrics.Meter // Meter for measuring the total time spent in database compaction - compReadMeter gometrics.Meter // Meter for measuring the data read during compaction - compWriteMeter gometrics.Meter // Meter for measuring the data written during compaction + getTimer metrics.Timer // Timer for measuring the database get request counts and latencies + putTimer metrics.Timer // Timer for measuring the database put request counts and latencies + delTimer metrics.Timer // Timer for measuring the database delete request counts and latencies + missMeter metrics.Meter // Meter for measuring the missed database get requests + readMeter metrics.Meter // Meter for measuring the database get request data usage + writeMeter metrics.Meter // Meter for measuring the database put request data usage + compTimeMeter metrics.Meter // Meter for measuring the total time spent in database compaction + compReadMeter metrics.Meter // Meter for measuring the data read during compaction + compWriteMeter metrics.Meter // Meter for measuring the data written during compaction quitLock sync.Mutex // Mutex protecting the quit channel access quitChan chan chan error // Quit channel to stop the metrics collection before closing the database @@ -180,15 +178,15 @@ func (db *LDBDatabase) Meter(prefix string) { return } // Initialize all the metrics collector at the requested prefix - db.getTimer = metrics.NewTimer(prefix + "user/gets") - db.putTimer = metrics.NewTimer(prefix + "user/puts") - db.delTimer = metrics.NewTimer(prefix + "user/dels") - db.missMeter = metrics.NewMeter(prefix + "user/misses") - db.readMeter = metrics.NewMeter(prefix + "user/reads") - db.writeMeter = metrics.NewMeter(prefix + "user/writes") - db.compTimeMeter = metrics.NewMeter(prefix + "compact/time") - db.compReadMeter = metrics.NewMeter(prefix + "compact/input") - db.compWriteMeter = metrics.NewMeter(prefix + "compact/output") + db.getTimer = metrics.NewRegisteredTimer(prefix+"user/gets", nil) + db.putTimer = metrics.NewRegisteredTimer(prefix+"user/puts", nil) + db.delTimer = metrics.NewRegisteredTimer(prefix+"user/dels", nil) + db.missMeter = metrics.NewRegisteredMeter(prefix+"user/misses", nil) + db.readMeter = metrics.NewRegisteredMeter(prefix+"user/reads", nil) + db.writeMeter = metrics.NewRegisteredMeter(prefix+"user/writes", nil) + db.compTimeMeter = metrics.NewRegisteredMeter(prefix+"compact/time", nil) + db.compReadMeter = metrics.NewRegisteredMeter(prefix+"compact/input", nil) + db.compWriteMeter = metrics.NewRegisteredMeter(prefix+"compact/output", nil) // Create a quit channel for the periodic collector and run it db.quitLock.Lock() diff --git a/internal/debug/api.go b/internal/debug/api.go index 3547b0564..048b7d763 100644 --- a/internal/debug/api.go +++ b/internal/debug/api.go @@ -140,10 +140,9 @@ func (h *HandlerT) GoTrace(file string, nsec uint) error { return nil } -// BlockProfile turns on CPU profiling for nsec seconds and writes -// profile data to file. It uses a profile rate of 1 for most accurate -// information. If a different rate is desired, set the rate -// and write the profile manually. +// BlockProfile turns on goroutine profiling for nsec seconds and writes profile data to +// file. It uses a profile rate of 1 for most accurate information. If a different rate is +// desired, set the rate and write the profile manually. func (*HandlerT) BlockProfile(file string, nsec uint) error { runtime.SetBlockProfileRate(1) time.Sleep(time.Duration(nsec) * time.Second) @@ -162,6 +161,26 @@ func (*HandlerT) WriteBlockProfile(file string) error { return writeProfile("block", file) } +// MutexProfile turns on mutex profiling for nsec seconds and writes profile data to file. +// It uses a profile rate of 1 for most accurate information. If a different rate is +// desired, set the rate and write the profile manually. +func (*HandlerT) MutexProfile(file string, nsec uint) error { + runtime.SetMutexProfileFraction(1) + time.Sleep(time.Duration(nsec) * time.Second) + defer runtime.SetMutexProfileFraction(0) + return writeProfile("mutex", file) +} + +// SetMutexProfileFraction sets the rate of mutex profiling. +func (*HandlerT) SetMutexProfileFraction(rate int) { + runtime.SetMutexProfileFraction(rate) +} + +// WriteMutexProfile writes a goroutine blocking profile to the given file. +func (*HandlerT) WriteMutexProfile(file string) error { + return writeProfile("mutex", file) +} + // WriteMemProfile writes an allocation profile to the given file. // Note that the profiling rate cannot be set through the API, // it must be set on the command line. diff --git a/internal/debug/flags.go b/internal/debug/flags.go index 6247cc7dc..1f181bf8b 100644 --- a/internal/debug/flags.go +++ b/internal/debug/flags.go @@ -26,6 +26,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log/term" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/metrics/exp" colorable "github.com/mattn/go-colorable" "gopkg.in/urfave/cli.v1" ) @@ -127,6 +129,10 @@ func Setup(ctx *cli.Context) error { // pprof server if ctx.GlobalBool(pprofFlag.Name) { + // Hook go-metrics into expvar on any /debug/metrics request, load all vars + // from the registry into expvar, and execute regular expvar handler. + exp.Exp(metrics.DefaultRegistry) + address := fmt.Sprintf("%s:%d", ctx.GlobalString(pprofAddrFlag.Name), ctx.GlobalInt(pprofPortFlag.Name)) go func() { log.Info("Starting pprof server", "addr", fmt.Sprintf("http://%s/debug/pprof", address)) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 314086335..e49244404 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -611,7 +611,7 @@ type CallArgs struct { Data hexutil.Bytes `json:"data"` } -func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config) ([]byte, uint64, bool, error) { +func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) state, header, err := s.b.StateAndHeaderByNumber(ctx, blockNr) @@ -630,7 +630,7 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // Set default gas & gas price if none were set gas, gasPrice := uint64(args.Gas), args.GasPrice.ToInt() if gas == 0 { - gas = 50000000 + gas = math.MaxUint64 / 2 } if gasPrice.Sign() == 0 { gasPrice = new(big.Int).SetUint64(defaultGasPrice) @@ -642,14 +642,14 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. var cancel context.CancelFunc - if vmCfg.DisableGasMetering { - ctx, cancel = context.WithTimeout(ctx, time.Second*5) + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) } else { ctx, cancel = context.WithCancel(ctx) } // Make sure the context is cancelled when the call has completed // this makes sure resources are cleaned up. - defer func() { cancel() }() + defer cancel() // Get a new instance of the EVM. evm, vmError, err := s.b.GetEVM(ctx, msg, state, header, vmCfg) @@ -676,7 +676,7 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - result, _, _, err := s.doCall(ctx, args, blockNr, vm.Config{DisableGasMetering: true}) + result, _, _, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) return (hexutil.Bytes)(result), err } @@ -705,7 +705,7 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h executable := func(gas uint64) bool { args.Gas = hexutil.Uint64(gas) - _, _, failed, err := s.doCall(ctx, args, rpc.PendingBlockNumber, vm.Config{}) + _, _, failed, err := s.doCall(ctx, args, rpc.PendingBlockNumber, vm.Config{}, 0) if err != nil || failed { return false } @@ -1032,15 +1032,19 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, } // GetTransactionReceipt returns the transaction receipt for the given transaction hash. -func (s *PublicTransactionPoolAPI) GetTransactionReceipt(hash common.Hash) (map[string]interface{}, error) { +func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash) if tx == nil { - return nil, errors.New("unknown transaction") + return nil, nil } - receipt, _, _, _ := core.GetReceipt(s.b.ChainDb(), hash) // Old receipts don't have the lookup data available - if receipt == nil { - return nil, errors.New("unknown receipt") + receipts, err := s.b.GetReceipts(ctx, blockHash) + if err != nil { + return nil, err } + if len(receipts) <= int(index) { + return nil, nil + } + receipt := receipts[index] var signer types.Signer = types.FrontierSigner{} if tx.Protected() { @@ -1135,6 +1139,18 @@ func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { if args.Data != nil && args.Input != nil && !bytes.Equal(*args.Data, *args.Input) { return errors.New(`Both "data" and "input" are set and not equal. Please use "input" to pass transaction call data.`) } + if args.To == nil { + // Contract creation + var input []byte + if args.Data != nil { + input = *args.Data + } else if args.Input != nil { + input = *args.Input + } + if len(input) == 0 { + return errors.New(`contract creation without any data provided`) + } + } return nil } diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index a6b81b4c2..9d6ce8c6c 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -308,6 +308,21 @@ web3._extend({ params: 1 }), new web3._extend.Method({ + name: 'mutexProfile', + call: 'debug_mutexProfile', + params: 2 + }), + new web3._extend.Method({ + name: 'setMutexProfileRate', + call: 'debug_setMutexProfileRate', + params: 1 + }), + new web3._extend.Method({ + name: 'writeMutexProfile', + call: 'debug_writeMutexProfile', + params: 1 + }), + new web3._extend.Method({ name: 'writeMemProfile', call: 'debug_writeMemProfile', params: 1 diff --git a/les/api_backend.go b/les/api_backend.go index 56f617a7d..3fc5c33a4 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -87,6 +87,10 @@ func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) } +func (b *LesApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { + return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) +} + func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } diff --git a/les/fetcher.go b/les/fetcher.go index 9d224176f..e12a2c78a 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -36,24 +36,26 @@ const ( maxNodeCount = 20 // maximum number of fetcherTreeNode entries remembered for each peer ) -// lightFetcher +// lightFetcher implements retrieval of newly announced headers. It also provides a peerHasBlock function for the +// ODR system to ensure that we only request data related to a certain block from peers who have already processed +// and announced that block. type lightFetcher struct { pm *ProtocolManager odr *LesOdr chain *light.LightChain + lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests maxConfirmedTd *big.Int peers map[*peer]*fetcherPeerInfo lastUpdateStats *updateStatsEntry + syncing bool + syncDone chan *peer - lock sync.Mutex // qwerqwerqwe - deliverChn chan fetchResponse - reqMu sync.RWMutex + reqMu sync.RWMutex // reqMu protects access to sent header fetch requests requested map[uint64]fetchRequest + deliverChn chan fetchResponse timeoutChn chan uint64 requestChn chan bool // true if initiated from outside - syncing bool - syncDone chan *peer } // fetcherPeerInfo holds fetcher-specific information about each active peer @@ -560,8 +562,13 @@ func (f *lightFetcher) checkAnnouncedHeaders(fp *fetcherPeerInfo, headers []*typ return true } // we ran out of recently delivered headers but have not reached a node known by this peer yet, continue matching - td = f.chain.GetTd(header.ParentHash, header.Number.Uint64()-1) - header = f.chain.GetHeader(header.ParentHash, header.Number.Uint64()-1) + hash, number := header.ParentHash, header.Number.Uint64()-1 + td = f.chain.GetTd(hash, number) + header = f.chain.GetHeader(hash, number) + if header == nil || td == nil { + log.Error("Missing parent of validated header", "hash", hash, "number", number) + return false + } } else { header = headers[i] td = tds[i] @@ -645,13 +652,18 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool { if td == nil { return false } + header := f.chain.GetHeader(n.hash, n.number) + // check the availability of both header and td because reads are not protected by chain db mutex + // Note: returning false is always safe here + if header == nil { + return false + } fp := f.peers[p] if fp == nil { p.Log().Debug("Unknown peer to check known nodes") return false } - header := f.chain.GetHeader(n.hash, n.number) if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) { p.Log().Debug("Inconsistent announcement") go f.pm.removePeer(p.id) diff --git a/les/handler.go b/les/handler.go index 864abe605..9627f392b 100644 --- a/les/handler.go +++ b/les/handler.go @@ -260,7 +260,8 @@ func (pm *ProtocolManager) newPeer(pv int, nv uint64, p *p2p.Peer, rw p2p.MsgRea // handle is the callback invoked to manage the life cycle of a les peer. When // this function terminates, the peer is disconnected. func (pm *ProtocolManager) handle(p *peer) error { - if pm.peers.Len() >= pm.maxPeers { + // Ignore maxPeers if this is a trusted peer + if pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { return p2p.DiscTooManyPeers } diff --git a/les/metrics.go b/les/metrics.go index 0162a1d1a..c282a62a1 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -58,10 +58,10 @@ var ( reqReceiptInTrafficMeter = metrics.NewMeter("eth/req/receipts/in/traffic") reqReceiptOutPacketsMeter = metrics.NewMeter("eth/req/receipts/out/packets") reqReceiptOutTrafficMeter = metrics.NewMeter("eth/req/receipts/out/traffic")*/ - miscInPacketsMeter = metrics.NewMeter("les/misc/in/packets") - miscInTrafficMeter = metrics.NewMeter("les/misc/in/traffic") - miscOutPacketsMeter = metrics.NewMeter("les/misc/out/packets") - miscOutTrafficMeter = metrics.NewMeter("les/misc/out/traffic") + miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets", nil) + miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic", nil) + miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets", nil) + miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic", nil) ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of diff --git a/light/lightchain.go b/light/lightchain.go index 181a1c2a6..2784615d3 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -171,9 +171,6 @@ func (bc *LightChain) SetHead(head uint64) { // GasLimit returns the gas limit of the current HEAD block. func (self *LightChain) GasLimit() uint64 { - self.mu.RLock() - defer self.mu.RUnlock() - return self.hc.CurrentHeader().GasLimit } @@ -387,9 +384,6 @@ func (self *LightChain) InsertHeaderChain(chain []*types.Header, checkFreq int) // CurrentHeader retrieves the current head header of the canonical chain. The // header is retrieved from the HeaderChain's internal cache. func (self *LightChain) CurrentHeader() *types.Header { - self.mu.RLock() - defer self.mu.RUnlock() - return self.hc.CurrentHeader() } diff --git a/light/odr_util.go b/light/odr_util.go index 8f92d6442..97ba440ac 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -126,15 +126,48 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint // GetBlockReceipts retrieves the receipts generated by the transactions included // in a block given by its hash. func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (types.Receipts, error) { + // Retrieve the potentially incomplete receipts from disk or network receipts := core.GetBlockReceipts(odr.Database(), hash, number) - if receipts != nil { - return receipts, nil + if receipts == nil { + r := &ReceiptsRequest{Hash: hash, Number: number} + if err := odr.Retrieve(ctx, r); err != nil { + return nil, err + } + receipts = r.Receipts } - r := &ReceiptsRequest{Hash: hash, Number: number} - if err := odr.Retrieve(ctx, r); err != nil { - return nil, err + // If the receipts are incomplete, fill the derived fields + if len(receipts) > 0 && receipts[0].TxHash == (common.Hash{}) { + block, err := GetBlock(ctx, odr, hash, number) + if err != nil { + return nil, err + } + genesis := core.GetCanonicalHash(odr.Database(), 0) + config, _ := core.GetChainConfig(odr.Database(), genesis) + + core.SetReceiptsData(config, block, receipts) + core.WriteBlockReceipts(odr.Database(), hash, number, receipts) + } + return receipts, nil +} + +// GetBlockLogs retrieves the logs generated by the transactions included in a +// block given by its hash. +func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) ([][]*types.Log, error) { + // Retrieve the potentially incomplete receipts from disk or network + receipts := core.GetBlockReceipts(odr.Database(), hash, number) + if receipts == nil { + r := &ReceiptsRequest{Hash: hash, Number: number} + if err := odr.Retrieve(ctx, r); err != nil { + return nil, err + } + receipts = r.Receipts + } + // Return the logs without deriving any computed fields on the receipts + logs := make([][]*types.Log, len(receipts)) + for i, receipt := range receipts { + logs[i] = receipt.Logs } - return r.Receipts, nil + return logs, nil } // GetBloomBits retrieves a batch of compressed bloomBits vectors belonging to the given bit index and section indexes diff --git a/light/postprocess.go b/light/postprocess.go index 84149fdaa..384a635f7 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -58,18 +58,18 @@ type trustedCheckpoint struct { var ( mainnetCheckpoint = trustedCheckpoint{ name: "mainnet", - sectionIdx: 153, - sectionHead: common.HexToHash("04c2114a8cbe49ba5c37a03cc4b4b8d3adfc0bd2c78e0e726405dd84afca1d63"), - chtRoot: common.HexToHash("d7ec603e5d30b567a6e894ee7704e4603232f206d3e5a589794cec0c57bf318e"), - bloomTrieRoot: common.HexToHash("0b139b8fb692e21f663ff200da287192201c28ef5813c1ac6ba02a0a4799eef9"), + sectionIdx: 157, + sectionHead: common.HexToHash("1963c080887ca7f406c2bb114293eea83e54f783f94df24b447f7e3b6317c747"), + chtRoot: common.HexToHash("42abc436567dfb678a38fa6a9f881aa4c8a4cc8eaa2def08359292c3d0bd48ec"), + bloomTrieRoot: common.HexToHash("281c9f8fb3cb8b37ae45e9907ef8f3b19cd22c54e297c2d6c09c1db1593dce42"), } ropstenCheckpoint = trustedCheckpoint{ name: "ropsten", - sectionIdx: 79, - sectionHead: common.HexToHash("1b1ba890510e06411fdee9bb64ca7705c56a1a4ce3559ddb34b3680c526cb419"), - chtRoot: common.HexToHash("71d60207af74e5a22a3e1cfbfc89f9944f91b49aa980c86fba94d568369eaf44"), - bloomTrieRoot: common.HexToHash("70aca4b3b6d08dde8704c95cedb1420394453c1aec390947751e69ff8c436360"), + sectionIdx: 83, + sectionHead: common.HexToHash("3ca623586bc0da35f1fc8d9b6b55950f3b1f69be9c6501846a2df672adb61236"), + chtRoot: common.HexToHash("8f08ec7783969768c6ef06e5fe3398223cbf4ae2907b676da7b6fe6c7f55b059"), + bloomTrieRoot: common.HexToHash("02d86d3c6a87f8f8a92c2a59bbba2132ff6f9f61b0915a5dc28a9d8279219fd0"), } ) diff --git a/metrics/FORK.md b/metrics/FORK.md new file mode 100644 index 000000000..b19985bf5 --- /dev/null +++ b/metrics/FORK.md @@ -0,0 +1 @@ +This repo has been forked from https://github.com/rcrowley/go-metrics at commit e181e09 diff --git a/vendor/github.com/rcrowley/go-metrics/LICENSE b/metrics/LICENSE index 363fa9ee7..363fa9ee7 100644 --- a/vendor/github.com/rcrowley/go-metrics/LICENSE +++ b/metrics/LICENSE diff --git a/vendor/github.com/rcrowley/go-metrics/README.md b/metrics/README.md index 2d1a6dcfa..bc2a45a83 100644 --- a/vendor/github.com/rcrowley/go-metrics/README.md +++ b/metrics/README.md @@ -42,12 +42,22 @@ t.Update(47) Register() is not threadsafe. For threadsafe metric registration use GetOrRegister: -``` +```go t := metrics.GetOrRegisterTimer("account.create.latency", nil) t.Time(func() {}) t.Update(47) ``` +**NOTE:** Be sure to unregister short-lived meters and timers otherwise they will +leak memory: + +```go +// Will call Stop() on the Meter to allow for garbage collection +metrics.Unregister("quux") +// Or similarly for a Timer that embeds a Meter +metrics.Unregister("bang") +``` + Periodically log every metric in human-readable form to standard error: ```go @@ -81,12 +91,13 @@ issues [#121](https://github.com/rcrowley/go-metrics/issues/121) and ```go import "github.com/vrischmann/go-metrics-influxdb" -go influxdb.Influxdb(metrics.DefaultRegistry, 10e9, &influxdb.Config{ - Host: "127.0.0.1:8086", - Database: "metrics", - Username: "test", - Password: "test", -}) +go influxdb.InfluxDB(metrics.DefaultRegistry, + 10e9, + "127.0.0.1:8086", + "database-name", + "username", + "password" +) ``` Periodically upload every metric to Librato using the [Librato client](https://github.com/mihasya/go-metrics-librato): @@ -146,8 +157,10 @@ Publishing Metrics Clients are available for the following destinations: -* Librato - [https://github.com/mihasya/go-metrics-librato](https://github.com/mihasya/go-metrics-librato) -* Graphite - [https://github.com/cyberdelia/go-metrics-graphite](https://github.com/cyberdelia/go-metrics-graphite) -* InfluxDB - [https://github.com/vrischmann/go-metrics-influxdb](https://github.com/vrischmann/go-metrics-influxdb) -* Ganglia - [https://github.com/appscode/metlia](https://github.com/appscode/metlia) -* Prometheus - [https://github.com/deathowl/go-metrics-prometheus](https://github.com/deathowl/go-metrics-prometheus) +* Librato - https://github.com/mihasya/go-metrics-librato +* Graphite - https://github.com/cyberdelia/go-metrics-graphite +* InfluxDB - https://github.com/vrischmann/go-metrics-influxdb +* Ganglia - https://github.com/appscode/metlia +* Prometheus - https://github.com/deathowl/go-metrics-prometheus +* DataDog - https://github.com/syntaqx/go-metrics-datadog +* SignalFX - https://github.com/pascallouisperez/go-metrics-signalfx diff --git a/vendor/github.com/rcrowley/go-metrics/counter.go b/metrics/counter.go index bb7b039cb..c7f2b4bd3 100644 --- a/vendor/github.com/rcrowley/go-metrics/counter.go +++ b/metrics/counter.go @@ -22,7 +22,7 @@ func GetOrRegisterCounter(name string, r Registry) Counter { // NewCounter constructs a new StandardCounter. func NewCounter() Counter { - if UseNilMetrics { + if !Enabled { return NilCounter{} } return &StandardCounter{0} diff --git a/metrics/counter_test.go b/metrics/counter_test.go new file mode 100644 index 000000000..dfb03b4e8 --- /dev/null +++ b/metrics/counter_test.go @@ -0,0 +1,77 @@ +package metrics + +import "testing" + +func BenchmarkCounter(b *testing.B) { + c := NewCounter() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Inc(1) + } +} + +func TestCounterClear(t *testing.T) { + c := NewCounter() + c.Inc(1) + c.Clear() + if count := c.Count(); 0 != count { + t.Errorf("c.Count(): 0 != %v\n", count) + } +} + +func TestCounterDec1(t *testing.T) { + c := NewCounter() + c.Dec(1) + if count := c.Count(); -1 != count { + t.Errorf("c.Count(): -1 != %v\n", count) + } +} + +func TestCounterDec2(t *testing.T) { + c := NewCounter() + c.Dec(2) + if count := c.Count(); -2 != count { + t.Errorf("c.Count(): -2 != %v\n", count) + } +} + +func TestCounterInc1(t *testing.T) { + c := NewCounter() + c.Inc(1) + if count := c.Count(); 1 != count { + t.Errorf("c.Count(): 1 != %v\n", count) + } +} + +func TestCounterInc2(t *testing.T) { + c := NewCounter() + c.Inc(2) + if count := c.Count(); 2 != count { + t.Errorf("c.Count(): 2 != %v\n", count) + } +} + +func TestCounterSnapshot(t *testing.T) { + c := NewCounter() + c.Inc(1) + snapshot := c.Snapshot() + c.Inc(1) + if count := snapshot.Count(); 1 != count { + t.Errorf("c.Count(): 1 != %v\n", count) + } +} + +func TestCounterZero(t *testing.T) { + c := NewCounter() + if count := c.Count(); 0 != count { + t.Errorf("c.Count(): 0 != %v\n", count) + } +} + +func TestGetOrRegisterCounter(t *testing.T) { + r := NewRegistry() + NewRegisteredCounter("foo", r).Inc(47) + if c := GetOrRegisterCounter("foo", r); 47 != c.Count() { + t.Fatal(c) + } +} diff --git a/vendor/github.com/rcrowley/go-metrics/debug.go b/metrics/debug.go index 043ccefab..de4a2739f 100644 --- a/vendor/github.com/rcrowley/go-metrics/debug.go +++ b/metrics/debug.go @@ -22,7 +22,7 @@ var ( // Capture new values for the Go garbage collector statistics exported in // debug.GCStats. This is designed to be called as a goroutine. func CaptureDebugGCStats(r Registry, d time.Duration) { - for _ = range time.Tick(d) { + for range time.Tick(d) { CaptureDebugGCStatsOnce(r) } } @@ -41,8 +41,8 @@ func CaptureDebugGCStatsOnce(r Registry) { debug.ReadGCStats(&gcStats) debugMetrics.ReadGCStats.UpdateSince(t) - debugMetrics.GCStats.LastGC.Update(int64(gcStats.LastGC.UnixNano())) - debugMetrics.GCStats.NumGC.Update(int64(gcStats.NumGC)) + debugMetrics.GCStats.LastGC.Update(gcStats.LastGC.UnixNano()) + debugMetrics.GCStats.NumGC.Update(gcStats.NumGC) if lastGC != gcStats.LastGC && 0 < len(gcStats.Pause) { debugMetrics.GCStats.Pause.Update(int64(gcStats.Pause[0])) } diff --git a/metrics/debug_test.go b/metrics/debug_test.go new file mode 100644 index 000000000..07eb86784 --- /dev/null +++ b/metrics/debug_test.go @@ -0,0 +1,48 @@ +package metrics + +import ( + "runtime" + "runtime/debug" + "testing" + "time" +) + +func BenchmarkDebugGCStats(b *testing.B) { + r := NewRegistry() + RegisterDebugGCStats(r) + b.ResetTimer() + for i := 0; i < b.N; i++ { + CaptureDebugGCStatsOnce(r) + } +} + +func TestDebugGCStatsBlocking(t *testing.T) { + if g := runtime.GOMAXPROCS(0); g < 2 { + t.Skipf("skipping TestDebugGCMemStatsBlocking with GOMAXPROCS=%d\n", g) + return + } + ch := make(chan int) + go testDebugGCStatsBlocking(ch) + var gcStats debug.GCStats + t0 := time.Now() + debug.ReadGCStats(&gcStats) + t1 := time.Now() + t.Log("i++ during debug.ReadGCStats:", <-ch) + go testDebugGCStatsBlocking(ch) + d := t1.Sub(t0) + t.Log(d) + time.Sleep(d) + t.Log("i++ during time.Sleep:", <-ch) +} + +func testDebugGCStatsBlocking(ch chan int) { + i := 0 + for { + select { + case ch <- i: + return + default: + i++ + } + } +} diff --git a/vendor/github.com/rcrowley/go-metrics/ewma.go b/metrics/ewma.go index 694a1d033..3aecd4fa3 100644 --- a/vendor/github.com/rcrowley/go-metrics/ewma.go +++ b/metrics/ewma.go @@ -17,7 +17,7 @@ type EWMA interface { // NewEWMA constructs a new EWMA with the given alpha. func NewEWMA(alpha float64) EWMA { - if UseNilMetrics { + if !Enabled { return NilEWMA{} } return &StandardEWMA{alpha: alpha} diff --git a/metrics/ewma_test.go b/metrics/ewma_test.go new file mode 100644 index 000000000..0430fbd24 --- /dev/null +++ b/metrics/ewma_test.go @@ -0,0 +1,225 @@ +package metrics + +import "testing" + +func BenchmarkEWMA(b *testing.B) { + a := NewEWMA1() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Update(1) + a.Tick() + } +} + +func TestEWMA1(t *testing.T) { + a := NewEWMA1() + a.Update(3) + a.Tick() + if rate := a.Rate(); 0.6 != rate { + t.Errorf("initial a.Rate(): 0.6 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.22072766470286553 != rate { + t.Errorf("1 minute a.Rate(): 0.22072766470286553 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.08120116994196772 != rate { + t.Errorf("2 minute a.Rate(): 0.08120116994196772 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.029872241020718428 != rate { + t.Errorf("3 minute a.Rate(): 0.029872241020718428 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.01098938333324054 != rate { + t.Errorf("4 minute a.Rate(): 0.01098938333324054 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.004042768199451294 != rate { + t.Errorf("5 minute a.Rate(): 0.004042768199451294 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.0014872513059998212 != rate { + t.Errorf("6 minute a.Rate(): 0.0014872513059998212 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.0005471291793327122 != rate { + t.Errorf("7 minute a.Rate(): 0.0005471291793327122 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.00020127757674150815 != rate { + t.Errorf("8 minute a.Rate(): 0.00020127757674150815 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 7.404588245200814e-05 != rate { + t.Errorf("9 minute a.Rate(): 7.404588245200814e-05 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 2.7239957857491083e-05 != rate { + t.Errorf("10 minute a.Rate(): 2.7239957857491083e-05 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 1.0021020474147462e-05 != rate { + t.Errorf("11 minute a.Rate(): 1.0021020474147462e-05 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 3.6865274119969525e-06 != rate { + t.Errorf("12 minute a.Rate(): 3.6865274119969525e-06 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 1.3561976441886433e-06 != rate { + t.Errorf("13 minute a.Rate(): 1.3561976441886433e-06 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 4.989172314621449e-07 != rate { + t.Errorf("14 minute a.Rate(): 4.989172314621449e-07 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 1.8354139230109722e-07 != rate { + t.Errorf("15 minute a.Rate(): 1.8354139230109722e-07 != %v\n", rate) + } +} + +func TestEWMA5(t *testing.T) { + a := NewEWMA5() + a.Update(3) + a.Tick() + if rate := a.Rate(); 0.6 != rate { + t.Errorf("initial a.Rate(): 0.6 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.49123845184678905 != rate { + t.Errorf("1 minute a.Rate(): 0.49123845184678905 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.4021920276213837 != rate { + t.Errorf("2 minute a.Rate(): 0.4021920276213837 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.32928698165641596 != rate { + t.Errorf("3 minute a.Rate(): 0.32928698165641596 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.269597378470333 != rate { + t.Errorf("4 minute a.Rate(): 0.269597378470333 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.2207276647028654 != rate { + t.Errorf("5 minute a.Rate(): 0.2207276647028654 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.18071652714732128 != rate { + t.Errorf("6 minute a.Rate(): 0.18071652714732128 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.14795817836496392 != rate { + t.Errorf("7 minute a.Rate(): 0.14795817836496392 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.12113791079679326 != rate { + t.Errorf("8 minute a.Rate(): 0.12113791079679326 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.09917933293295193 != rate { + t.Errorf("9 minute a.Rate(): 0.09917933293295193 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.08120116994196763 != rate { + t.Errorf("10 minute a.Rate(): 0.08120116994196763 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.06648189501740036 != rate { + t.Errorf("11 minute a.Rate(): 0.06648189501740036 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.05443077197364752 != rate { + t.Errorf("12 minute a.Rate(): 0.05443077197364752 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.04456414692860035 != rate { + t.Errorf("13 minute a.Rate(): 0.04456414692860035 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.03648603757513079 != rate { + t.Errorf("14 minute a.Rate(): 0.03648603757513079 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.0298722410207183831020718428 != rate { + t.Errorf("15 minute a.Rate(): 0.0298722410207183831020718428 != %v\n", rate) + } +} + +func TestEWMA15(t *testing.T) { + a := NewEWMA15() + a.Update(3) + a.Tick() + if rate := a.Rate(); 0.6 != rate { + t.Errorf("initial a.Rate(): 0.6 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.5613041910189706 != rate { + t.Errorf("1 minute a.Rate(): 0.5613041910189706 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.5251039914257684 != rate { + t.Errorf("2 minute a.Rate(): 0.5251039914257684 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.4912384518467888184678905 != rate { + t.Errorf("3 minute a.Rate(): 0.4912384518467888184678905 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.459557003018789 != rate { + t.Errorf("4 minute a.Rate(): 0.459557003018789 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.4299187863442732 != rate { + t.Errorf("5 minute a.Rate(): 0.4299187863442732 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.4021920276213831 != rate { + t.Errorf("6 minute a.Rate(): 0.4021920276213831 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.37625345116383313 != rate { + t.Errorf("7 minute a.Rate(): 0.37625345116383313 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.3519877317060185 != rate { + t.Errorf("8 minute a.Rate(): 0.3519877317060185 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.3292869816564153165641596 != rate { + t.Errorf("9 minute a.Rate(): 0.3292869816564153165641596 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.3080502714195546 != rate { + t.Errorf("10 minute a.Rate(): 0.3080502714195546 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.2881831806538789 != rate { + t.Errorf("11 minute a.Rate(): 0.2881831806538789 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.26959737847033216 != rate { + t.Errorf("12 minute a.Rate(): 0.26959737847033216 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.2522102307052083 != rate { + t.Errorf("13 minute a.Rate(): 0.2522102307052083 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.23594443252115815 != rate { + t.Errorf("14 minute a.Rate(): 0.23594443252115815 != %v\n", rate) + } + elapseMinute(a) + if rate := a.Rate(); 0.2207276647028646247028654470286553 != rate { + t.Errorf("15 minute a.Rate(): 0.2207276647028646247028654470286553 != %v\n", rate) + } +} + +func elapseMinute(a EWMA) { + for i := 0; i < 12; i++ { + a.Tick() + } +} diff --git a/vendor/github.com/rcrowley/go-metrics/exp/exp.go b/metrics/exp/exp.go index 11dd3f898..c19d00a94 100644 --- a/vendor/github.com/rcrowley/go-metrics/exp/exp.go +++ b/metrics/exp/exp.go @@ -8,7 +8,7 @@ import ( "net/http" "sync" - "github.com/rcrowley/go-metrics" + "github.com/ethereum/go-ethereum/metrics" ) type exp struct { @@ -97,22 +97,22 @@ func (exp *exp) publishHistogram(name string, metric metrics.Histogram) { exp.getInt(name + ".count").Set(h.Count()) exp.getFloat(name + ".min").Set(float64(h.Min())) exp.getFloat(name + ".max").Set(float64(h.Max())) - exp.getFloat(name + ".mean").Set(float64(h.Mean())) - exp.getFloat(name + ".std-dev").Set(float64(h.StdDev())) - exp.getFloat(name + ".50-percentile").Set(float64(ps[0])) - exp.getFloat(name + ".75-percentile").Set(float64(ps[1])) - exp.getFloat(name + ".95-percentile").Set(float64(ps[2])) - exp.getFloat(name + ".99-percentile").Set(float64(ps[3])) - exp.getFloat(name + ".999-percentile").Set(float64(ps[4])) + exp.getFloat(name + ".mean").Set(h.Mean()) + exp.getFloat(name + ".std-dev").Set(h.StdDev()) + exp.getFloat(name + ".50-percentile").Set(ps[0]) + exp.getFloat(name + ".75-percentile").Set(ps[1]) + exp.getFloat(name + ".95-percentile").Set(ps[2]) + exp.getFloat(name + ".99-percentile").Set(ps[3]) + exp.getFloat(name + ".999-percentile").Set(ps[4]) } func (exp *exp) publishMeter(name string, metric metrics.Meter) { m := metric.Snapshot() exp.getInt(name + ".count").Set(m.Count()) - exp.getFloat(name + ".one-minute").Set(float64(m.Rate1())) - exp.getFloat(name + ".five-minute").Set(float64(m.Rate5())) - exp.getFloat(name + ".fifteen-minute").Set(float64((m.Rate15()))) - exp.getFloat(name + ".mean").Set(float64(m.RateMean())) + exp.getFloat(name + ".one-minute").Set(m.Rate1()) + exp.getFloat(name + ".five-minute").Set(m.Rate5()) + exp.getFloat(name + ".fifteen-minute").Set((m.Rate15())) + exp.getFloat(name + ".mean").Set(m.RateMean()) } func (exp *exp) publishTimer(name string, metric metrics.Timer) { @@ -121,17 +121,17 @@ func (exp *exp) publishTimer(name string, metric metrics.Timer) { exp.getInt(name + ".count").Set(t.Count()) exp.getFloat(name + ".min").Set(float64(t.Min())) exp.getFloat(name + ".max").Set(float64(t.Max())) - exp.getFloat(name + ".mean").Set(float64(t.Mean())) - exp.getFloat(name + ".std-dev").Set(float64(t.StdDev())) - exp.getFloat(name + ".50-percentile").Set(float64(ps[0])) - exp.getFloat(name + ".75-percentile").Set(float64(ps[1])) - exp.getFloat(name + ".95-percentile").Set(float64(ps[2])) - exp.getFloat(name + ".99-percentile").Set(float64(ps[3])) - exp.getFloat(name + ".999-percentile").Set(float64(ps[4])) - exp.getFloat(name + ".one-minute").Set(float64(t.Rate1())) - exp.getFloat(name + ".five-minute").Set(float64(t.Rate5())) - exp.getFloat(name + ".fifteen-minute").Set(float64((t.Rate15()))) - exp.getFloat(name + ".mean-rate").Set(float64(t.RateMean())) + exp.getFloat(name + ".mean").Set(t.Mean()) + exp.getFloat(name + ".std-dev").Set(t.StdDev()) + exp.getFloat(name + ".50-percentile").Set(ps[0]) + exp.getFloat(name + ".75-percentile").Set(ps[1]) + exp.getFloat(name + ".95-percentile").Set(ps[2]) + exp.getFloat(name + ".99-percentile").Set(ps[3]) + exp.getFloat(name + ".999-percentile").Set(ps[4]) + exp.getFloat(name + ".one-minute").Set(t.Rate1()) + exp.getFloat(name + ".five-minute").Set(t.Rate5()) + exp.getFloat(name + ".fifteen-minute").Set(t.Rate15()) + exp.getFloat(name + ".mean-rate").Set(t.RateMean()) } func (exp *exp) syncToExpvar() { diff --git a/vendor/github.com/rcrowley/go-metrics/gauge.go b/metrics/gauge.go index cb57a9388..0fbfdb860 100644 --- a/vendor/github.com/rcrowley/go-metrics/gauge.go +++ b/metrics/gauge.go @@ -20,7 +20,7 @@ func GetOrRegisterGauge(name string, r Registry) Gauge { // NewGauge constructs a new StandardGauge. func NewGauge() Gauge { - if UseNilMetrics { + if !Enabled { return NilGauge{} } return &StandardGauge{0} @@ -38,7 +38,7 @@ func NewRegisteredGauge(name string, r Registry) Gauge { // NewFunctionalGauge constructs a new FunctionalGauge. func NewFunctionalGauge(f func() int64) Gauge { - if UseNilMetrics { + if !Enabled { return NilGauge{} } return &FunctionalGauge{value: f} diff --git a/vendor/github.com/rcrowley/go-metrics/gauge_float64.go b/metrics/gauge_float64.go index 6f93920b2..66819c957 100644 --- a/vendor/github.com/rcrowley/go-metrics/gauge_float64.go +++ b/metrics/gauge_float64.go @@ -20,7 +20,7 @@ func GetOrRegisterGaugeFloat64(name string, r Registry) GaugeFloat64 { // NewGaugeFloat64 constructs a new StandardGaugeFloat64. func NewGaugeFloat64() GaugeFloat64 { - if UseNilMetrics { + if !Enabled { return NilGaugeFloat64{} } return &StandardGaugeFloat64{ @@ -40,7 +40,7 @@ func NewRegisteredGaugeFloat64(name string, r Registry) GaugeFloat64 { // NewFunctionalGauge constructs a new FunctionalGauge. func NewFunctionalGaugeFloat64(f func() float64) GaugeFloat64 { - if UseNilMetrics { + if !Enabled { return NilGaugeFloat64{} } return &FunctionalGaugeFloat64{value: f} diff --git a/metrics/gauge_float64_test.go b/metrics/gauge_float64_test.go new file mode 100644 index 000000000..99e62a403 --- /dev/null +++ b/metrics/gauge_float64_test.go @@ -0,0 +1,59 @@ +package metrics + +import "testing" + +func BenchmarkGuageFloat64(b *testing.B) { + g := NewGaugeFloat64() + b.ResetTimer() + for i := 0; i < b.N; i++ { + g.Update(float64(i)) + } +} + +func TestGaugeFloat64(t *testing.T) { + g := NewGaugeFloat64() + g.Update(float64(47.0)) + if v := g.Value(); float64(47.0) != v { + t.Errorf("g.Value(): 47.0 != %v\n", v) + } +} + +func TestGaugeFloat64Snapshot(t *testing.T) { + g := NewGaugeFloat64() + g.Update(float64(47.0)) + snapshot := g.Snapshot() + g.Update(float64(0)) + if v := snapshot.Value(); float64(47.0) != v { + t.Errorf("g.Value(): 47.0 != %v\n", v) + } +} + +func TestGetOrRegisterGaugeFloat64(t *testing.T) { + r := NewRegistry() + NewRegisteredGaugeFloat64("foo", r).Update(float64(47.0)) + t.Logf("registry: %v", r) + if g := GetOrRegisterGaugeFloat64("foo", r); float64(47.0) != g.Value() { + t.Fatal(g) + } +} + +func TestFunctionalGaugeFloat64(t *testing.T) { + var counter float64 + fg := NewFunctionalGaugeFloat64(func() float64 { + counter++ + return counter + }) + fg.Value() + fg.Value() + if counter != 2 { + t.Error("counter != 2") + } +} + +func TestGetOrRegisterFunctionalGaugeFloat64(t *testing.T) { + r := NewRegistry() + NewRegisteredFunctionalGaugeFloat64("foo", r, func() float64 { return 47 }) + if g := GetOrRegisterGaugeFloat64("foo", r); 47 != g.Value() { + t.Fatal(g) + } +} diff --git a/metrics/gauge_test.go b/metrics/gauge_test.go new file mode 100644 index 000000000..1f2603d33 --- /dev/null +++ b/metrics/gauge_test.go @@ -0,0 +1,68 @@ +package metrics + +import ( + "fmt" + "testing" +) + +func BenchmarkGuage(b *testing.B) { + g := NewGauge() + b.ResetTimer() + for i := 0; i < b.N; i++ { + g.Update(int64(i)) + } +} + +func TestGauge(t *testing.T) { + g := NewGauge() + g.Update(int64(47)) + if v := g.Value(); 47 != v { + t.Errorf("g.Value(): 47 != %v\n", v) + } +} + +func TestGaugeSnapshot(t *testing.T) { + g := NewGauge() + g.Update(int64(47)) + snapshot := g.Snapshot() + g.Update(int64(0)) + if v := snapshot.Value(); 47 != v { + t.Errorf("g.Value(): 47 != %v\n", v) + } +} + +func TestGetOrRegisterGauge(t *testing.T) { + r := NewRegistry() + NewRegisteredGauge("foo", r).Update(47) + if g := GetOrRegisterGauge("foo", r); 47 != g.Value() { + t.Fatal(g) + } +} + +func TestFunctionalGauge(t *testing.T) { + var counter int64 + fg := NewFunctionalGauge(func() int64 { + counter++ + return counter + }) + fg.Value() + fg.Value() + if counter != 2 { + t.Error("counter != 2") + } +} + +func TestGetOrRegisterFunctionalGauge(t *testing.T) { + r := NewRegistry() + NewRegisteredFunctionalGauge("foo", r, func() int64 { return 47 }) + if g := GetOrRegisterGauge("foo", r); 47 != g.Value() { + t.Fatal(g) + } +} + +func ExampleGetOrRegisterGauge() { + m := "server.bytes_sent" + g := GetOrRegisterGauge(m, nil) + g.Update(47) + fmt.Println(g.Value()) // Output: 47 +} diff --git a/vendor/github.com/rcrowley/go-metrics/graphite.go b/metrics/graphite.go index abd0a7d29..142eec86b 100644 --- a/vendor/github.com/rcrowley/go-metrics/graphite.go +++ b/metrics/graphite.go @@ -39,7 +39,7 @@ func Graphite(r Registry, d time.Duration, prefix string, addr *net.TCPAddr) { // but it takes a GraphiteConfig instead. func GraphiteWithConfig(c GraphiteConfig) { log.Printf("WARNING: This go-metrics client has been DEPRECATED! It has been moved to https://github.com/cyberdelia/go-metrics-graphite and will be removed from rcrowley/go-metrics on August 12th 2015") - for _ = range time.Tick(c.FlushInterval) { + for range time.Tick(c.FlushInterval) { if err := graphite(&c); nil != err { log.Println(err) } diff --git a/metrics/graphite_test.go b/metrics/graphite_test.go new file mode 100644 index 000000000..c797c781d --- /dev/null +++ b/metrics/graphite_test.go @@ -0,0 +1,22 @@ +package metrics + +import ( + "net" + "time" +) + +func ExampleGraphite() { + addr, _ := net.ResolveTCPAddr("net", ":2003") + go Graphite(DefaultRegistry, 1*time.Second, "some.prefix", addr) +} + +func ExampleGraphiteWithConfig() { + addr, _ := net.ResolveTCPAddr("net", ":2003") + go GraphiteWithConfig(GraphiteConfig{ + Addr: addr, + Registry: DefaultRegistry, + FlushInterval: 1 * time.Second, + DurationUnit: time.Millisecond, + Percentiles: []float64{0.5, 0.75, 0.99, 0.999}, + }) +} diff --git a/vendor/github.com/rcrowley/go-metrics/healthcheck.go b/metrics/healthcheck.go index 445131cae..f1ae31e34 100644 --- a/vendor/github.com/rcrowley/go-metrics/healthcheck.go +++ b/metrics/healthcheck.go @@ -11,7 +11,7 @@ type Healthcheck interface { // NewHealthcheck constructs a new Healthcheck which will use the given // function to update its status. func NewHealthcheck(f func(Healthcheck)) Healthcheck { - if UseNilMetrics { + if !Enabled { return NilHealthcheck{} } return &StandardHealthcheck{nil, f} diff --git a/vendor/github.com/rcrowley/go-metrics/histogram.go b/metrics/histogram.go index dbc837fe4..46f3bbd2f 100644 --- a/vendor/github.com/rcrowley/go-metrics/histogram.go +++ b/metrics/histogram.go @@ -28,7 +28,7 @@ func GetOrRegisterHistogram(name string, r Registry, s Sample) Histogram { // NewHistogram constructs a new StandardHistogram from a Sample. func NewHistogram(s Sample) Histogram { - if UseNilMetrics { + if !Enabled { return NilHistogram{} } return &StandardHistogram{sample: s} diff --git a/metrics/histogram_test.go b/metrics/histogram_test.go new file mode 100644 index 000000000..d7f4f0171 --- /dev/null +++ b/metrics/histogram_test.go @@ -0,0 +1,95 @@ +package metrics + +import "testing" + +func BenchmarkHistogram(b *testing.B) { + h := NewHistogram(NewUniformSample(100)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + h.Update(int64(i)) + } +} + +func TestGetOrRegisterHistogram(t *testing.T) { + r := NewRegistry() + s := NewUniformSample(100) + NewRegisteredHistogram("foo", r, s).Update(47) + if h := GetOrRegisterHistogram("foo", r, s); 1 != h.Count() { + t.Fatal(h) + } +} + +func TestHistogram10000(t *testing.T) { + h := NewHistogram(NewUniformSample(100000)) + for i := 1; i <= 10000; i++ { + h.Update(int64(i)) + } + testHistogram10000(t, h) +} + +func TestHistogramEmpty(t *testing.T) { + h := NewHistogram(NewUniformSample(100)) + if count := h.Count(); 0 != count { + t.Errorf("h.Count(): 0 != %v\n", count) + } + if min := h.Min(); 0 != min { + t.Errorf("h.Min(): 0 != %v\n", min) + } + if max := h.Max(); 0 != max { + t.Errorf("h.Max(): 0 != %v\n", max) + } + if mean := h.Mean(); 0.0 != mean { + t.Errorf("h.Mean(): 0.0 != %v\n", mean) + } + if stdDev := h.StdDev(); 0.0 != stdDev { + t.Errorf("h.StdDev(): 0.0 != %v\n", stdDev) + } + ps := h.Percentiles([]float64{0.5, 0.75, 0.99}) + if 0.0 != ps[0] { + t.Errorf("median: 0.0 != %v\n", ps[0]) + } + if 0.0 != ps[1] { + t.Errorf("75th percentile: 0.0 != %v\n", ps[1]) + } + if 0.0 != ps[2] { + t.Errorf("99th percentile: 0.0 != %v\n", ps[2]) + } +} + +func TestHistogramSnapshot(t *testing.T) { + h := NewHistogram(NewUniformSample(100000)) + for i := 1; i <= 10000; i++ { + h.Update(int64(i)) + } + snapshot := h.Snapshot() + h.Update(0) + testHistogram10000(t, snapshot) +} + +func testHistogram10000(t *testing.T, h Histogram) { + if count := h.Count(); 10000 != count { + t.Errorf("h.Count(): 10000 != %v\n", count) + } + if min := h.Min(); 1 != min { + t.Errorf("h.Min(): 1 != %v\n", min) + } + if max := h.Max(); 10000 != max { + t.Errorf("h.Max(): 10000 != %v\n", max) + } + if mean := h.Mean(); 5000.5 != mean { + t.Errorf("h.Mean(): 5000.5 != %v\n", mean) + } + if stdDev := h.StdDev(); 2886.751331514372 != stdDev { + t.Errorf("h.StdDev(): 2886.751331514372 != %v\n", stdDev) + } + ps := h.Percentiles([]float64{0.5, 0.75, 0.99}) + if 5000.5 != ps[0] { + t.Errorf("median: 5000.5 != %v\n", ps[0]) + } + if 7500.75 != ps[1] { + t.Errorf("75th percentile: 7500.75 != %v\n", ps[1]) + } + if 9900.99 != ps[2] { + t.Errorf("99th percentile: 9900.99 != %v\n", ps[2]) + } +} diff --git a/metrics/influxdb/LICENSE b/metrics/influxdb/LICENSE new file mode 100644 index 000000000..e5bf20cdb --- /dev/null +++ b/metrics/influxdb/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Vincent Rischmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/metrics/influxdb/README.md b/metrics/influxdb/README.md new file mode 100644 index 000000000..b76b1a3f9 --- /dev/null +++ b/metrics/influxdb/README.md @@ -0,0 +1,30 @@ +go-metrics-influxdb +=================== + +This is a reporter for the [go-metrics](https://github.com/rcrowley/go-metrics) library which will post the metrics to [InfluxDB](https://influxdb.com/). + +Note +---- + +This is only compatible with InfluxDB 0.9+. + +Usage +----- + +```go +import "github.com/vrischmann/go-metrics-influxdb" + +go influxdb.InfluxDB( + metrics.DefaultRegistry, // metrics registry + time.Second * 10, // interval + "http://localhost:8086", // the InfluxDB url + "mydb", // your InfluxDB database + "myuser", // your InfluxDB user + "mypassword", // your InfluxDB password +) +``` + +License +------- + +go-metrics-influxdb is licensed under the MIT license. See the LICENSE file for details. diff --git a/metrics/influxdb/influxdb.go b/metrics/influxdb/influxdb.go new file mode 100644 index 000000000..d5cb4da66 --- /dev/null +++ b/metrics/influxdb/influxdb.go @@ -0,0 +1,227 @@ +package influxdb + +import ( + "fmt" + "log" + uurl "net/url" + "time" + + "github.com/ethereum/go-ethereum/metrics" + "github.com/influxdata/influxdb/client" +) + +type reporter struct { + reg metrics.Registry + interval time.Duration + + url uurl.URL + database string + username string + password string + namespace string + tags map[string]string + + client *client.Client + + cache map[string]int64 +} + +// InfluxDB starts a InfluxDB reporter which will post the from the given metrics.Registry at each d interval. +func InfluxDB(r metrics.Registry, d time.Duration, url, database, username, password, namespace string) { + InfluxDBWithTags(r, d, url, database, username, password, namespace, nil) +} + +// InfluxDBWithTags starts a InfluxDB reporter which will post the from the given metrics.Registry at each d interval with the specified tags +func InfluxDBWithTags(r metrics.Registry, d time.Duration, url, database, username, password, namespace string, tags map[string]string) { + u, err := uurl.Parse(url) + if err != nil { + log.Printf("unable to parse InfluxDB url %s. err=%v", url, err) + return + } + + rep := &reporter{ + reg: r, + interval: d, + url: *u, + database: database, + username: username, + password: password, + namespace: namespace, + tags: tags, + cache: make(map[string]int64), + } + if err := rep.makeClient(); err != nil { + log.Printf("unable to make InfluxDB client. err=%v", err) + return + } + + rep.run() +} + +func (r *reporter) makeClient() (err error) { + r.client, err = client.NewClient(client.Config{ + URL: r.url, + Username: r.username, + Password: r.password, + }) + + return +} + +func (r *reporter) run() { + intervalTicker := time.Tick(r.interval) + pingTicker := time.Tick(time.Second * 5) + + for { + select { + case <-intervalTicker: + if err := r.send(); err != nil { + log.Printf("unable to send to InfluxDB. err=%v", err) + } + case <-pingTicker: + _, _, err := r.client.Ping() + if err != nil { + log.Printf("got error while sending a ping to InfluxDB, trying to recreate client. err=%v", err) + + if err = r.makeClient(); err != nil { + log.Printf("unable to make InfluxDB client. err=%v", err) + } + } + } + } +} + +func (r *reporter) send() error { + var pts []client.Point + + r.reg.Each(func(name string, i interface{}) { + now := time.Now() + namespace := r.namespace + + switch metric := i.(type) { + case metrics.Counter: + v := metric.Count() + l := r.cache[name] + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.count", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "value": v - l, + }, + Time: now, + }) + r.cache[name] = v + case metrics.Gauge: + ms := metric.Snapshot() + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.gauge", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "value": ms.Value(), + }, + Time: now, + }) + case metrics.GaugeFloat64: + ms := metric.Snapshot() + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.gauge", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "value": ms.Value(), + }, + Time: now, + }) + case metrics.Histogram: + ms := metric.Snapshot() + ps := ms.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999, 0.9999}) + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.histogram", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "count": ms.Count(), + "max": ms.Max(), + "mean": ms.Mean(), + "min": ms.Min(), + "stddev": ms.StdDev(), + "variance": ms.Variance(), + "p50": ps[0], + "p75": ps[1], + "p95": ps[2], + "p99": ps[3], + "p999": ps[4], + "p9999": ps[5], + }, + Time: now, + }) + case metrics.Meter: + ms := metric.Snapshot() + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.meter", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "count": ms.Count(), + "m1": ms.Rate1(), + "m5": ms.Rate5(), + "m15": ms.Rate15(), + "mean": ms.RateMean(), + }, + Time: now, + }) + case metrics.Timer: + ms := metric.Snapshot() + ps := ms.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999, 0.9999}) + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.timer", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "count": ms.Count(), + "max": ms.Max(), + "mean": ms.Mean(), + "min": ms.Min(), + "stddev": ms.StdDev(), + "variance": ms.Variance(), + "p50": ps[0], + "p75": ps[1], + "p95": ps[2], + "p99": ps[3], + "p999": ps[4], + "p9999": ps[5], + "m1": ms.Rate1(), + "m5": ms.Rate5(), + "m15": ms.Rate15(), + "meanrate": ms.RateMean(), + }, + Time: now, + }) + case metrics.ResettingTimer: + t := metric.Snapshot() + + if len(t.Values()) > 0 { + ps := t.Percentiles([]float64{50, 95, 99}) + val := t.Values() + pts = append(pts, client.Point{ + Measurement: fmt.Sprintf("%s%s.span", namespace, name), + Tags: r.tags, + Fields: map[string]interface{}{ + "count": len(val), + "max": val[len(val)-1], + "mean": t.Mean(), + "min": val[0], + "p50": ps[0], + "p95": ps[1], + "p99": ps[2], + }, + Time: now, + }) + } + } + }) + + bps := client.BatchPoints{ + Points: pts, + Database: r.database, + } + + _, err := r.client.Write(bps) + return err +} diff --git a/metrics/init_test.go b/metrics/init_test.go new file mode 100644 index 000000000..43401e833 --- /dev/null +++ b/metrics/init_test.go @@ -0,0 +1,5 @@ +package metrics + +func init() { + Enabled = true +} diff --git a/metrics/json.go b/metrics/json.go new file mode 100644 index 000000000..2087d8211 --- /dev/null +++ b/metrics/json.go @@ -0,0 +1,31 @@ +package metrics + +import ( + "encoding/json" + "io" + "time" +) + +// MarshalJSON returns a byte slice containing a JSON representation of all +// the metrics in the Registry. +func (r *StandardRegistry) MarshalJSON() ([]byte, error) { + return json.Marshal(r.GetAll()) +} + +// WriteJSON writes metrics from the given registry periodically to the +// specified io.Writer as JSON. +func WriteJSON(r Registry, d time.Duration, w io.Writer) { + for range time.Tick(d) { + WriteJSONOnce(r, w) + } +} + +// WriteJSONOnce writes metrics from the given registry to the specified +// io.Writer as JSON. +func WriteJSONOnce(r Registry, w io.Writer) { + json.NewEncoder(w).Encode(r) +} + +func (p *PrefixedRegistry) MarshalJSON() ([]byte, error) { + return json.Marshal(p.GetAll()) +} diff --git a/metrics/json_test.go b/metrics/json_test.go new file mode 100644 index 000000000..cf70051f7 --- /dev/null +++ b/metrics/json_test.go @@ -0,0 +1,28 @@ +package metrics + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestRegistryMarshallJSON(t *testing.T) { + b := &bytes.Buffer{} + enc := json.NewEncoder(b) + r := NewRegistry() + r.Register("counter", NewCounter()) + enc.Encode(r) + if s := b.String(); "{\"counter\":{\"count\":0}}\n" != s { + t.Fatalf(s) + } +} + +func TestRegistryWriteJSONOnce(t *testing.T) { + r := NewRegistry() + r.Register("counter", NewCounter()) + b := &bytes.Buffer{} + WriteJSONOnce(r, b) + if s := b.String(); s != "{\"counter\":{\"count\":0}}\n" { + t.Fail() + } +} diff --git a/metrics/librato/client.go b/metrics/librato/client.go new file mode 100644 index 000000000..8c0c850e3 --- /dev/null +++ b/metrics/librato/client.go @@ -0,0 +1,102 @@ +package librato + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +const Operations = "operations" +const OperationsShort = "ops" + +type LibratoClient struct { + Email, Token string +} + +// property strings +const ( + // display attributes + Color = "color" + DisplayMax = "display_max" + DisplayMin = "display_min" + DisplayUnitsLong = "display_units_long" + DisplayUnitsShort = "display_units_short" + DisplayStacked = "display_stacked" + DisplayTransform = "display_transform" + // special gauge display attributes + SummarizeFunction = "summarize_function" + Aggregate = "aggregate" + + // metric keys + Name = "name" + Period = "period" + Description = "description" + DisplayName = "display_name" + Attributes = "attributes" + + // measurement keys + MeasureTime = "measure_time" + Source = "source" + Value = "value" + + // special gauge keys + Count = "count" + Sum = "sum" + Max = "max" + Min = "min" + SumSquares = "sum_squares" + + // batch keys + Counters = "counters" + Gauges = "gauges" + + MetricsPostUrl = "https://metrics-api.librato.com/v1/metrics" +) + +type Measurement map[string]interface{} +type Metric map[string]interface{} + +type Batch struct { + Gauges []Measurement `json:"gauges,omitempty"` + Counters []Measurement `json:"counters,omitempty"` + MeasureTime int64 `json:"measure_time"` + Source string `json:"source"` +} + +func (self *LibratoClient) PostMetrics(batch Batch) (err error) { + var ( + js []byte + req *http.Request + resp *http.Response + ) + + if len(batch.Counters) == 0 && len(batch.Gauges) == 0 { + return nil + } + + if js, err = json.Marshal(batch); err != nil { + return + } + + if req, err = http.NewRequest("POST", MetricsPostUrl, bytes.NewBuffer(js)); err != nil { + return + } + + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth(self.Email, self.Token) + + if resp, err = http.DefaultClient.Do(req); err != nil { + return + } + + if resp.StatusCode != http.StatusOK { + var body []byte + if body, err = ioutil.ReadAll(resp.Body); err != nil { + body = []byte(fmt.Sprintf("(could not fetch response body for error: %s)", err)) + } + err = fmt.Errorf("Unable to post to Librato: %d %s %s", resp.StatusCode, resp.Status, string(body)) + } + return +} diff --git a/metrics/librato/librato.go b/metrics/librato/librato.go new file mode 100644 index 000000000..f8c8c9ecb --- /dev/null +++ b/metrics/librato/librato.go @@ -0,0 +1,235 @@ +package librato + +import ( + "fmt" + "log" + "math" + "regexp" + "time" + + "github.com/ethereum/go-ethereum/metrics" +) + +// a regexp for extracting the unit from time.Duration.String +var unitRegexp = regexp.MustCompile(`[^\\d]+$`) + +// a helper that turns a time.Duration into librato display attributes for timer metrics +func translateTimerAttributes(d time.Duration) (attrs map[string]interface{}) { + attrs = make(map[string]interface{}) + attrs[DisplayTransform] = fmt.Sprintf("x/%d", int64(d)) + attrs[DisplayUnitsShort] = string(unitRegexp.Find([]byte(d.String()))) + return +} + +type Reporter struct { + Email, Token string + Namespace string + Source string + Interval time.Duration + Registry metrics.Registry + Percentiles []float64 // percentiles to report on histogram metrics + TimerAttributes map[string]interface{} // units in which timers will be displayed + intervalSec int64 +} + +func NewReporter(r metrics.Registry, d time.Duration, e string, t string, s string, p []float64, u time.Duration) *Reporter { + return &Reporter{e, t, "", s, d, r, p, translateTimerAttributes(u), int64(d / time.Second)} +} + +func Librato(r metrics.Registry, d time.Duration, e string, t string, s string, p []float64, u time.Duration) { + NewReporter(r, d, e, t, s, p, u).Run() +} + +func (self *Reporter) Run() { + log.Printf("WARNING: This client has been DEPRECATED! It has been moved to https://github.com/mihasya/go-metrics-librato and will be removed from rcrowley/go-metrics on August 5th 2015") + ticker := time.Tick(self.Interval) + metricsApi := &LibratoClient{self.Email, self.Token} + for now := range ticker { + var metrics Batch + var err error + if metrics, err = self.BuildRequest(now, self.Registry); err != nil { + log.Printf("ERROR constructing librato request body %s", err) + continue + } + if err := metricsApi.PostMetrics(metrics); err != nil { + log.Printf("ERROR sending metrics to librato %s", err) + continue + } + } +} + +// calculate sum of squares from data provided by metrics.Histogram +// see http://en.wikipedia.org/wiki/Standard_deviation#Rapid_calculation_methods +func sumSquares(s metrics.Sample) float64 { + count := float64(s.Count()) + sumSquared := math.Pow(count*s.Mean(), 2) + sumSquares := math.Pow(count*s.StdDev(), 2) + sumSquared/count + if math.IsNaN(sumSquares) { + return 0.0 + } + return sumSquares +} +func sumSquaresTimer(t metrics.Timer) float64 { + count := float64(t.Count()) + sumSquared := math.Pow(count*t.Mean(), 2) + sumSquares := math.Pow(count*t.StdDev(), 2) + sumSquared/count + if math.IsNaN(sumSquares) { + return 0.0 + } + return sumSquares +} + +func (self *Reporter) BuildRequest(now time.Time, r metrics.Registry) (snapshot Batch, err error) { + snapshot = Batch{ + // coerce timestamps to a stepping fn so that they line up in Librato graphs + MeasureTime: (now.Unix() / self.intervalSec) * self.intervalSec, + Source: self.Source, + } + snapshot.Gauges = make([]Measurement, 0) + snapshot.Counters = make([]Measurement, 0) + histogramGaugeCount := 1 + len(self.Percentiles) + r.Each(func(name string, metric interface{}) { + if self.Namespace != "" { + name = fmt.Sprintf("%s.%s", self.Namespace, name) + } + measurement := Measurement{} + measurement[Period] = self.Interval.Seconds() + switch m := metric.(type) { + case metrics.Counter: + if m.Count() > 0 { + measurement[Name] = fmt.Sprintf("%s.%s", name, "count") + measurement[Value] = float64(m.Count()) + measurement[Attributes] = map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + } + snapshot.Counters = append(snapshot.Counters, measurement) + } + case metrics.Gauge: + measurement[Name] = name + measurement[Value] = float64(m.Value()) + snapshot.Gauges = append(snapshot.Gauges, measurement) + case metrics.GaugeFloat64: + measurement[Name] = name + measurement[Value] = m.Value() + snapshot.Gauges = append(snapshot.Gauges, measurement) + case metrics.Histogram: + if m.Count() > 0 { + gauges := make([]Measurement, histogramGaugeCount) + s := m.Sample() + measurement[Name] = fmt.Sprintf("%s.%s", name, "hist") + measurement[Count] = uint64(s.Count()) + measurement[Max] = float64(s.Max()) + measurement[Min] = float64(s.Min()) + measurement[Sum] = float64(s.Sum()) + measurement[SumSquares] = sumSquares(s) + gauges[0] = measurement + for i, p := range self.Percentiles { + gauges[i+1] = Measurement{ + Name: fmt.Sprintf("%s.%.2f", measurement[Name], p), + Value: s.Percentile(p), + Period: measurement[Period], + } + } + snapshot.Gauges = append(snapshot.Gauges, gauges...) + } + case metrics.Meter: + measurement[Name] = name + measurement[Value] = float64(m.Count()) + snapshot.Counters = append(snapshot.Counters, measurement) + snapshot.Gauges = append(snapshot.Gauges, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "1min"), + Value: m.Rate1(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "5min"), + Value: m.Rate5(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "15min"), + Value: m.Rate15(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + ) + case metrics.Timer: + measurement[Name] = name + measurement[Value] = float64(m.Count()) + snapshot.Counters = append(snapshot.Counters, measurement) + if m.Count() > 0 { + libratoName := fmt.Sprintf("%s.%s", name, "timer.mean") + gauges := make([]Measurement, histogramGaugeCount) + gauges[0] = Measurement{ + Name: libratoName, + Count: uint64(m.Count()), + Sum: m.Mean() * float64(m.Count()), + Max: float64(m.Max()), + Min: float64(m.Min()), + SumSquares: sumSquaresTimer(m), + Period: int64(self.Interval.Seconds()), + Attributes: self.TimerAttributes, + } + for i, p := range self.Percentiles { + gauges[i+1] = Measurement{ + Name: fmt.Sprintf("%s.timer.%2.0f", name, p*100), + Value: m.Percentile(p), + Period: int64(self.Interval.Seconds()), + Attributes: self.TimerAttributes, + } + } + snapshot.Gauges = append(snapshot.Gauges, gauges...) + snapshot.Gauges = append(snapshot.Gauges, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "rate.1min"), + Value: m.Rate1(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "rate.5min"), + Value: m.Rate5(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + Measurement{ + Name: fmt.Sprintf("%s.%s", name, "rate.15min"), + Value: m.Rate15(), + Period: int64(self.Interval.Seconds()), + Attributes: map[string]interface{}{ + DisplayUnitsLong: Operations, + DisplayUnitsShort: OperationsShort, + DisplayMin: "0", + }, + }, + ) + } + } + }) + return +} diff --git a/vendor/github.com/rcrowley/go-metrics/log.go b/metrics/log.go index f8074c045..0c8ea7c97 100644 --- a/vendor/github.com/rcrowley/go-metrics/log.go +++ b/metrics/log.go @@ -18,7 +18,7 @@ func LogScaled(r Registry, freq time.Duration, scale time.Duration, l Logger) { du := float64(scale) duSuffix := scale.String()[1:] - for _ = range time.Tick(freq) { + for range time.Tick(freq) { r.Each(func(name string, i interface{}) { switch metric := i.(type) { case Counter: diff --git a/vendor/github.com/rcrowley/go-metrics/memory.md b/metrics/memory.md index 47454f54b..47454f54b 100644 --- a/vendor/github.com/rcrowley/go-metrics/memory.md +++ b/metrics/memory.md diff --git a/vendor/github.com/rcrowley/go-metrics/meter.go b/metrics/meter.go index 0389ab0b8..82b2141a6 100644 --- a/vendor/github.com/rcrowley/go-metrics/meter.go +++ b/metrics/meter.go @@ -15,10 +15,13 @@ type Meter interface { Rate15() float64 RateMean() float64 Snapshot() Meter + Stop() } // GetOrRegisterMeter returns an existing Meter or constructs and registers a // new StandardMeter. +// Be sure to unregister the meter from the registry once it is of no use to +// allow for garbage collection. func GetOrRegisterMeter(name string, r Registry) Meter { if nil == r { r = DefaultRegistry @@ -27,14 +30,15 @@ func GetOrRegisterMeter(name string, r Registry) Meter { } // NewMeter constructs a new StandardMeter and launches a goroutine. +// Be sure to call Stop() once the meter is of no use to allow for garbage collection. func NewMeter() Meter { - if UseNilMetrics { + if !Enabled { return NilMeter{} } m := newStandardMeter() arbiter.Lock() defer arbiter.Unlock() - arbiter.meters = append(arbiter.meters, m) + arbiter.meters[m] = struct{}{} if !arbiter.started { arbiter.started = true go arbiter.tick() @@ -44,6 +48,8 @@ func NewMeter() Meter { // NewMeter constructs and registers a new StandardMeter and launches a // goroutine. +// Be sure to unregister the meter from the registry once it is of no use to +// allow for garbage collection. func NewRegisteredMeter(name string, r Registry) Meter { c := NewMeter() if nil == r { @@ -86,6 +92,9 @@ func (m *MeterSnapshot) RateMean() float64 { return m.rateMean } // Snapshot returns the snapshot. func (m *MeterSnapshot) Snapshot() Meter { return m } +// Stop is a no-op. +func (m *MeterSnapshot) Stop() {} + // NilMeter is a no-op Meter. type NilMeter struct{} @@ -110,12 +119,16 @@ func (NilMeter) RateMean() float64 { return 0.0 } // Snapshot is a no-op. func (NilMeter) Snapshot() Meter { return NilMeter{} } +// Stop is a no-op. +func (NilMeter) Stop() {} + // StandardMeter is the standard implementation of a Meter. type StandardMeter struct { lock sync.RWMutex snapshot *MeterSnapshot a1, a5, a15 EWMA startTime time.Time + stopped bool } func newStandardMeter() *StandardMeter { @@ -128,6 +141,19 @@ func newStandardMeter() *StandardMeter { } } +// Stop stops the meter, Mark() will be a no-op if you use it after being stopped. +func (m *StandardMeter) Stop() { + m.lock.Lock() + stopped := m.stopped + m.stopped = true + m.lock.Unlock() + if !stopped { + arbiter.Lock() + delete(arbiter.meters, m) + arbiter.Unlock() + } +} + // Count returns the number of events recorded. func (m *StandardMeter) Count() int64 { m.lock.RLock() @@ -136,10 +162,13 @@ func (m *StandardMeter) Count() int64 { return count } -// Mark records the occurance of n events. +// Mark records the occurrence of n events. func (m *StandardMeter) Mark(n int64) { m.lock.Lock() defer m.lock.Unlock() + if m.stopped { + return + } m.snapshot.count += n m.a1.Update(n) m.a5.Update(n) @@ -205,29 +234,28 @@ func (m *StandardMeter) tick() { m.updateSnapshot() } +// meterArbiter ticks meters every 5s from a single goroutine. +// meters are references in a set for future stopping. type meterArbiter struct { sync.RWMutex started bool - meters []*StandardMeter + meters map[*StandardMeter]struct{} ticker *time.Ticker } -var arbiter = meterArbiter{ticker: time.NewTicker(5e9)} +var arbiter = meterArbiter{ticker: time.NewTicker(5e9), meters: make(map[*StandardMeter]struct{})} // Ticks meters on the scheduled interval func (ma *meterArbiter) tick() { - for { - select { - case <-ma.ticker.C: - ma.tickMeters() - } + for range ma.ticker.C { + ma.tickMeters() } } func (ma *meterArbiter) tickMeters() { ma.RLock() defer ma.RUnlock() - for _, meter := range ma.meters { + for meter := range ma.meters { meter.tick() } } diff --git a/metrics/meter_test.go b/metrics/meter_test.go new file mode 100644 index 000000000..e88922260 --- /dev/null +++ b/metrics/meter_test.go @@ -0,0 +1,73 @@ +package metrics + +import ( + "testing" + "time" +) + +func BenchmarkMeter(b *testing.B) { + m := NewMeter() + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Mark(1) + } +} + +func TestGetOrRegisterMeter(t *testing.T) { + r := NewRegistry() + NewRegisteredMeter("foo", r).Mark(47) + if m := GetOrRegisterMeter("foo", r); 47 != m.Count() { + t.Fatal(m) + } +} + +func TestMeterDecay(t *testing.T) { + ma := meterArbiter{ + ticker: time.NewTicker(time.Millisecond), + meters: make(map[*StandardMeter]struct{}), + } + m := newStandardMeter() + ma.meters[m] = struct{}{} + go ma.tick() + m.Mark(1) + rateMean := m.RateMean() + time.Sleep(100 * time.Millisecond) + if m.RateMean() >= rateMean { + t.Error("m.RateMean() didn't decrease") + } +} + +func TestMeterNonzero(t *testing.T) { + m := NewMeter() + m.Mark(3) + if count := m.Count(); 3 != count { + t.Errorf("m.Count(): 3 != %v\n", count) + } +} + +func TestMeterStop(t *testing.T) { + l := len(arbiter.meters) + m := NewMeter() + if len(arbiter.meters) != l+1 { + t.Errorf("arbiter.meters: %d != %d\n", l+1, len(arbiter.meters)) + } + m.Stop() + if len(arbiter.meters) != l { + t.Errorf("arbiter.meters: %d != %d\n", l, len(arbiter.meters)) + } +} + +func TestMeterSnapshot(t *testing.T) { + m := NewMeter() + m.Mark(1) + if snapshot := m.Snapshot(); m.RateMean() != snapshot.RateMean() { + t.Fatal(snapshot) + } +} + +func TestMeterZero(t *testing.T) { + m := NewMeter() + if count := m.Count(); 0 != count { + t.Errorf("m.Count(): 0 != %v\n", count) + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index c82661d80..e24324814 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,20 +1,8 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. +// Go port of Coda Hale's Metrics library // -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. +// <https://github.com/rcrowley/go-metrics> // -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// Package metrics provides general system and process level metrics collection. +// Coda Hale's original work: <https://github.com/codahale/metrics> package metrics import ( @@ -24,17 +12,19 @@ import ( "time" "github.com/ethereum/go-ethereum/log" - "github.com/rcrowley/go-metrics" - "github.com/rcrowley/go-metrics/exp" ) +// Enabled is checked by the constructor functions for all of the +// standard metrics. If it is true, the metric returned is a stub. +// +// This global kill-switch helps quantify the observer effect and makes +// for less cluttered pprof profiles. +var Enabled bool = false + // MetricsEnabledFlag is the CLI flag name to use to enable metrics collections. const MetricsEnabledFlag = "metrics" const DashboardEnabledFlag = "dashboard" -// Enabled is the flag specifying if metrics are enable or not. -var Enabled = false - // Init enables or disables the metrics system. Since we need this to run before // any other code gets to create meters and timers, we'll actually do an ugly hack // and peek into the command line args for the metrics flag. @@ -45,34 +35,6 @@ func init() { Enabled = true } } - exp.Exp(metrics.DefaultRegistry) -} - -// NewCounter create a new metrics Counter, either a real one of a NOP stub depending -// on the metrics flag. -func NewCounter(name string) metrics.Counter { - if !Enabled { - return new(metrics.NilCounter) - } - return metrics.GetOrRegisterCounter(name, metrics.DefaultRegistry) -} - -// NewMeter create a new metrics Meter, either a real one of a NOP stub depending -// on the metrics flag. -func NewMeter(name string) metrics.Meter { - if !Enabled { - return new(metrics.NilMeter) - } - return metrics.GetOrRegisterMeter(name, metrics.DefaultRegistry) -} - -// NewTimer create a new metrics Timer, either a real one of a NOP stub depending -// on the metrics flag. -func NewTimer(name string) metrics.Timer { - if !Enabled { - return new(metrics.NilTimer) - } - return metrics.GetOrRegisterTimer(name, metrics.DefaultRegistry) } // CollectProcessMetrics periodically collects various metrics about the running @@ -90,17 +52,17 @@ func CollectProcessMetrics(refresh time.Duration) { diskstats[i] = new(DiskStats) } // Define the various metrics to collect - memAllocs := metrics.GetOrRegisterMeter("system/memory/allocs", metrics.DefaultRegistry) - memFrees := metrics.GetOrRegisterMeter("system/memory/frees", metrics.DefaultRegistry) - memInuse := metrics.GetOrRegisterMeter("system/memory/inuse", metrics.DefaultRegistry) - memPauses := metrics.GetOrRegisterMeter("system/memory/pauses", metrics.DefaultRegistry) + memAllocs := GetOrRegisterMeter("system/memory/allocs", DefaultRegistry) + memFrees := GetOrRegisterMeter("system/memory/frees", DefaultRegistry) + memInuse := GetOrRegisterMeter("system/memory/inuse", DefaultRegistry) + memPauses := GetOrRegisterMeter("system/memory/pauses", DefaultRegistry) - var diskReads, diskReadBytes, diskWrites, diskWriteBytes metrics.Meter + var diskReads, diskReadBytes, diskWrites, diskWriteBytes Meter if err := ReadDiskStats(diskstats[0]); err == nil { - diskReads = metrics.GetOrRegisterMeter("system/disk/readcount", metrics.DefaultRegistry) - diskReadBytes = metrics.GetOrRegisterMeter("system/disk/readdata", metrics.DefaultRegistry) - diskWrites = metrics.GetOrRegisterMeter("system/disk/writecount", metrics.DefaultRegistry) - diskWriteBytes = metrics.GetOrRegisterMeter("system/disk/writedata", metrics.DefaultRegistry) + diskReads = GetOrRegisterMeter("system/disk/readcount", DefaultRegistry) + diskReadBytes = GetOrRegisterMeter("system/disk/readdata", DefaultRegistry) + diskWrites = GetOrRegisterMeter("system/disk/writecount", DefaultRegistry) + diskWriteBytes = GetOrRegisterMeter("system/disk/writedata", DefaultRegistry) } else { log.Debug("Failed to read disk metrics", "err", err) } diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go new file mode 100644 index 000000000..df36da0ad --- /dev/null +++ b/metrics/metrics_test.go @@ -0,0 +1,125 @@ +package metrics + +import ( + "fmt" + "io/ioutil" + "log" + "sync" + "testing" + "time" +) + +const FANOUT = 128 + +// Stop the compiler from complaining during debugging. +var ( + _ = ioutil.Discard + _ = log.LstdFlags +) + +func BenchmarkMetrics(b *testing.B) { + r := NewRegistry() + c := NewRegisteredCounter("counter", r) + g := NewRegisteredGauge("gauge", r) + gf := NewRegisteredGaugeFloat64("gaugefloat64", r) + h := NewRegisteredHistogram("histogram", r, NewUniformSample(100)) + m := NewRegisteredMeter("meter", r) + t := NewRegisteredTimer("timer", r) + RegisterDebugGCStats(r) + RegisterRuntimeMemStats(r) + b.ResetTimer() + ch := make(chan bool) + + wgD := &sync.WaitGroup{} + /* + wgD.Add(1) + go func() { + defer wgD.Done() + //log.Println("go CaptureDebugGCStats") + for { + select { + case <-ch: + //log.Println("done CaptureDebugGCStats") + return + default: + CaptureDebugGCStatsOnce(r) + } + } + }() + //*/ + + wgR := &sync.WaitGroup{} + //* + wgR.Add(1) + go func() { + defer wgR.Done() + //log.Println("go CaptureRuntimeMemStats") + for { + select { + case <-ch: + //log.Println("done CaptureRuntimeMemStats") + return + default: + CaptureRuntimeMemStatsOnce(r) + } + } + }() + //*/ + + wgW := &sync.WaitGroup{} + /* + wgW.Add(1) + go func() { + defer wgW.Done() + //log.Println("go Write") + for { + select { + case <-ch: + //log.Println("done Write") + return + default: + WriteOnce(r, ioutil.Discard) + } + } + }() + //*/ + + wg := &sync.WaitGroup{} + wg.Add(FANOUT) + for i := 0; i < FANOUT; i++ { + go func(i int) { + defer wg.Done() + //log.Println("go", i) + for i := 0; i < b.N; i++ { + c.Inc(1) + g.Update(int64(i)) + gf.Update(float64(i)) + h.Update(int64(i)) + m.Mark(1) + t.Update(1) + } + //log.Println("done", i) + }(i) + } + wg.Wait() + close(ch) + wgD.Wait() + wgR.Wait() + wgW.Wait() +} + +func Example() { + c := NewCounter() + Register("money", c) + c.Inc(17) + + // Threadsafe registration + t := GetOrRegisterTimer("db.get.latency", nil) + t.Time(func() { time.Sleep(10 * time.Millisecond) }) + t.Update(1) + + fmt.Println(c.Count()) + fmt.Println(t.Min()) + // Output: 17 + // 1 +} diff --git a/vendor/github.com/rcrowley/go-metrics/opentsdb.go b/metrics/opentsdb.go index 266b6c93d..df7f152ed 100644 --- a/vendor/github.com/rcrowley/go-metrics/opentsdb.go +++ b/metrics/opentsdb.go @@ -38,7 +38,7 @@ func OpenTSDB(r Registry, d time.Duration, prefix string, addr *net.TCPAddr) { // OpenTSDBWithConfig is a blocking exporter function just like OpenTSDB, // but it takes a OpenTSDBConfig instead. func OpenTSDBWithConfig(c OpenTSDBConfig) { - for _ = range time.Tick(c.FlushInterval) { + for range time.Tick(c.FlushInterval) { if err := openTSDB(&c); nil != err { log.Println(err) } diff --git a/metrics/opentsdb_test.go b/metrics/opentsdb_test.go new file mode 100644 index 000000000..c43728960 --- /dev/null +++ b/metrics/opentsdb_test.go @@ -0,0 +1,21 @@ +package metrics + +import ( + "net" + "time" +) + +func ExampleOpenTSDB() { + addr, _ := net.ResolveTCPAddr("net", ":2003") + go OpenTSDB(DefaultRegistry, 1*time.Second, "some.prefix", addr) +} + +func ExampleOpenTSDBWithConfig() { + addr, _ := net.ResolveTCPAddr("net", ":2003") + go OpenTSDBWithConfig(OpenTSDBConfig{ + Addr: addr, + Registry: DefaultRegistry, + FlushInterval: 1 * time.Second, + DurationUnit: time.Millisecond, + }) +} diff --git a/vendor/github.com/rcrowley/go-metrics/registry.go b/metrics/registry.go index 2bb7a1e7d..cc34c9dfd 100644 --- a/vendor/github.com/rcrowley/go-metrics/registry.go +++ b/metrics/registry.go @@ -29,6 +29,9 @@ type Registry interface { // Get the metric by the given name or nil if none is registered. Get(string) interface{} + // GetAll metrics in the Registry. + GetAll() map[string]map[string]interface{} + // Gets an existing metric or registers the given one. // The interface can be the metric to register if not found in registry, // or a function returning the metric for lazy instantiation. @@ -109,10 +112,72 @@ func (r *StandardRegistry) RunHealthchecks() { } } +// GetAll metrics in the Registry +func (r *StandardRegistry) GetAll() map[string]map[string]interface{} { + data := make(map[string]map[string]interface{}) + r.Each(func(name string, i interface{}) { + values := make(map[string]interface{}) + switch metric := i.(type) { + case Counter: + values["count"] = metric.Count() + case Gauge: + values["value"] = metric.Value() + case GaugeFloat64: + values["value"] = metric.Value() + case Healthcheck: + values["error"] = nil + metric.Check() + if err := metric.Error(); nil != err { + values["error"] = metric.Error().Error() + } + case Histogram: + h := metric.Snapshot() + ps := h.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) + values["count"] = h.Count() + values["min"] = h.Min() + values["max"] = h.Max() + values["mean"] = h.Mean() + values["stddev"] = h.StdDev() + values["median"] = ps[0] + values["75%"] = ps[1] + values["95%"] = ps[2] + values["99%"] = ps[3] + values["99.9%"] = ps[4] + case Meter: + m := metric.Snapshot() + values["count"] = m.Count() + values["1m.rate"] = m.Rate1() + values["5m.rate"] = m.Rate5() + values["15m.rate"] = m.Rate15() + values["mean.rate"] = m.RateMean() + case Timer: + t := metric.Snapshot() + ps := t.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) + values["count"] = t.Count() + values["min"] = t.Min() + values["max"] = t.Max() + values["mean"] = t.Mean() + values["stddev"] = t.StdDev() + values["median"] = ps[0] + values["75%"] = ps[1] + values["95%"] = ps[2] + values["99%"] = ps[3] + values["99.9%"] = ps[4] + values["1m.rate"] = t.Rate1() + values["5m.rate"] = t.Rate5() + values["15m.rate"] = t.Rate15() + values["mean.rate"] = t.RateMean() + } + data[name] = values + }) + return data +} + // Unregister the metric with the given name. func (r *StandardRegistry) Unregister(name string) { r.mutex.Lock() defer r.mutex.Unlock() + r.stop(name) delete(r.metrics, name) } @@ -120,7 +185,8 @@ func (r *StandardRegistry) Unregister(name string) { func (r *StandardRegistry) UnregisterAll() { r.mutex.Lock() defer r.mutex.Unlock() - for name, _ := range r.metrics { + for name := range r.metrics { + r.stop(name) delete(r.metrics, name) } } @@ -130,7 +196,7 @@ func (r *StandardRegistry) register(name string, i interface{}) error { return DuplicateMetric(name) } switch i.(type) { - case Counter, Gauge, GaugeFloat64, Healthcheck, Histogram, Meter, Timer: + case Counter, Gauge, GaugeFloat64, Healthcheck, Histogram, Meter, Timer, ResettingTimer: r.metrics[name] = i } return nil @@ -146,6 +212,19 @@ func (r *StandardRegistry) registered() map[string]interface{} { return metrics } +func (r *StandardRegistry) stop(name string) { + if i, ok := r.metrics[name]; ok { + if s, ok := i.(Stoppable); ok { + s.Stop() + } + } +} + +// Stoppable defines the metrics which has to be stopped. +type Stoppable interface { + Stop() +} + type PrefixedRegistry struct { underlying Registry prefix string @@ -216,6 +295,11 @@ func (r *PrefixedRegistry) RunHealthchecks() { r.underlying.RunHealthchecks() } +// GetAll metrics in the Registry +func (r *PrefixedRegistry) GetAll() map[string]map[string]interface{} { + return r.underlying.GetAll() +} + // Unregister the metric with the given name. The name will be prefixed. func (r *PrefixedRegistry) Unregister(name string) { realName := r.prefix + name diff --git a/metrics/registry_test.go b/metrics/registry_test.go new file mode 100644 index 000000000..a63e485fe --- /dev/null +++ b/metrics/registry_test.go @@ -0,0 +1,305 @@ +package metrics + +import ( + "testing" +) + +func BenchmarkRegistry(b *testing.B) { + r := NewRegistry() + r.Register("foo", NewCounter()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Each(func(string, interface{}) {}) + } +} + +func TestRegistry(t *testing.T) { + r := NewRegistry() + r.Register("foo", NewCounter()) + i := 0 + r.Each(func(name string, iface interface{}) { + i++ + if "foo" != name { + t.Fatal(name) + } + if _, ok := iface.(Counter); !ok { + t.Fatal(iface) + } + }) + if 1 != i { + t.Fatal(i) + } + r.Unregister("foo") + i = 0 + r.Each(func(string, interface{}) { i++ }) + if 0 != i { + t.Fatal(i) + } +} + +func TestRegistryDuplicate(t *testing.T) { + r := NewRegistry() + if err := r.Register("foo", NewCounter()); nil != err { + t.Fatal(err) + } + if err := r.Register("foo", NewGauge()); nil == err { + t.Fatal(err) + } + i := 0 + r.Each(func(name string, iface interface{}) { + i++ + if _, ok := iface.(Counter); !ok { + t.Fatal(iface) + } + }) + if 1 != i { + t.Fatal(i) + } +} + +func TestRegistryGet(t *testing.T) { + r := NewRegistry() + r.Register("foo", NewCounter()) + if count := r.Get("foo").(Counter).Count(); 0 != count { + t.Fatal(count) + } + r.Get("foo").(Counter).Inc(1) + if count := r.Get("foo").(Counter).Count(); 1 != count { + t.Fatal(count) + } +} + +func TestRegistryGetOrRegister(t *testing.T) { + r := NewRegistry() + + // First metric wins with GetOrRegister + _ = r.GetOrRegister("foo", NewCounter()) + m := r.GetOrRegister("foo", NewGauge()) + if _, ok := m.(Counter); !ok { + t.Fatal(m) + } + + i := 0 + r.Each(func(name string, iface interface{}) { + i++ + if name != "foo" { + t.Fatal(name) + } + if _, ok := iface.(Counter); !ok { + t.Fatal(iface) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestRegistryGetOrRegisterWithLazyInstantiation(t *testing.T) { + r := NewRegistry() + + // First metric wins with GetOrRegister + _ = r.GetOrRegister("foo", NewCounter) + m := r.GetOrRegister("foo", NewGauge) + if _, ok := m.(Counter); !ok { + t.Fatal(m) + } + + i := 0 + r.Each(func(name string, iface interface{}) { + i++ + if name != "foo" { + t.Fatal(name) + } + if _, ok := iface.(Counter); !ok { + t.Fatal(iface) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestRegistryUnregister(t *testing.T) { + l := len(arbiter.meters) + r := NewRegistry() + r.Register("foo", NewCounter()) + r.Register("bar", NewMeter()) + r.Register("baz", NewTimer()) + if len(arbiter.meters) != l+2 { + t.Errorf("arbiter.meters: %d != %d\n", l+2, len(arbiter.meters)) + } + r.Unregister("foo") + r.Unregister("bar") + r.Unregister("baz") + if len(arbiter.meters) != l { + t.Errorf("arbiter.meters: %d != %d\n", l+2, len(arbiter.meters)) + } +} + +func TestPrefixedChildRegistryGetOrRegister(t *testing.T) { + r := NewRegistry() + pr := NewPrefixedChildRegistry(r, "prefix.") + + _ = pr.GetOrRegister("foo", NewCounter()) + + i := 0 + r.Each(func(name string, m interface{}) { + i++ + if name != "prefix.foo" { + t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestPrefixedRegistryGetOrRegister(t *testing.T) { + r := NewPrefixedRegistry("prefix.") + + _ = r.GetOrRegister("foo", NewCounter()) + + i := 0 + r.Each(func(name string, m interface{}) { + i++ + if name != "prefix.foo" { + t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestPrefixedRegistryRegister(t *testing.T) { + r := NewPrefixedRegistry("prefix.") + err := r.Register("foo", NewCounter()) + c := NewCounter() + Register("bar", c) + if err != nil { + t.Fatal(err.Error()) + } + + i := 0 + r.Each(func(name string, m interface{}) { + i++ + if name != "prefix.foo" { + t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestPrefixedRegistryUnregister(t *testing.T) { + r := NewPrefixedRegistry("prefix.") + + _ = r.Register("foo", NewCounter()) + + i := 0 + r.Each(func(name string, m interface{}) { + i++ + if name != "prefix.foo" { + t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } + + r.Unregister("foo") + + i = 0 + r.Each(func(name string, m interface{}) { + i++ + }) + + if i != 0 { + t.Fatal(i) + } +} + +func TestPrefixedRegistryGet(t *testing.T) { + pr := NewPrefixedRegistry("prefix.") + name := "foo" + pr.Register(name, NewCounter()) + + fooCounter := pr.Get(name) + if fooCounter == nil { + t.Fatal(name) + } +} + +func TestPrefixedChildRegistryGet(t *testing.T) { + r := NewRegistry() + pr := NewPrefixedChildRegistry(r, "prefix.") + name := "foo" + pr.Register(name, NewCounter()) + fooCounter := pr.Get(name) + if fooCounter == nil { + t.Fatal(name) + } +} + +func TestChildPrefixedRegistryRegister(t *testing.T) { + r := NewPrefixedChildRegistry(DefaultRegistry, "prefix.") + err := r.Register("foo", NewCounter()) + c := NewCounter() + Register("bar", c) + if err != nil { + t.Fatal(err.Error()) + } + + i := 0 + r.Each(func(name string, m interface{}) { + i++ + if name != "prefix.foo" { + t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestChildPrefixedRegistryOfChildRegister(t *testing.T) { + r := NewPrefixedChildRegistry(NewRegistry(), "prefix.") + r2 := NewPrefixedChildRegistry(r, "prefix2.") + err := r.Register("foo2", NewCounter()) + if err != nil { + t.Fatal(err.Error()) + } + err = r2.Register("baz", NewCounter()) + c := NewCounter() + Register("bars", c) + + i := 0 + r2.Each(func(name string, m interface{}) { + i++ + if name != "prefix.prefix2.baz" { + //t.Fatal(name) + } + }) + if i != 1 { + t.Fatal(i) + } +} + +func TestWalkRegistries(t *testing.T) { + r := NewPrefixedChildRegistry(NewRegistry(), "prefix.") + r2 := NewPrefixedChildRegistry(r, "prefix2.") + err := r.Register("foo2", NewCounter()) + if err != nil { + t.Fatal(err.Error()) + } + err = r2.Register("baz", NewCounter()) + c := NewCounter() + Register("bars", c) + + _, prefix := findPrefix(r2, "") + if "prefix.prefix2." != prefix { + t.Fatal(prefix) + } + +} diff --git a/metrics/resetting_timer.go b/metrics/resetting_timer.go new file mode 100644 index 000000000..57bcb3134 --- /dev/null +++ b/metrics/resetting_timer.go @@ -0,0 +1,237 @@ +package metrics + +import ( + "math" + "sort" + "sync" + "time" +) + +// Initial slice capacity for the values stored in a ResettingTimer +const InitialResettingTimerSliceCap = 10 + +// ResettingTimer is used for storing aggregated values for timers, which are reset on every flush interval. +type ResettingTimer interface { + Values() []int64 + Snapshot() ResettingTimer + Percentiles([]float64) []int64 + Mean() float64 + Time(func()) + Update(time.Duration) + UpdateSince(time.Time) +} + +// GetOrRegisterResettingTimer returns an existing ResettingTimer or constructs and registers a +// new StandardResettingTimer. +func GetOrRegisterResettingTimer(name string, r Registry) ResettingTimer { + if nil == r { + r = DefaultRegistry + } + return r.GetOrRegister(name, NewResettingTimer).(ResettingTimer) +} + +// NewRegisteredResettingTimer constructs and registers a new StandardResettingTimer. +func NewRegisteredResettingTimer(name string, r Registry) ResettingTimer { + c := NewResettingTimer() + if nil == r { + r = DefaultRegistry + } + r.Register(name, c) + return c +} + +// NewResettingTimer constructs a new StandardResettingTimer +func NewResettingTimer() ResettingTimer { + if !Enabled { + return NilResettingTimer{} + } + return &StandardResettingTimer{ + values: make([]int64, 0, InitialResettingTimerSliceCap), + } +} + +// NilResettingTimer is a no-op ResettingTimer. +type NilResettingTimer struct { +} + +// Values is a no-op. +func (NilResettingTimer) Values() []int64 { return nil } + +// Snapshot is a no-op. +func (NilResettingTimer) Snapshot() ResettingTimer { return NilResettingTimer{} } + +// Time is a no-op. +func (NilResettingTimer) Time(func()) {} + +// Update is a no-op. +func (NilResettingTimer) Update(time.Duration) {} + +// Percentiles panics. +func (NilResettingTimer) Percentiles([]float64) []int64 { + panic("Percentiles called on a NilResettingTimer") +} + +// Mean panics. +func (NilResettingTimer) Mean() float64 { + panic("Mean called on a NilResettingTimer") +} + +// UpdateSince is a no-op. +func (NilResettingTimer) UpdateSince(time.Time) {} + +// StandardResettingTimer is the standard implementation of a ResettingTimer. +// and Meter. +type StandardResettingTimer struct { + values []int64 + mutex sync.Mutex +} + +// Values returns a slice with all measurements. +func (t *StandardResettingTimer) Values() []int64 { + return t.values +} + +// Snapshot resets the timer and returns a read-only copy of its contents. +func (t *StandardResettingTimer) Snapshot() ResettingTimer { + t.mutex.Lock() + defer t.mutex.Unlock() + currentValues := t.values + t.values = make([]int64, 0, InitialResettingTimerSliceCap) + + return &ResettingTimerSnapshot{ + values: currentValues, + } +} + +// Percentiles panics. +func (t *StandardResettingTimer) Percentiles([]float64) []int64 { + panic("Percentiles called on a StandardResettingTimer") +} + +// Mean panics. +func (t *StandardResettingTimer) Mean() float64 { + panic("Mean called on a StandardResettingTimer") +} + +// Record the duration of the execution of the given function. +func (t *StandardResettingTimer) Time(f func()) { + ts := time.Now() + f() + t.Update(time.Since(ts)) +} + +// Record the duration of an event. +func (t *StandardResettingTimer) Update(d time.Duration) { + t.mutex.Lock() + defer t.mutex.Unlock() + t.values = append(t.values, int64(d)) +} + +// Record the duration of an event that started at a time and ends now. +func (t *StandardResettingTimer) UpdateSince(ts time.Time) { + t.mutex.Lock() + defer t.mutex.Unlock() + t.values = append(t.values, int64(time.Since(ts))) +} + +// ResettingTimerSnapshot is a point-in-time copy of another ResettingTimer. +type ResettingTimerSnapshot struct { + values []int64 + mean float64 + thresholdBoundaries []int64 + calculated bool +} + +// Snapshot returns the snapshot. +func (t *ResettingTimerSnapshot) Snapshot() ResettingTimer { return t } + +// Time panics. +func (*ResettingTimerSnapshot) Time(func()) { + panic("Time called on a ResettingTimerSnapshot") +} + +// Update panics. +func (*ResettingTimerSnapshot) Update(time.Duration) { + panic("Update called on a ResettingTimerSnapshot") +} + +// UpdateSince panics. +func (*ResettingTimerSnapshot) UpdateSince(time.Time) { + panic("UpdateSince called on a ResettingTimerSnapshot") +} + +// Values returns all values from snapshot. +func (t *ResettingTimerSnapshot) Values() []int64 { + return t.values +} + +// Percentiles returns the boundaries for the input percentiles. +func (t *ResettingTimerSnapshot) Percentiles(percentiles []float64) []int64 { + t.calc(percentiles) + + return t.thresholdBoundaries +} + +// Mean returns the mean of the snapshotted values +func (t *ResettingTimerSnapshot) Mean() float64 { + if !t.calculated { + t.calc([]float64{}) + } + + return t.mean +} + +func (t *ResettingTimerSnapshot) calc(percentiles []float64) { + sort.Sort(Int64Slice(t.values)) + + count := len(t.values) + if count > 0 { + min := t.values[0] + max := t.values[count-1] + + cumulativeValues := make([]int64, count) + cumulativeValues[0] = min + for i := 1; i < count; i++ { + cumulativeValues[i] = t.values[i] + cumulativeValues[i-1] + } + + t.thresholdBoundaries = make([]int64, len(percentiles)) + + thresholdBoundary := max + + for i, pct := range percentiles { + if count > 1 { + var abs float64 + if pct >= 0 { + abs = pct + } else { + abs = 100 + pct + } + // poor man's math.Round(x): + // math.Floor(x + 0.5) + indexOfPerc := int(math.Floor(((abs / 100.0) * float64(count)) + 0.5)) + if pct >= 0 { + indexOfPerc -= 1 // index offset=0 + } + thresholdBoundary = t.values[indexOfPerc] + } + + t.thresholdBoundaries[i] = thresholdBoundary + } + + sum := cumulativeValues[count-1] + t.mean = float64(sum) / float64(count) + } else { + t.thresholdBoundaries = make([]int64, len(percentiles)) + t.mean = 0 + } + + t.calculated = true +} + +// Int64Slice attaches the methods of sort.Interface to []int64, sorting in increasing order. +type Int64Slice []int64 + +func (s Int64Slice) Len() int { return len(s) } +func (s Int64Slice) Less(i, j int) bool { return s[i] < s[j] } +func (s Int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/metrics/resetting_timer_test.go b/metrics/resetting_timer_test.go new file mode 100644 index 000000000..58fd47f35 --- /dev/null +++ b/metrics/resetting_timer_test.go @@ -0,0 +1,106 @@ +package metrics + +import ( + "testing" + "time" +) + +func TestResettingTimer(t *testing.T) { + tests := []struct { + values []int64 + start int + end int + wantP50 int64 + wantP95 int64 + wantP99 int64 + wantMean float64 + wantMin int64 + wantMax int64 + }{ + { + values: []int64{}, + start: 1, + end: 11, + wantP50: 5, wantP95: 10, wantP99: 10, + wantMin: 1, wantMax: 10, wantMean: 5.5, + }, + { + values: []int64{}, + start: 1, + end: 101, + wantP50: 50, wantP95: 95, wantP99: 99, + wantMin: 1, wantMax: 100, wantMean: 50.5, + }, + { + values: []int64{1}, + start: 0, + end: 0, + wantP50: 1, wantP95: 1, wantP99: 1, + wantMin: 1, wantMax: 1, wantMean: 1, + }, + { + values: []int64{0}, + start: 0, + end: 0, + wantP50: 0, wantP95: 0, wantP99: 0, + wantMin: 0, wantMax: 0, wantMean: 0, + }, + { + values: []int64{}, + start: 0, + end: 0, + wantP50: 0, wantP95: 0, wantP99: 0, + wantMin: 0, wantMax: 0, wantMean: 0, + }, + { + values: []int64{1, 10}, + start: 0, + end: 0, + wantP50: 1, wantP95: 10, wantP99: 10, + wantMin: 1, wantMax: 10, wantMean: 5.5, + }, + } + for ind, tt := range tests { + timer := NewResettingTimer() + + for i := tt.start; i < tt.end; i++ { + tt.values = append(tt.values, int64(i)) + } + + for _, v := range tt.values { + timer.Update(time.Duration(v)) + } + + snap := timer.Snapshot() + + ps := snap.Percentiles([]float64{50, 95, 99}) + + val := snap.Values() + + if len(val) > 0 { + if tt.wantMin != val[0] { + t.Fatalf("%d: min: got %d, want %d", ind, val[0], tt.wantMin) + } + + if tt.wantMax != val[len(val)-1] { + t.Fatalf("%d: max: got %d, want %d", ind, val[len(val)-1], tt.wantMax) + } + } + + if tt.wantMean != snap.Mean() { + t.Fatalf("%d: mean: got %.2f, want %.2f", ind, snap.Mean(), tt.wantMean) + } + + if tt.wantP50 != ps[0] { + t.Fatalf("%d: p50: got %d, want %d", ind, ps[0], tt.wantP50) + } + + if tt.wantP95 != ps[1] { + t.Fatalf("%d: p95: got %d, want %d", ind, ps[1], tt.wantP95) + } + + if tt.wantP99 != ps[2] { + t.Fatalf("%d: p99: got %d, want %d", ind, ps[2], tt.wantP99) + } + } +} diff --git a/vendor/github.com/rcrowley/go-metrics/runtime.go b/metrics/runtime.go index 11c6b785a..9450c479b 100644 --- a/vendor/github.com/rcrowley/go-metrics/runtime.go +++ b/metrics/runtime.go @@ -55,7 +55,7 @@ var ( // Capture new values for the Go runtime statistics exported in // runtime.MemStats. This is designed to be called as a goroutine. func CaptureRuntimeMemStats(r Registry, d time.Duration) { - for _ = range time.Tick(d) { + for range time.Tick(d) { CaptureRuntimeMemStatsOnce(r) } } diff --git a/vendor/github.com/rcrowley/go-metrics/runtime_cgo.go b/metrics/runtime_cgo.go index e3391f4e8..e3391f4e8 100644 --- a/vendor/github.com/rcrowley/go-metrics/runtime_cgo.go +++ b/metrics/runtime_cgo.go diff --git a/vendor/github.com/rcrowley/go-metrics/runtime_gccpufraction.go b/metrics/runtime_gccpufraction.go index ca12c05ba..ca12c05ba 100644 --- a/vendor/github.com/rcrowley/go-metrics/runtime_gccpufraction.go +++ b/metrics/runtime_gccpufraction.go diff --git a/vendor/github.com/rcrowley/go-metrics/runtime_no_cgo.go b/metrics/runtime_no_cgo.go index 616a3b475..616a3b475 100644 --- a/vendor/github.com/rcrowley/go-metrics/runtime_no_cgo.go +++ b/metrics/runtime_no_cgo.go diff --git a/vendor/github.com/rcrowley/go-metrics/runtime_no_gccpufraction.go b/metrics/runtime_no_gccpufraction.go index be96aa6f1..be96aa6f1 100644 --- a/vendor/github.com/rcrowley/go-metrics/runtime_no_gccpufraction.go +++ b/metrics/runtime_no_gccpufraction.go diff --git a/metrics/runtime_test.go b/metrics/runtime_test.go new file mode 100644 index 000000000..ebbfd501a --- /dev/null +++ b/metrics/runtime_test.go @@ -0,0 +1,88 @@ +package metrics + +import ( + "runtime" + "testing" + "time" +) + +func BenchmarkRuntimeMemStats(b *testing.B) { + r := NewRegistry() + RegisterRuntimeMemStats(r) + b.ResetTimer() + for i := 0; i < b.N; i++ { + CaptureRuntimeMemStatsOnce(r) + } +} + +func TestRuntimeMemStats(t *testing.T) { + r := NewRegistry() + RegisterRuntimeMemStats(r) + CaptureRuntimeMemStatsOnce(r) + zero := runtimeMetrics.MemStats.PauseNs.Count() // Get a "zero" since GC may have run before these tests. + runtime.GC() + CaptureRuntimeMemStatsOnce(r) + if count := runtimeMetrics.MemStats.PauseNs.Count(); 1 != count-zero { + t.Fatal(count - zero) + } + runtime.GC() + runtime.GC() + CaptureRuntimeMemStatsOnce(r) + if count := runtimeMetrics.MemStats.PauseNs.Count(); 3 != count-zero { + t.Fatal(count - zero) + } + for i := 0; i < 256; i++ { + runtime.GC() + } + CaptureRuntimeMemStatsOnce(r) + if count := runtimeMetrics.MemStats.PauseNs.Count(); 259 != count-zero { + t.Fatal(count - zero) + } + for i := 0; i < 257; i++ { + runtime.GC() + } + CaptureRuntimeMemStatsOnce(r) + if count := runtimeMetrics.MemStats.PauseNs.Count(); 515 != count-zero { // We lost one because there were too many GCs between captures. + t.Fatal(count - zero) + } +} + +func TestRuntimeMemStatsNumThread(t *testing.T) { + r := NewRegistry() + RegisterRuntimeMemStats(r) + CaptureRuntimeMemStatsOnce(r) + + if value := runtimeMetrics.NumThread.Value(); value < 1 { + t.Fatalf("got NumThread: %d, wanted at least 1", value) + } +} + +func TestRuntimeMemStatsBlocking(t *testing.T) { + if g := runtime.GOMAXPROCS(0); g < 2 { + t.Skipf("skipping TestRuntimeMemStatsBlocking with GOMAXPROCS=%d\n", g) + } + ch := make(chan int) + go testRuntimeMemStatsBlocking(ch) + var memStats runtime.MemStats + t0 := time.Now() + runtime.ReadMemStats(&memStats) + t1 := time.Now() + t.Log("i++ during runtime.ReadMemStats:", <-ch) + go testRuntimeMemStatsBlocking(ch) + d := t1.Sub(t0) + t.Log(d) + time.Sleep(d) + t.Log("i++ during time.Sleep:", <-ch) +} + +func testRuntimeMemStatsBlocking(ch chan int) { + i := 0 + for { + select { + case ch <- i: + return + default: + i++ + } + } +} diff --git a/vendor/github.com/rcrowley/go-metrics/sample.go b/metrics/sample.go index fecee5ef6..5c4845a4f 100644 --- a/vendor/github.com/rcrowley/go-metrics/sample.go +++ b/metrics/sample.go @@ -46,7 +46,7 @@ type ExpDecaySample struct { // NewExpDecaySample constructs a new exponentially-decaying sample with the // given reservoir size and alpha. func NewExpDecaySample(reservoirSize int, alpha float64) Sample { - if UseNilMetrics { + if !Enabled { return NilSample{} } s := &ExpDecaySample{ @@ -407,7 +407,7 @@ type UniformSample struct { // NewUniformSample constructs a new uniform sample with the given reservoir // size. func NewUniformSample(reservoirSize int) Sample { - if UseNilMetrics { + if !Enabled { return NilSample{} } return &UniformSample{ diff --git a/metrics/sample_test.go b/metrics/sample_test.go new file mode 100644 index 000000000..d60e99c5b --- /dev/null +++ b/metrics/sample_test.go @@ -0,0 +1,363 @@ +package metrics + +import ( + "math/rand" + "runtime" + "testing" + "time" +) + +// Benchmark{Compute,Copy}{1000,1000000} demonstrate that, even for relatively +// expensive computations like Variance, the cost of copying the Sample, as +// approximated by a make and copy, is much greater than the cost of the +// computation for small samples and only slightly less for large samples. +func BenchmarkCompute1000(b *testing.B) { + s := make([]int64, 1000) + for i := 0; i < len(s); i++ { + s[i] = int64(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + SampleVariance(s) + } +} +func BenchmarkCompute1000000(b *testing.B) { + s := make([]int64, 1000000) + for i := 0; i < len(s); i++ { + s[i] = int64(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + SampleVariance(s) + } +} +func BenchmarkCopy1000(b *testing.B) { + s := make([]int64, 1000) + for i := 0; i < len(s); i++ { + s[i] = int64(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + sCopy := make([]int64, len(s)) + copy(sCopy, s) + } +} +func BenchmarkCopy1000000(b *testing.B) { + s := make([]int64, 1000000) + for i := 0; i < len(s); i++ { + s[i] = int64(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + sCopy := make([]int64, len(s)) + copy(sCopy, s) + } +} + +func BenchmarkExpDecaySample257(b *testing.B) { + benchmarkSample(b, NewExpDecaySample(257, 0.015)) +} + +func BenchmarkExpDecaySample514(b *testing.B) { + benchmarkSample(b, NewExpDecaySample(514, 0.015)) +} + +func BenchmarkExpDecaySample1028(b *testing.B) { + benchmarkSample(b, NewExpDecaySample(1028, 0.015)) +} + +func BenchmarkUniformSample257(b *testing.B) { + benchmarkSample(b, NewUniformSample(257)) +} + +func BenchmarkUniformSample514(b *testing.B) { + benchmarkSample(b, NewUniformSample(514)) +} + +func BenchmarkUniformSample1028(b *testing.B) { + benchmarkSample(b, NewUniformSample(1028)) +} + +func TestExpDecaySample10(t *testing.T) { + rand.Seed(1) + s := NewExpDecaySample(100, 0.99) + for i := 0; i < 10; i++ { + s.Update(int64(i)) + } + if size := s.Count(); 10 != size { + t.Errorf("s.Count(): 10 != %v\n", size) + } + if size := s.Size(); 10 != size { + t.Errorf("s.Size(): 10 != %v\n", size) + } + if l := len(s.Values()); 10 != l { + t.Errorf("len(s.Values()): 10 != %v\n", l) + } + for _, v := range s.Values() { + if v > 10 || v < 0 { + t.Errorf("out of range [0, 10): %v\n", v) + } + } +} + +func TestExpDecaySample100(t *testing.T) { + rand.Seed(1) + s := NewExpDecaySample(1000, 0.01) + for i := 0; i < 100; i++ { + s.Update(int64(i)) + } + if size := s.Count(); 100 != size { + t.Errorf("s.Count(): 100 != %v\n", size) + } + if size := s.Size(); 100 != size { + t.Errorf("s.Size(): 100 != %v\n", size) + } + if l := len(s.Values()); 100 != l { + t.Errorf("len(s.Values()): 100 != %v\n", l) + } + for _, v := range s.Values() { + if v > 100 || v < 0 { + t.Errorf("out of range [0, 100): %v\n", v) + } + } +} + +func TestExpDecaySample1000(t *testing.T) { + rand.Seed(1) + s := NewExpDecaySample(100, 0.99) + for i := 0; i < 1000; i++ { + s.Update(int64(i)) + } + if size := s.Count(); 1000 != size { + t.Errorf("s.Count(): 1000 != %v\n", size) + } + if size := s.Size(); 100 != size { + t.Errorf("s.Size(): 100 != %v\n", size) + } + if l := len(s.Values()); 100 != l { + t.Errorf("len(s.Values()): 100 != %v\n", l) + } + for _, v := range s.Values() { + if v > 1000 || v < 0 { + t.Errorf("out of range [0, 1000): %v\n", v) + } + } +} + +// This test makes sure that the sample's priority is not amplified by using +// nanosecond duration since start rather than second duration since start. +// The priority becomes +Inf quickly after starting if this is done, +// effectively freezing the set of samples until a rescale step happens. +func TestExpDecaySampleNanosecondRegression(t *testing.T) { + rand.Seed(1) + s := NewExpDecaySample(100, 0.99) + for i := 0; i < 100; i++ { + s.Update(10) + } + time.Sleep(1 * time.Millisecond) + for i := 0; i < 100; i++ { + s.Update(20) + } + v := s.Values() + avg := float64(0) + for i := 0; i < len(v); i++ { + avg += float64(v[i]) + } + avg /= float64(len(v)) + if avg > 16 || avg < 14 { + t.Errorf("out of range [14, 16]: %v\n", avg) + } +} + +func TestExpDecaySampleRescale(t *testing.T) { + s := NewExpDecaySample(2, 0.001).(*ExpDecaySample) + s.update(time.Now(), 1) + s.update(time.Now().Add(time.Hour+time.Microsecond), 1) + for _, v := range s.values.Values() { + if v.k == 0.0 { + t.Fatal("v.k == 0.0") + } + } +} + +func TestExpDecaySampleSnapshot(t *testing.T) { + now := time.Now() + rand.Seed(1) + s := NewExpDecaySample(100, 0.99) + for i := 1; i <= 10000; i++ { + s.(*ExpDecaySample).update(now.Add(time.Duration(i)), int64(i)) + } + snapshot := s.Snapshot() + s.Update(1) + testExpDecaySampleStatistics(t, snapshot) +} + +func TestExpDecaySampleStatistics(t *testing.T) { + now := time.Now() + rand.Seed(1) + s := NewExpDecaySample(100, 0.99) + for i := 1; i <= 10000; i++ { + s.(*ExpDecaySample).update(now.Add(time.Duration(i)), int64(i)) + } + testExpDecaySampleStatistics(t, s) +} + +func TestUniformSample(t *testing.T) { + rand.Seed(1) + s := NewUniformSample(100) + for i := 0; i < 1000; i++ { + s.Update(int64(i)) + } + if size := s.Count(); 1000 != size { + t.Errorf("s.Count(): 1000 != %v\n", size) + } + if size := s.Size(); 100 != size { + t.Errorf("s.Size(): 100 != %v\n", size) + } + if l := len(s.Values()); 100 != l { + t.Errorf("len(s.Values()): 100 != %v\n", l) + } + for _, v := range s.Values() { + if v > 1000 || v < 0 { + t.Errorf("out of range [0, 100): %v\n", v) + } + } +} + +func TestUniformSampleIncludesTail(t *testing.T) { + rand.Seed(1) + s := NewUniformSample(100) + max := 100 + for i := 0; i < max; i++ { + s.Update(int64(i)) + } + v := s.Values() + sum := 0 + exp := (max - 1) * max / 2 + for i := 0; i < len(v); i++ { + sum += int(v[i]) + } + if exp != sum { + t.Errorf("sum: %v != %v\n", exp, sum) + } +} + +func TestUniformSampleSnapshot(t *testing.T) { + s := NewUniformSample(100) + for i := 1; i <= 10000; i++ { + s.Update(int64(i)) + } + snapshot := s.Snapshot() + s.Update(1) + testUniformSampleStatistics(t, snapshot) +} + +func TestUniformSampleStatistics(t *testing.T) { + rand.Seed(1) + s := NewUniformSample(100) + for i := 1; i <= 10000; i++ { + s.Update(int64(i)) + } + testUniformSampleStatistics(t, s) +} + +func benchmarkSample(b *testing.B, s Sample) { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + pauseTotalNs := memStats.PauseTotalNs + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Update(1) + } + b.StopTimer() + runtime.GC() + runtime.ReadMemStats(&memStats) + b.Logf("GC cost: %d ns/op", int(memStats.PauseTotalNs-pauseTotalNs)/b.N) +} + +func testExpDecaySampleStatistics(t *testing.T, s Sample) { + if count := s.Count(); 10000 != count { + t.Errorf("s.Count(): 10000 != %v\n", count) + } + if min := s.Min(); 107 != min { + t.Errorf("s.Min(): 107 != %v\n", min) + } + if max := s.Max(); 10000 != max { + t.Errorf("s.Max(): 10000 != %v\n", max) + } + if mean := s.Mean(); 4965.98 != mean { + t.Errorf("s.Mean(): 4965.98 != %v\n", mean) + } + if stdDev := s.StdDev(); 2959.825156930727 != stdDev { + t.Errorf("s.StdDev(): 2959.825156930727 != %v\n", stdDev) + } + ps := s.Percentiles([]float64{0.5, 0.75, 0.99}) + if 4615 != ps[0] { + t.Errorf("median: 4615 != %v\n", ps[0]) + } + if 7672 != ps[1] { + t.Errorf("75th percentile: 7672 != %v\n", ps[1]) + } + if 9998.99 != ps[2] { + t.Errorf("99th percentile: 9998.99 != %v\n", ps[2]) + } +} + +func testUniformSampleStatistics(t *testing.T, s Sample) { + if count := s.Count(); 10000 != count { + t.Errorf("s.Count(): 10000 != %v\n", count) + } + if min := s.Min(); 37 != min { + t.Errorf("s.Min(): 37 != %v\n", min) + } + if max := s.Max(); 9989 != max { + t.Errorf("s.Max(): 9989 != %v\n", max) + } + if mean := s.Mean(); 4748.14 != mean { + t.Errorf("s.Mean(): 4748.14 != %v\n", mean) + } + if stdDev := s.StdDev(); 2826.684117548333 != stdDev { + t.Errorf("s.StdDev(): 2826.684117548333 != %v\n", stdDev) + } + ps := s.Percentiles([]float64{0.5, 0.75, 0.99}) + if 4599 != ps[0] { + t.Errorf("median: 4599 != %v\n", ps[0]) + } + if 7380.5 != ps[1] { + t.Errorf("75th percentile: 7380.5 != %v\n", ps[1]) + } + if 9986.429999999998 != ps[2] { + t.Errorf("99th percentile: 9986.429999999998 != %v\n", ps[2]) + } +} + +// TestUniformSampleConcurrentUpdateCount would expose data race problems with +// concurrent Update and Count calls on Sample when test is called with -race +// argument +func TestUniformSampleConcurrentUpdateCount(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + s := NewUniformSample(100) + for i := 0; i < 100; i++ { + s.Update(int64(i)) + } + quit := make(chan struct{}) + go func() { + t := time.NewTicker(10 * time.Millisecond) + for { + select { + case <-t.C: + s.Update(rand.Int63()) + case <-quit: + t.Stop() + return + } + } + }() + for i := 0; i < 1000; i++ { + s.Count() + time.Sleep(5 * time.Millisecond) + } + quit <- struct{}{} +} diff --git a/vendor/github.com/rcrowley/go-metrics/syslog.go b/metrics/syslog.go index 693f19085..a0ed4b1b2 100644 --- a/vendor/github.com/rcrowley/go-metrics/syslog.go +++ b/metrics/syslog.go @@ -11,7 +11,7 @@ import ( // Output each metric in the given registry to syslog periodically using // the given syslogger. func Syslog(r Registry, d time.Duration, w *syslog.Writer) { - for _ = range time.Tick(d) { + for range time.Tick(d) { r.Each(func(name string, i interface{}) { switch metric := i.(type) { case Counter: diff --git a/vendor/github.com/rcrowley/go-metrics/timer.go b/metrics/timer.go index 17db8f8d2..89e22208f 100644 --- a/vendor/github.com/rcrowley/go-metrics/timer.go +++ b/metrics/timer.go @@ -19,6 +19,7 @@ type Timer interface { RateMean() float64 Snapshot() Timer StdDev() float64 + Stop() Sum() int64 Time(func()) Update(time.Duration) @@ -28,6 +29,8 @@ type Timer interface { // GetOrRegisterTimer returns an existing Timer or constructs and registers a // new StandardTimer. +// Be sure to unregister the meter from the registry once it is of no use to +// allow for garbage collection. func GetOrRegisterTimer(name string, r Registry) Timer { if nil == r { r = DefaultRegistry @@ -36,8 +39,9 @@ func GetOrRegisterTimer(name string, r Registry) Timer { } // NewCustomTimer constructs a new StandardTimer from a Histogram and a Meter. +// Be sure to call Stop() once the timer is of no use to allow for garbage collection. func NewCustomTimer(h Histogram, m Meter) Timer { - if UseNilMetrics { + if !Enabled { return NilTimer{} } return &StandardTimer{ @@ -47,6 +51,8 @@ func NewCustomTimer(h Histogram, m Meter) Timer { } // NewRegisteredTimer constructs and registers a new StandardTimer. +// Be sure to unregister the meter from the registry once it is of no use to +// allow for garbage collection. func NewRegisteredTimer(name string, r Registry) Timer { c := NewTimer() if nil == r { @@ -58,8 +64,9 @@ func NewRegisteredTimer(name string, r Registry) Timer { // NewTimer constructs a new StandardTimer using an exponentially-decaying // sample with the same reservoir size and alpha as UNIX load averages. +// Be sure to call Stop() once the timer is of no use to allow for garbage collection. func NewTimer() Timer { - if UseNilMetrics { + if !Enabled { return NilTimer{} } return &StandardTimer{ @@ -112,6 +119,9 @@ func (NilTimer) Snapshot() Timer { return NilTimer{} } // StdDev is a no-op. func (NilTimer) StdDev() float64 { return 0.0 } +// Stop is a no-op. +func (NilTimer) Stop() {} + // Sum is a no-op. func (NilTimer) Sum() int64 { return 0 } @@ -201,6 +211,11 @@ func (t *StandardTimer) StdDev() float64 { return t.histogram.StdDev() } +// Stop stops the meter. +func (t *StandardTimer) Stop() { + t.meter.Stop() +} + // Sum returns the sum in the sample. func (t *StandardTimer) Sum() int64 { return t.histogram.Sum() @@ -288,6 +303,9 @@ func (t *TimerSnapshot) Snapshot() Timer { return t } // was taken. func (t *TimerSnapshot) StdDev() float64 { return t.histogram.StdDev() } +// Stop is a no-op. +func (t *TimerSnapshot) Stop() {} + // Sum returns the sum at the time the snapshot was taken. func (t *TimerSnapshot) Sum() int64 { return t.histogram.Sum() } diff --git a/metrics/timer_test.go b/metrics/timer_test.go new file mode 100644 index 000000000..c1f0ff938 --- /dev/null +++ b/metrics/timer_test.go @@ -0,0 +1,101 @@ +package metrics + +import ( + "fmt" + "math" + "testing" + "time" +) + +func BenchmarkTimer(b *testing.B) { + tm := NewTimer() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tm.Update(1) + } +} + +func TestGetOrRegisterTimer(t *testing.T) { + r := NewRegistry() + NewRegisteredTimer("foo", r).Update(47) + if tm := GetOrRegisterTimer("foo", r); 1 != tm.Count() { + t.Fatal(tm) + } +} + +func TestTimerExtremes(t *testing.T) { + tm := NewTimer() + tm.Update(math.MaxInt64) + tm.Update(0) + if stdDev := tm.StdDev(); 4.611686018427388e+18 != stdDev { + t.Errorf("tm.StdDev(): 4.611686018427388e+18 != %v\n", stdDev) + } +} + +func TestTimerStop(t *testing.T) { + l := len(arbiter.meters) + tm := NewTimer() + if len(arbiter.meters) != l+1 { + t.Errorf("arbiter.meters: %d != %d\n", l+1, len(arbiter.meters)) + } + tm.Stop() + if len(arbiter.meters) != l { + t.Errorf("arbiter.meters: %d != %d\n", l, len(arbiter.meters)) + } +} + +func TestTimerFunc(t *testing.T) { + tm := NewTimer() + tm.Time(func() { time.Sleep(50e6) }) + if max := tm.Max(); 35e6 > max || max > 95e6 { + t.Errorf("tm.Max(): 35e6 > %v || %v > 95e6\n", max, max) + } +} + +func TestTimerZero(t *testing.T) { + tm := NewTimer() + if count := tm.Count(); 0 != count { + t.Errorf("tm.Count(): 0 != %v\n", count) + } + if min := tm.Min(); 0 != min { + t.Errorf("tm.Min(): 0 != %v\n", min) + } + if max := tm.Max(); 0 != max { + t.Errorf("tm.Max(): 0 != %v\n", max) + } + if mean := tm.Mean(); 0.0 != mean { + t.Errorf("tm.Mean(): 0.0 != %v\n", mean) + } + if stdDev := tm.StdDev(); 0.0 != stdDev { + t.Errorf("tm.StdDev(): 0.0 != %v\n", stdDev) + } + ps := tm.Percentiles([]float64{0.5, 0.75, 0.99}) + if 0.0 != ps[0] { + t.Errorf("median: 0.0 != %v\n", ps[0]) + } + if 0.0 != ps[1] { + t.Errorf("75th percentile: 0.0 != %v\n", ps[1]) + } + if 0.0 != ps[2] { + t.Errorf("99th percentile: 0.0 != %v\n", ps[2]) + } + if rate1 := tm.Rate1(); 0.0 != rate1 { + t.Errorf("tm.Rate1(): 0.0 != %v\n", rate1) + } + if rate5 := tm.Rate5(); 0.0 != rate5 { + t.Errorf("tm.Rate5(): 0.0 != %v\n", rate5) + } + if rate15 := tm.Rate15(); 0.0 != rate15 { + t.Errorf("tm.Rate15(): 0.0 != %v\n", rate15) + } + if rateMean := tm.RateMean(); 0.0 != rateMean { + t.Errorf("tm.RateMean(): 0.0 != %v\n", rateMean) + } +} + +func ExampleGetOrRegisterTimer() { + m := "account.create.latency" + t := GetOrRegisterTimer(m, nil) + t.Update(47) + fmt.Println(t.Max()) // Output: 47 +} diff --git a/vendor/github.com/rcrowley/go-metrics/validate.sh b/metrics/validate.sh index f6499982e..c4ae91e64 100755 --- a/vendor/github.com/rcrowley/go-metrics/validate.sh +++ b/metrics/validate.sh @@ -7,4 +7,4 @@ GOFMT_LINES=`gofmt -l . | wc -l | xargs` test $GOFMT_LINES -eq 0 || echo "gofmt needs to be run, ${GOFMT_LINES} files have issues" # run the tests for the root package -go test . +go test -race . diff --git a/vendor/github.com/rcrowley/go-metrics/writer.go b/metrics/writer.go index 091e971d2..88521a80d 100644 --- a/vendor/github.com/rcrowley/go-metrics/writer.go +++ b/metrics/writer.go @@ -10,7 +10,7 @@ import ( // Write sorts writes each metric in the given registry periodically to the // given io.Writer. func Write(r Registry, d time.Duration, w io.Writer) { - for _ = range time.Tick(d) { + for range time.Tick(d) { WriteOnce(r, w) } } diff --git a/metrics/writer_test.go b/metrics/writer_test.go new file mode 100644 index 000000000..1aacc2871 --- /dev/null +++ b/metrics/writer_test.go @@ -0,0 +1,22 @@ +package metrics + +import ( + "sort" + "testing" +) + +func TestMetricsSorting(t *testing.T) { + var namedMetrics = namedMetricSlice{ + {name: "zzz"}, + {name: "bbb"}, + {name: "fff"}, + {name: "ggg"}, + } + + sort.Sort(namedMetrics) + for i, name := range []string{"bbb", "fff", "ggg", "zzz"} { + if namedMetrics[i].name != name { + t.Fail() + } + } +} diff --git a/mobile/bind.go b/mobile/bind.go index 7a1bf9e60..d6e621a25 100644 --- a/mobile/bind.go +++ b/mobile/bind.go @@ -154,12 +154,20 @@ func (c *BoundContract) GetDeployer() *Transaction { // Call invokes the (constant) contract method with params as input values and // sets the output to result. func (c *BoundContract) Call(opts *CallOpts, out *Interfaces, method string, args *Interfaces) error { - results := make([]interface{}, len(out.objects)) - copy(results, out.objects) - if err := c.contract.Call(&opts.opts, &results, method, args.objects...); err != nil { - return err + if len(out.objects) == 1 { + result := out.objects[0] + if err := c.contract.Call(&opts.opts, result, method, args.objects...); err != nil { + return err + } + out.objects[0] = result + } else { + results := make([]interface{}, len(out.objects)) + copy(results, out.objects) + if err := c.contract.Call(&opts.opts, &results, method, args.objects...); err != nil { + return err + } + copy(out.objects, results) } - copy(out.objects, results) return nil } diff --git a/node/api.go b/node/api.go index 4e9b1edc4..a3b8bc0bb 100644 --- a/node/api.go +++ b/node/api.go @@ -24,10 +24,10 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rpc" - "github.com/rcrowley/go-metrics" ) // PrivateAdminAPI is the collection of administrative API methods exposed only @@ -308,6 +308,11 @@ func (api *PublicDebugAPI) Metrics(raw bool) (map[string]interface{}, error) { // Fill the counter with the metric details, formatting if requested if raw { switch metric := metric.(type) { + case metrics.Counter: + root[name] = map[string]interface{}{ + "Overall": float64(metric.Count()), + } + case metrics.Meter: root[name] = map[string]interface{}{ "AvgRate01Min": metric.Rate1(), @@ -338,6 +343,11 @@ func (api *PublicDebugAPI) Metrics(raw bool) (map[string]interface{}, error) { } } else { switch metric := metric.(type) { + case metrics.Counter: + root[name] = map[string]interface{}{ + "Overall": float64(metric.Count()), + } + case metrics.Meter: root[name] = map[string]interface{}{ "Avg01Min": format(metric.Rate1()*60, metric.Rate1()), diff --git a/node/defaults.go b/node/defaults.go index d4e148683..887560580 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -35,11 +35,12 @@ const ( // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ - DataDir: DefaultDataDir(), - HTTPPort: DefaultHTTPPort, - HTTPModules: []string{"net", "web3"}, - WSPort: DefaultWSPort, - WSModules: []string{"net", "web3"}, + DataDir: DefaultDataDir(), + HTTPPort: DefaultHTTPPort, + HTTPModules: []string{"net", "web3"}, + HTTPVirtualHosts: []string{"localhost"}, + WSPort: DefaultWSPort, + WSModules: []string{"net", "web3"}, P2P: p2p.Config{ ListenAddr: ":30303", MaxPeers: 25, diff --git a/p2p/dial.go b/p2p/dial.go index f5ff2c211..d8feceb9f 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -154,6 +154,9 @@ func (s *dialstate) addStatic(n *discover.Node) { func (s *dialstate) removeStatic(n *discover.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID) + // This removes a previous dial timestamp so that application + // can force a server to reconnect with chosen peer immediately. + s.hist.remove(n.ID) } func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { @@ -390,6 +393,16 @@ func (h dialHistory) min() pastDial { } func (h *dialHistory) add(id discover.NodeID, exp time.Time) { heap.Push(h, pastDial{id, exp}) + +} +func (h *dialHistory) remove(id discover.NodeID) bool { + for i, v := range *h { + if v.id == id { + heap.Remove(h, i) + return true + } + } + return false } func (h dialHistory) contains(id discover.NodeID) bool { for _, v := range h { diff --git a/p2p/dial_test.go b/p2p/dial_test.go index ad18ef9ab..2a7941fc6 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -515,6 +515,50 @@ func TestDialStateStaticDial(t *testing.T) { }) } +// This test checks that static peers will be redialed immediately if they were re-added to a static list. +func TestDialStaticAfterReset(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + } + + rounds := []round{ + // Static dials are launched for the nodes that aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + }, + // No new dial tasks, all peers are connected. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + new: []task{ + &waitExpireTask{Duration: 30 * time.Second}, + }, + }, + } + dTest := dialtest{ + init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + rounds: rounds, + } + runDialTest(t, dTest) + for _, n := range wantStatic { + dTest.init.removeStatic(n) + dTest.init.addStatic(n) + } + // without removing peers they will be considered recently dialed + runDialTest(t, dTest) +} + // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { wantStatic := []*discover.Node{ diff --git a/p2p/discover/database.go b/p2p/discover/database.go index b136609f2..6f98de9b4 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.lastPong(id); seen.After(threshold) { + if seen := db.bondTime(id); seen.After(threshold) { continue } } @@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// lastPong retrieves the time of the last successful contact from remote node. -func (db *nodeDB) lastPong(id NodeID) time.Time { +// bondTime retrieves the time of the last successful pong from remote node. +func (db *nodeDB) bondTime(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } -// updateLastPong updates the last time a remote node successfully contacted. -func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { +// hasBond reports whether the given node is considered bonded. +func (db *nodeDB) hasBond(id NodeID) bool { + return time.Since(db.bondTime(id)) < nodeDBNodeExpiration +} + +// updateBondTime updates the last pong time of a node. +func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -327,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.lastPong(n.ID)) > maxAge { + if now.Sub(db.bondTime(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index be972fd2c..c4fa44d09 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.lastPong(node.ID); stored.Unix() != 0 { + if stored := db.bondTime(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateLastPong(node.ID, inst); err != nil { + if err := db.updateBondTime(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() { + if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to insert lastPong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire some of them, and check the rest @@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire the nodes and make sure self has been evacuated too diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 17c9db777..6509326e6 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -455,7 +455,7 @@ func (tab *Table) loadSeedNodes(bond bool) { } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) tab.add(seed) } @@ -596,7 +596,7 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 } // Start bonding if we haven't seen this node for a while or if it failed findnode too often. node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.lastPong(id)) + age := time.Since(tab.db.bondTime(id)) var result error if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) @@ -663,7 +663,7 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { if err := tab.net.ping(id, addr); err != nil { return err } - tab.db.updateLastPong(id, time.Now()) + tab.db.updateBondTime(id, time.Now()) return nil } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index e40de2c36..524c6e498 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -613,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 3ffa5c4dd..db9804f7b 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -247,12 +247,8 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateNode(NewNode( - PubkeyID(&test.remotekey.PublicKey), - test.remoteaddr.IP, - uint16(test.remoteaddr.Port), - 99, - )) + test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(targetHash, bucketSize) diff --git a/p2p/message.go b/p2p/message.go index 5690494bf..50b419970 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -22,8 +22,6 @@ import ( "fmt" "io" "io/ioutil" - "net" - "sync" "sync/atomic" "time" @@ -112,30 +110,6 @@ func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { return Send(w, msgcode, elems) } -// netWrapper wraps a MsgReadWriter with locks around -// ReadMsg/WriteMsg and applies read/write deadlines. -type netWrapper struct { - rmu, wmu sync.Mutex - - rtimeout, wtimeout time.Duration - conn net.Conn - wrapped MsgReadWriter -} - -func (rw *netWrapper) ReadMsg() (Msg, error) { - rw.rmu.Lock() - defer rw.rmu.Unlock() - rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout)) - return rw.wrapped.ReadMsg() -} - -func (rw *netWrapper) WriteMsg(msg Msg) error { - rw.wmu.Lock() - defer rw.wmu.Unlock() - rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout)) - return rw.wrapped.WriteMsg(msg) -} - // eofSignal wraps a reader with eof signaling. the eof channel is // closed when the wrapped reader returns an error or when count bytes // have been read. diff --git a/p2p/metrics.go b/p2p/metrics.go index 98b61901d..4cbff90ac 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -25,10 +25,10 @@ import ( ) var ( - ingressConnectMeter = metrics.NewMeter("p2p/InboundConnects") - ingressTrafficMeter = metrics.NewMeter("p2p/InboundTraffic") - egressConnectMeter = metrics.NewMeter("p2p/OutboundConnects") - egressTrafficMeter = metrics.NewMeter("p2p/OutboundTraffic") + ingressConnectMeter = metrics.NewRegisteredMeter("p2p/InboundConnects", nil) + ingressTrafficMeter = metrics.NewRegisteredMeter("p2p/InboundTraffic", nil) + egressConnectMeter = metrics.NewRegisteredMeter("p2p/OutboundConnects", nil) + egressTrafficMeter = metrics.NewRegisteredMeter("p2p/OutboundTraffic", nil) ) // meteredConn is a wrapper around a network TCP connection that meters both the diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go new file mode 100644 index 000000000..9914c9958 --- /dev/null +++ b/p2p/protocols/protocol.go @@ -0,0 +1,311 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +/* +Package protocols is an extension to p2p. It offers a user friendly simple way to define +devp2p subprotocols by abstracting away code standardly shared by protocols. + +* automate assigments of code indexes to messages +* automate RLP decoding/encoding based on reflecting +* provide the forever loop to read incoming messages +* standardise error handling related to communication +* standardised handshake negotiation +* TODO: automatic generation of wire protocol specification for peers + +*/ +package protocols + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/ethereum/go-ethereum/p2p" +) + +// error codes used by this protocol scheme +const ( + ErrMsgTooLong = iota + ErrDecode + ErrWrite + ErrInvalidMsgCode + ErrInvalidMsgType + ErrHandshake + ErrNoHandler + ErrHandler +) + +// error description strings associated with the codes +var errorToString = map[int]string{ + ErrMsgTooLong: "Message too long", + ErrDecode: "Invalid message (RLP error)", + ErrWrite: "Error sending message", + ErrInvalidMsgCode: "Invalid message code", + ErrInvalidMsgType: "Invalid message type", + ErrHandshake: "Handshake error", + ErrNoHandler: "No handler registered error", + ErrHandler: "Message handler error", +} + +/* +Error implements the standard go error interface. +Use: + + errorf(code, format, params ...interface{}) + +Prints as: + + <description>: <details> + +where description is given by code in errorToString +and details is fmt.Sprintf(format, params...) + +exported field Code can be checked +*/ +type Error struct { + Code int + message string + format string + params []interface{} +} + +func (e Error) Error() (message string) { + if len(e.message) == 0 { + name, ok := errorToString[e.Code] + if !ok { + panic("invalid message code") + } + e.message = name + if e.format != "" { + e.message += ": " + fmt.Sprintf(e.format, e.params...) + } + } + return e.message +} + +func errorf(code int, format string, params ...interface{}) *Error { + return &Error{ + Code: code, + format: format, + params: params, + } +} + +// Spec is a protocol specification including its name and version as well as +// the types of messages which are exchanged +type Spec struct { + // Name is the name of the protocol, often a three-letter word + Name string + + // Version is the version number of the protocol + Version uint + + // MaxMsgSize is the maximum accepted length of the message payload + MaxMsgSize uint32 + + // Messages is a list of message data types which this protocol uses, with + // each message type being sent with its array index as the code (so + // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes + // 0, 1 and 2 respectively) + // each message must have a single unique data type + Messages []interface{} + + initOnce sync.Once + codes map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +func (s *Spec) init() { + s.initOnce.Do(func() { + s.codes = make(map[reflect.Type]uint64, len(s.Messages)) + s.types = make(map[uint64]reflect.Type, len(s.Messages)) + for i, msg := range s.Messages { + code := uint64(i) + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + s.codes[typ] = code + s.types[code] = typ + } + }) +} + +// Length returns the number of message types in the protocol +func (s *Spec) Length() uint64 { + return uint64(len(s.Messages)) +} + +// GetCode returns the message code of a type, and boolean second argument is +// false if the message type is not found +func (s *Spec) GetCode(msg interface{}) (uint64, bool) { + s.init() + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + code, ok := s.codes[typ] + return code, ok +} + +// NewMsg construct a new message type given the code +func (s *Spec) NewMsg(code uint64) (interface{}, bool) { + s.init() + typ, ok := s.types[code] + if !ok { + return nil, false + } + return reflect.New(typ).Interface(), true +} + +// Peer represents a remote peer or protocol instance that is running on a peer connection with +// a remote peer +type Peer struct { + *p2p.Peer // the p2p.Peer object representing the remote + rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from + spec *Spec +} + +// NewPeer constructs a new peer +// this constructor is called by the p2p.Protocol#Run function +// the first two arguments are the arguments passed to p2p.Protocol.Run function +// the third argument is the Spec describing the protocol +func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { + return &Peer{ + Peer: p, + rw: rw, + spec: spec, + } +} + +// Run starts the forever loop that handles incoming messages +// called within the p2p.Protocol#Run function +// the handler argument is a function which is called for each message received +// from the remote peer, a returned error causes the loop to exit +// resulting in disconnection +func (p *Peer) Run(handler func(msg interface{}) error) error { + for { + if err := p.handleIncoming(handler); err != nil { + return err + } + } +} + +// Drop disconnects a peer. +// TODO: may need to implement protocol drop only? don't want to kick off the peer +// if they are useful for other protocols +func (p *Peer) Drop(err error) { + p.Disconnect(p2p.DiscSubprotocolError) +} + +// Send takes a message, encodes it in RLP, finds the right message code and sends the +// message off to the peer +// this low level call will be wrapped by libraries providing routed or broadcast sends +// but often just used to forward and push messages to directly connected peers +func (p *Peer) Send(msg interface{}) error { + code, found := p.spec.GetCode(msg) + if !found { + return errorf(ErrInvalidMsgType, "%v", code) + } + return p2p.Send(p.rw, code, msg) +} + +// handleIncoming(code) +// is called each cycle of the main forever loop that dispatches incoming messages +// if this returns an error the loop returns and the peer is disconnected with the error +// this generic handler +// * checks message size, +// * checks for out-of-range message codes, +// * handles decoding with reflection, +// * call handlers as callbacks +func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + if msg.Size > p.spec.MaxMsgSize { + return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) + } + + val, ok := p.spec.NewMsg(msg.Code) + if !ok { + return errorf(ErrInvalidMsgCode, "%v", msg.Code) + } + if err := msg.Decode(val); err != nil { + return errorf(ErrDecode, "<= %v: %v", msg, err) + } + + // call the registered handler callbacks + // a registered callback take the decoded message as argument as an interface + // which the handler is supposed to cast to the appropriate type + // it is entirely safe not to check the cast in the handler since the handler is + // chosen based on the proper type in the first place + if err := handle(val); err != nil { + return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) + } + return nil +} + +// Handshake negotiates a handshake on the peer connection +// * arguments +// * context +// * the local handshake to be sent to the remote peer +// * funcion to be called on the remote handshake (can be nil) +// * expects a remote handshake back of the same type +// * the dialing peer needs to send the handshake first and then waits for remote +// * the listening peer waits for the remote handshake and then sends it +// returns the remote handshake and an error +func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { + if _, ok := p.spec.GetCode(hs); !ok { + return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) + } + errc := make(chan error, 2) + handle := func(msg interface{}) error { + rhs = msg + if verify != nil { + return verify(rhs) + } + return nil + } + send := func() { errc <- p.Send(hs) } + receive := func() { errc <- p.handleIncoming(handle) } + + go func() { + if p.Inbound() { + receive() + send() + } else { + send() + receive() + } + }() + + for i := 0; i < 2; i++ { + select { + case err = <-errc: + case <-ctx.Done(): + err = ctx.Err() + } + if err != nil { + return nil, errorf(ErrHandshake, err.Error()) + } + } + return rhs, nil +} diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go new file mode 100644 index 000000000..053f537a6 --- /dev/null +++ b/p2p/protocols/protocol_test.go @@ -0,0 +1,389 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package protocols + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + p2ptest "github.com/ethereum/go-ethereum/p2p/testing" +) + +// handshake message type +type hs0 struct { + C uint +} + +// message to kill/drop the peer with nodeID +type kill struct { + C discover.NodeID +} + +// message to drop connection +type drop struct { +} + +/// protoHandshake represents module-independent aspects of the protocol and is +// the first message peers send and receive as part the initial exchange +type protoHandshake struct { + Version uint // local and remote peer should have identical version + NetworkID string // local and remote peer should have identical network id +} + +// checkProtoHandshake verifies local and remote protoHandshakes match +func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { + return func(rhs interface{}) error { + remote := rhs.(*protoHandshake) + if remote.NetworkID != testNetworkID { + return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) + } + + if remote.Version != testVersion { + return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) + } + return nil + } +} + +// newProtocol sets up a protocol +// the run function here demonstrates a typical protocol using peerPool, handshake +// and messages registered to handlers +func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + protoHandshake{}, + hs0{}, + kill{}, + drop{}, + }, + } + return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw, spec) + + // initiate one-off protohandshake and check validity + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + phs := &protoHandshake{42, "420"} + hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) + _, err := peer.Handshake(ctx, phs, hsCheck) + if err != nil { + return err + } + + lhs := &hs0{42} + // module handshake demonstrating a simple repeatable exchange of same-type message + hs, err := peer.Handshake(ctx, lhs, nil) + if err != nil { + return err + } + + if rmhs := hs.(*hs0); rmhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) + } + + handle := func(msg interface{}) error { + switch msg := msg.(type) { + + case *protoHandshake: + return errors.New("duplicate handshake") + + case *hs0: + rhs := msg + if rhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) + } + lhs.C += rhs.C + return peer.Send(lhs) + + case *kill: + // demonstrates use of peerPool, killing another peer connection as a response to a message + id := msg.C + pp.Get(id).Drop(errors.New("killed")) + return nil + + case *drop: + // for testing we can trigger self induced disconnect upon receiving drop message + return errors.New("dropped") + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + } + + pp.Add(peer) + defer pp.Remove(peer) + return peer.Run(handle) + } +} + +func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { + conf := adapters.RandomNodeConfig() + return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) +} + +func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: proto, + Peer: id, + }, + }, + }, + } +} + +func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + // TODO: make this more than one handshake + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestProtoHandshakeVersionMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) +} + +func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) +} + +func TestProtoHandshakeSuccess(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "420"}) +} + +func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 1, + Msg: &hs0{resp}, + Peer: id, + }, + }, + }, + } +} + +func runModuleHandshake(t *testing.T, resp uint, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { + t.Fatal(err) + } + if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestModuleHandshakeError(t *testing.T) { + runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) +} + +func TestModuleHandshakeSuccess(t *testing.T) { + runModuleHandshake(t, 42) +} + +// testing complex interactions over multiple peers, relaying, dropping +func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Label: "primary handshake", + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + }, + { + Label: "module handshake", + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: a, + }, + { + Code: 1, + Msg: &hs0{42}, + Peer: b, + }, + }, + }, + + {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, + {Code: 1, Msg: &hs0{41}, Peer: b}}}, + {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, + {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} +} + +func runMultiplePeers(t *testing.T, peer int, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + + if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { + t.Fatal(err) + } + // after some exchanges of messages, we can test state changes + // here this is simply demonstrated by the peerPool + // after the handshake negotiations peers must be added to the pool + // time.Sleep(1) + tick := time.NewTicker(10 * time.Millisecond) + timeout := time.NewTimer(1 * time.Second) +WAIT: + for { + select { + case <-tick.C: + if pp.Has(s.IDs[0]) { + break WAIT + } + case <-timeout.C: + t.Fatal("timeout") + } + } + if !pp.Has(s.IDs[1]) { + t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) + } + + // peer 0 sends kill request for peer with index <peer> + err := s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 2, + Msg: &kill{s.IDs[peer]}, + Peer: s.IDs[0], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // the peer not killed sends a drop request + err = s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 3, + Msg: &drop{}, + Peer: s.IDs[(peer+1)%2], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // check the actual discconnect errors on the individual peers + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } + // test if disconnected peers have been removed from peerPool + if pp.Has(s.IDs[peer]) { + t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) + } + +} + +func TestMultiplePeersDropSelf(t *testing.T) { + runMultiplePeers(t, 0, + fmt.Errorf("subprotocol error"), + fmt.Errorf("Message handler error: (msg code 3): dropped"), + ) +} + +func TestMultiplePeersDropOther(t *testing.T) { + runMultiplePeers(t, 1, + fmt.Errorf("Message handler error: (msg code 3): dropped"), + fmt.Errorf("subprotocol error"), + ) +} diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 24037ecc1..1889edac9 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -108,17 +108,19 @@ func (t *rlpx) close(err error) { // Tell the remote end why we're disconnecting if possible. if t.rw != nil { if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) - SendItems(t.rw, discMsg, r) + // rlpx tries to send DiscReason to disconnected peer + // if the connection is net.Pipe (in-memory simulation) + // it hangs forever, since net.Pipe does not implement + // a write deadline. Because of this only try to send + // the disconnect reason message if there is no error. + if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { + SendItems(t.rw, discMsg, r) + } } } t.fd.Close() } -// doEncHandshake runs the protocol handshake using authenticated -// messages. the protocol handshake is the first authenticated message -// and also verifies whether the encryption handshake 'worked' and the -// remote side actually provided the right public key. func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { // Writing our handshake happens concurrently, we prefer // returning the handshake read error. If the remote side @@ -169,6 +171,10 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, return &hs, nil } +// doEncHandshake runs the protocol handshake using authenticated +// messages. the protocol handshake is the first authenticated message +// and also verifies whether the encryption handshake 'worked' and the +// remote side actually provided the right public key. func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { var ( sec secrets diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index f4cefa650..bca460402 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) { node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} - fd0, fd1 = net.Pipe() - wg sync.WaitGroup + wg sync.WaitGroup ) + fd0, fd1, err := tcpPipe() + if err != nil { + t.Fatal(err) + } + wg.Add(2) go func() { defer wg.Done() - defer fd1.Close() + defer fd0.Close() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { @@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } + +// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket +func tcpPipe() (net.Conn, net.Conn, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer l.Close() + + var aconn net.Conn + aerr := make(chan error, 1) + go func() { + var err error + aconn, err = l.Accept() + aerr <- err + }() + + dconn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + <-aerr + return nil, nil, err + } + if err := <-aerr; err != nil { + dconn.Close() + return nil, nil, err + } + return aconn, dconn, nil +} diff --git a/p2p/server.go b/p2p/server.go index 90e92dc05..c41d1dc15 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -36,9 +36,7 @@ import ( ) const ( - defaultDialTimeout = 15 * time.Second - refreshPeersInterval = 30 * time.Second - staticPeerCheckInterval = 15 * time.Second + defaultDialTimeout = 15 * time.Second // Connectivity defaults. maxActiveDialTasks = 16 diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go new file mode 100644 index 000000000..45c6e6142 --- /dev/null +++ b/p2p/testing/peerpool.go @@ -0,0 +1,67 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package testing + +import ( + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +type TestPeer interface { + ID() discover.NodeID + Drop(error) +} + +// TestPeerPool is an example peerPool to demonstrate registration of peer connections +type TestPeerPool struct { + lock sync.Mutex + peers map[discover.NodeID]TestPeer +} + +func NewTestPeerPool() *TestPeerPool { + return &TestPeerPool{peers: make(map[discover.NodeID]TestPeer)} +} + +func (self *TestPeerPool) Add(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + log.Trace(fmt.Sprintf("pp add peer %v", p.ID())) + self.peers[p.ID()] = p + +} + +func (self *TestPeerPool) Remove(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + delete(self.peers, p.ID()) +} + +func (self *TestPeerPool) Has(id discover.NodeID) bool { + self.lock.Lock() + defer self.lock.Unlock() + _, ok := self.peers[id] + return ok +} + +func (self *TestPeerPool) Get(id discover.NodeID) TestPeer { + self.lock.Lock() + defer self.lock.Unlock() + return self.peers[id] +} diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go new file mode 100644 index 000000000..361285f06 --- /dev/null +++ b/p2p/testing/protocolsession.go @@ -0,0 +1,280 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package testing + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +var errTimedOut = errors.New("timed out") + +// ProtocolSession is a quasi simulation of a pivot node running +// a service and a number of dummy peers that can send (trigger) or +// receive (expect) messages +type ProtocolSession struct { + Server *p2p.Server + IDs []discover.NodeID + adapter *adapters.SimAdapter + events chan *p2p.PeerEvent +} + +// Exchange is the basic units of protocol tests +// the triggers and expects in the arrays are run immediately and asynchronously +// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types +// because it's unpredictable which expect will receive which message +// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code) +// an exchange is defined on a session +type Exchange struct { + Label string + Triggers []Trigger + Expects []Expect + Timeout time.Duration +} + +// Trigger is part of the exchange, incoming message for the pivot node +// sent by a peer +type Trigger struct { + Msg interface{} // type of message to be sent + Code uint64 // code of message is given + Peer discover.NodeID // the peer to send the message to + Timeout time.Duration // timeout duration for the sending +} + +// Expect is part of an exchange, outgoing message from the pivot node +// received by a peer +type Expect struct { + Msg interface{} // type of message to expect + Code uint64 // code of message is now given + Peer discover.NodeID // the peer that expects the message + Timeout time.Duration // timeout duration for receiving +} + +// Disconnect represents a disconnect event, used and checked by TestDisconnected +type Disconnect struct { + Peer discover.NodeID // discconnected peer + Error error // disconnect reason +} + +// trigger sends messages from peers +func (self *ProtocolSession) trigger(trig Trigger) error { + simNode, ok := self.adapter.GetNode(trig.Peer) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer) + } + + errc := make(chan error) + + go func() { + errc <- mockNode.Trigger(&trig) + }() + + t := trig.Timeout + if t == time.Duration(0) { + t = 1000 * time.Millisecond + } + select { + case err := <-errc: + return err + case <-time.After(t): + return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer) + } +} + +// expect checks an expectation of a message sent out by the pivot node +func (self *ProtocolSession) expect(exps []Expect) error { + // construct a map of expectations for each node + peerExpects := make(map[discover.NodeID][]Expect) + for _, exp := range exps { + if exp.Msg == nil { + return errors.New("no message to expect") + } + peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) + } + + // construct a map of mockNodes for each node + mockNodes := make(map[discover.NodeID]*mockNode) + for nodeID := range peerExpects { + simNode, ok := self.adapter.GetNode(nodeID) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", nodeID) + } + mockNodes[nodeID] = mockNode + } + + // done chanell cancels all created goroutines when function returns + done := make(chan struct{}) + defer close(done) + // errc catches the first error from + errc := make(chan error) + + wg := &sync.WaitGroup{} + wg.Add(len(mockNodes)) + for nodeID, mockNode := range mockNodes { + nodeID := nodeID + mockNode := mockNode + go func() { + defer wg.Done() + + // Sum all Expect timeouts to give the maximum + // time for all expectations to finish. + // mockNode.Expect checks all received messages against + // a list of expected messages and timeout for each + // of them can not be checked separately. + var t time.Duration + for _, exp := range peerExpects[nodeID] { + if exp.Timeout == time.Duration(0) { + t += 2000 * time.Millisecond + } else { + t += exp.Timeout + } + } + alarm := time.NewTimer(t) + defer alarm.Stop() + + // expectErrc is used to check if error returned + // from mockNode.Expect is not nil and to send it to + // errc only in that case. + // done channel will be closed when function + expectErrc := make(chan error) + go func() { + select { + case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): + case <-done: + case <-alarm.C: + } + }() + + select { + case err := <-expectErrc: + if err != nil { + select { + case errc <- err: + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + } + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + + }() + } + + go func() { + wg.Wait() + // close errc when all goroutines finish to return nill err from errc + close(errc) + }() + + return <-errc +} + +// TestExchanges tests a series of exchanges against the session +func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { + for i, e := range exchanges { + if err := self.testExchange(e); err != nil { + return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) + } + log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) + } + return nil +} + +// testExchange tests a single Exchange. +// Default timeout value is 2 seconds. +func (self *ProtocolSession) testExchange(e Exchange) error { + errc := make(chan error) + done := make(chan struct{}) + defer close(done) + + go func() { + for _, trig := range e.Triggers { + err := self.trigger(trig) + if err != nil { + errc <- err + return + } + } + + select { + case errc <- self.expect(e.Expects): + case <-done: + } + }() + + // time out globally or finish when all expectations satisfied + t := e.Timeout + if t == 0 { + t = 2000 * time.Millisecond + } + alarm := time.NewTimer(t) + select { + case err := <-errc: + return err + case <-alarm.C: + return errTimedOut + } +} + +// TestDisconnected tests the disconnections given as arguments +// the disconnect structs describe what disconnect error is expected on which peer +func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { + expects := make(map[discover.NodeID]error) + for _, disconnect := range disconnects { + expects[disconnect.Peer] = disconnect.Error + } + + timeout := time.After(time.Second) + for len(expects) > 0 { + select { + case event := <-self.events: + if event.Type != p2p.PeerEventTypeDrop { + continue + } + expectErr, ok := expects[event.Peer] + if !ok { + continue + } + + if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) { + return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error) + } + delete(expects, event.Peer) + case <-timeout: + return fmt.Errorf("timed out waiting for peers to disconnect") + } + } + return nil +} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go new file mode 100644 index 000000000..a797412d6 --- /dev/null +++ b/p2p/testing/protocoltester.go @@ -0,0 +1,269 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +/* +the p2p/testing package provides a unit test scheme to check simple +protocol message exchanges with one pivot node and a number of dummy peers +The pivot test node runs a node.Service, the dummy peers run a mock node +that can be used to send and receive messages +*/ + +package testing + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/rpc" +) + +// ProtocolTester is the tester environment used for unit testing protocol +// message exchanges. It uses p2p/simulations framework +type ProtocolTester struct { + *ProtocolSession + network *simulations.Network +} + +// NewProtocolTester constructs a new ProtocolTester +// it takes as argument the pivot node id, the number of dummy peers and the +// protocol run function called on a peer connection by the p2p server +func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { + services := adapters.Services{ + "test": func(ctx *adapters.ServiceContext) (node.Service, error) { + return &testNode{run}, nil + }, + "mock": func(ctx *adapters.ServiceContext) (node.Service, error) { + return newMockNode(), nil + }, + } + adapter := adapters.NewSimAdapter(services) + net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{}) + if _, err := net.NewNodeWithConfig(&adapters.NodeConfig{ + ID: id, + EnableMsgEvents: true, + Services: []string{"test"}, + }); err != nil { + panic(err.Error()) + } + if err := net.Start(id); err != nil { + panic(err.Error()) + } + + node := net.GetNode(id).Node.(*adapters.SimNode) + peers := make([]*adapters.NodeConfig, n) + peerIDs := make([]discover.NodeID, n) + for i := 0; i < n; i++ { + peers[i] = adapters.RandomNodeConfig() + peers[i].Services = []string{"mock"} + peerIDs[i] = peers[i].ID + } + events := make(chan *p2p.PeerEvent, 1000) + node.SubscribeEvents(events) + ps := &ProtocolSession{ + Server: node.Server(), + IDs: peerIDs, + adapter: adapter, + events: events, + } + self := &ProtocolTester{ + ProtocolSession: ps, + network: net, + } + + self.Connect(id, peers...) + + return self +} + +// Stop stops the p2p server +func (self *ProtocolTester) Stop() error { + self.Server.Stop() + return nil +} + +// Connect brings up the remote peer node and connects it using the +// p2p/simulations network connection with the in memory network adapter +func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) { + for _, peer := range peers { + log.Trace(fmt.Sprintf("start node %v", peer.ID)) + if _, err := self.network.NewNodeWithConfig(peer); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + if err := self.network.Start(peer.ID); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + log.Trace(fmt.Sprintf("connect to %v", peer.ID)) + if err := self.network.Connect(selfID, peer.ID); err != nil { + panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err)) + } + } + +} + +// testNode wraps a protocol run function and implements the node.Service +// interface +type testNode struct { + run func(*p2p.Peer, p2p.MsgReadWriter) error +} + +func (t *testNode) Protocols() []p2p.Protocol { + return []p2p.Protocol{{ + Length: 100, + Run: t.run, + }} +} + +func (t *testNode) APIs() []rpc.API { + return nil +} + +func (t *testNode) Start(server *p2p.Server) error { + return nil +} + +func (t *testNode) Stop() error { + return nil +} + +// mockNode is a testNode which doesn't actually run a protocol, instead +// exposing channels so that tests can manually trigger and expect certain +// messages +type mockNode struct { + testNode + + trigger chan *Trigger + expect chan []Expect + err chan error + stop chan struct{} + stopOnce sync.Once +} + +func newMockNode() *mockNode { + mock := &mockNode{ + trigger: make(chan *Trigger), + expect: make(chan []Expect), + err: make(chan error), + stop: make(chan struct{}), + } + mock.testNode.run = mock.Run + return mock +} + +// Run is a protocol run function which just loops waiting for tests to +// instruct it to either trigger or expect a message from the peer +func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + for { + select { + case trig := <-m.trigger: + m.err <- p2p.Send(rw, trig.Code, trig.Msg) + case exps := <-m.expect: + m.err <- expectMsgs(rw, exps) + case <-m.stop: + return nil + } + } +} + +func (m *mockNode) Trigger(trig *Trigger) error { + m.trigger <- trig + return <-m.err +} + +func (m *mockNode) Expect(exp ...Expect) error { + m.expect <- exp + return <-m.err +} + +func (m *mockNode) Stop() error { + m.stopOnce.Do(func() { close(m.stop) }) + return nil +} + +func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { + matched := make([]bool, len(exps)) + for { + msg, err := rw.ReadMsg() + if err != nil { + if err == io.EOF { + break + } + return err + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + var found bool + for i, exp := range exps { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if matched[i] { + return fmt.Errorf("message #%d received two times", i) + } + matched[i] = true + found = true + break + } + } + if !found { + expected := make([]string, 0) + for i, exp := range exps { + if matched[i] { + continue + } + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + } + return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) + } + done := true + for _, m := range matched { + if !m { + done = false + break + } + } + if done { + return nil + } + } + for i, m := range matched { + if !m { + return fmt.Errorf("expected message #%d not received", i) + } + } + return nil +} + +// mustEncodeMsg uses rlp to encode a message. +// In case of error it panics. +func mustEncodeMsg(msg interface{}) []byte { + contentEnc, err := rlp.EncodeToBytes(msg) + if err != nil { + panic("content encode error: " + err.Error()) + } + return contentEnc +} diff --git a/params/config.go b/params/config.go index 345f6394a..dc02c7ca3 100644 --- a/params/config.go +++ b/params/config.go @@ -31,46 +31,46 @@ var ( var ( // MainnetChainConfig is the chain parameters to run a node on the main network. MainnetChainConfig = &ChainConfig{ - ChainId: big.NewInt(1), - HomesteadBlock: big.NewInt(1150000), - DAOForkBlock: big.NewInt(1920000), - DAOForkSupport: true, - EIP150Block: big.NewInt(2463000), - EIP150Hash: common.HexToHash("0x2086799aeebeae135c246c65021c82b4e15a2c451340993aacfd2751886514f0"), - EIP155Block: big.NewInt(2675000), - EIP158Block: big.NewInt(2675000), - ByzantiumBlock: big.NewInt(4370000), - - Ethash: new(EthashConfig), + ChainId: big.NewInt(1), + HomesteadBlock: big.NewInt(1150000), + DAOForkBlock: big.NewInt(1920000), + DAOForkSupport: true, + EIP150Block: big.NewInt(2463000), + EIP150Hash: common.HexToHash("0x2086799aeebeae135c246c65021c82b4e15a2c451340993aacfd2751886514f0"), + EIP155Block: big.NewInt(2675000), + EIP158Block: big.NewInt(2675000), + ByzantiumBlock: big.NewInt(4370000), + ConstantinopleBlock: nil, + Ethash: new(EthashConfig), } // TestnetChainConfig contains the chain parameters to run a node on the Ropsten test network. TestnetChainConfig = &ChainConfig{ - ChainId: big.NewInt(3), - HomesteadBlock: big.NewInt(0), - DAOForkBlock: nil, - DAOForkSupport: true, - EIP150Block: big.NewInt(0), - EIP150Hash: common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d"), - EIP155Block: big.NewInt(10), - EIP158Block: big.NewInt(10), - ByzantiumBlock: big.NewInt(1700000), - - Ethash: new(EthashConfig), + ChainId: big.NewInt(3), + HomesteadBlock: big.NewInt(0), + DAOForkBlock: nil, + DAOForkSupport: true, + EIP150Block: big.NewInt(0), + EIP150Hash: common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d"), + EIP155Block: big.NewInt(10), + EIP158Block: big.NewInt(10), + ByzantiumBlock: big.NewInt(1700000), + ConstantinopleBlock: nil, + Ethash: new(EthashConfig), } // RinkebyChainConfig contains the chain parameters to run a node on the Rinkeby test network. RinkebyChainConfig = &ChainConfig{ - ChainId: big.NewInt(4), - HomesteadBlock: big.NewInt(1), - DAOForkBlock: nil, - DAOForkSupport: true, - EIP150Block: big.NewInt(2), - EIP150Hash: common.HexToHash("0x9b095b36c15eaf13044373aef8ee0bd3a382a5abb92e402afa44b8249c3a90e9"), - EIP155Block: big.NewInt(3), - EIP158Block: big.NewInt(3), - ByzantiumBlock: big.NewInt(1035301), - + ChainId: big.NewInt(4), + HomesteadBlock: big.NewInt(1), + DAOForkBlock: nil, + DAOForkSupport: true, + EIP150Block: big.NewInt(2), + EIP150Hash: common.HexToHash("0x9b095b36c15eaf13044373aef8ee0bd3a382a5abb92e402afa44b8249c3a90e9"), + EIP155Block: big.NewInt(3), + EIP158Block: big.NewInt(3), + ByzantiumBlock: big.NewInt(1035301), + ConstantinopleBlock: nil, Clique: &CliqueConfig{ Period: 15, Epoch: 30000, @@ -82,16 +82,16 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil} // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Clique consensus. // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, &CliqueConfig{Period: 0, Epoch: 30000}} + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil} TestRules = TestChainConfig.Rules(new(big.Int)) ) @@ -115,7 +115,8 @@ type ChainConfig struct { EIP155Block *big.Int `json:"eip155Block,omitempty"` // EIP155 HF block EIP158Block *big.Int `json:"eip158Block,omitempty"` // EIP158 HF block - ByzantiumBlock *big.Int `json:"byzantiumBlock,omitempty"` // Byzantium switch block (nil = no fork, 0 = already on byzantium) + ByzantiumBlock *big.Int `json:"byzantiumBlock,omitempty"` // Byzantium switch block (nil = no fork, 0 = already on byzantium) + ConstantinopleBlock *big.Int `json:"constantinopleBlock,omitempty"` // Constantinople switch block (nil = no fork, 0 = already activated) // Various consensus engines Ethash *EthashConfig `json:"ethash,omitempty"` @@ -152,7 +153,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Engine: %v}", c.ChainId, c.HomesteadBlock, c.DAOForkBlock, @@ -161,6 +162,7 @@ func (c *ChainConfig) String() string { c.EIP155Block, c.EIP158Block, c.ByzantiumBlock, + c.ConstantinopleBlock, engine, ) } @@ -191,6 +193,10 @@ func (c *ChainConfig) IsByzantium(num *big.Int) bool { return isForked(c.ByzantiumBlock, num) } +func (c *ChainConfig) IsConstantinople(num *big.Int) bool { + return isForked(c.ConstantinopleBlock, num) +} + // GasTable returns the gas table corresponding to the current phase (homestead or homestead reprice). // // The returned GasTable's fields shouldn't, under any circumstances, be changed. @@ -251,6 +257,9 @@ func (c *ChainConfig) checkCompatible(newcfg *ChainConfig, head *big.Int) *Confi if isForkIncompatible(c.ByzantiumBlock, newcfg.ByzantiumBlock, head) { return newCompatError("Byzantium fork block", c.ByzantiumBlock, newcfg.ByzantiumBlock) } + if isForkIncompatible(c.ConstantinopleBlock, newcfg.ConstantinopleBlock, head) { + return newCompatError("Constantinople fork block", c.ConstantinopleBlock, newcfg.ConstantinopleBlock) + } return nil } diff --git a/params/version.go b/params/version.go index 277585934..921d07599 100644 --- a/params/version.go +++ b/params/version.go @@ -23,7 +23,7 @@ import ( const ( VersionMajor = 1 // Major version component of the current release VersionMinor = 8 // Minor version component of the current release - VersionPatch = 1 // Patch version component of the current release + VersionPatch = 3 // Patch version component of the current release VersionMeta = "unstable" // Version metadata to append to the version string ) diff --git a/rpc/server.go b/rpc/server.go index 30c288349..11373b504 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -421,7 +421,7 @@ func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, Error) } } } else { - requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.method, r.method}} + requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}} } continue } diff --git a/swarm/api/api.go b/swarm/api/api.go index 8c4bca2ec..0cf12fdbe 100644 --- a/swarm/api/api.go +++ b/swarm/api/api.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "path" "regexp" "strings" "sync" @@ -31,15 +32,110 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/swarm/storage" ) var hashMatcher = regexp.MustCompile("^[0-9A-Fa-f]{64}") +//setup metrics +var ( + apiResolveCount = metrics.NewRegisteredCounter("api.resolve.count", nil) + apiResolveFail = metrics.NewRegisteredCounter("api.resolve.fail", nil) + apiPutCount = metrics.NewRegisteredCounter("api.put.count", nil) + apiPutFail = metrics.NewRegisteredCounter("api.put.fail", nil) + apiGetCount = metrics.NewRegisteredCounter("api.get.count", nil) + apiGetNotFound = metrics.NewRegisteredCounter("api.get.notfound", nil) + apiGetHttp300 = metrics.NewRegisteredCounter("api.get.http.300", nil) + apiModifyCount = metrics.NewRegisteredCounter("api.modify.count", nil) + apiModifyFail = metrics.NewRegisteredCounter("api.modify.fail", nil) + apiAddFileCount = metrics.NewRegisteredCounter("api.addfile.count", nil) + apiAddFileFail = metrics.NewRegisteredCounter("api.addfile.fail", nil) + apiRmFileCount = metrics.NewRegisteredCounter("api.removefile.count", nil) + apiRmFileFail = metrics.NewRegisteredCounter("api.removefile.fail", nil) + apiAppendFileCount = metrics.NewRegisteredCounter("api.appendfile.count", nil) + apiAppendFileFail = metrics.NewRegisteredCounter("api.appendfile.fail", nil) +) + type Resolver interface { Resolve(string) (common.Hash, error) } +// NoResolverError is returned by MultiResolver.Resolve if no resolver +// can be found for the address. +type NoResolverError struct { + TLD string +} + +func NewNoResolverError(tld string) *NoResolverError { + return &NoResolverError{TLD: tld} +} + +func (e *NoResolverError) Error() string { + if e.TLD == "" { + return "no ENS resolver" + } + return fmt.Sprintf("no ENS endpoint configured to resolve .%s TLD names", e.TLD) +} + +// MultiResolver is used to resolve URL addresses based on their TLDs. +// Each TLD can have multiple resolvers, and the resoluton from the +// first one in the sequence will be returned. +type MultiResolver struct { + resolvers map[string][]Resolver +} + +// MultiResolverOption sets options for MultiResolver and is used as +// arguments for its constructor. +type MultiResolverOption func(*MultiResolver) + +// MultiResolverOptionWithResolver adds a Resolver to a list of resolvers +// for a specific TLD. If TLD is an empty string, the resolver will be added +// to the list of default resolver, the ones that will be used for resolution +// of addresses which do not have their TLD resolver specified. +func MultiResolverOptionWithResolver(r Resolver, tld string) MultiResolverOption { + return func(m *MultiResolver) { + m.resolvers[tld] = append(m.resolvers[tld], r) + } +} + +// NewMultiResolver creates a new instance of MultiResolver. +func NewMultiResolver(opts ...MultiResolverOption) (m *MultiResolver) { + m = &MultiResolver{ + resolvers: make(map[string][]Resolver), + } + for _, o := range opts { + o(m) + } + return m +} + +// Resolve resolves address by choosing a Resolver by TLD. +// If there are more default Resolvers, or for a specific TLD, +// the Hash from the the first one which does not return error +// will be returned. +func (m MultiResolver) Resolve(addr string) (h common.Hash, err error) { + rs := m.resolvers[""] + tld := path.Ext(addr) + if tld != "" { + tld = tld[1:] + rstld, ok := m.resolvers[tld] + if ok { + rs = rstld + } + } + if rs == nil { + return h, NewNoResolverError(tld) + } + for _, r := range rs { + h, err = r.Resolve(addr) + if err == nil { + return + } + } + return +} + /* Api implements webserver/file system related content storage and retrieval on top of the dpa @@ -79,6 +175,7 @@ type ErrResolve error // DNS Resolver func (self *Api) Resolve(uri *URI) (storage.Key, error) { + apiResolveCount.Inc(1) log.Trace(fmt.Sprintf("Resolving : %v", uri.Addr)) // if the URI is immutable, check if the address is a hash @@ -93,6 +190,7 @@ func (self *Api) Resolve(uri *URI) (storage.Key, error) { // if DNS is not configured, check if the address is a hash if self.dns == nil { if !isHash { + apiResolveFail.Inc(1) return nil, fmt.Errorf("no DNS to resolve name: %q", uri.Addr) } return common.Hex2Bytes(uri.Addr), nil @@ -103,6 +201,7 @@ func (self *Api) Resolve(uri *URI) (storage.Key, error) { if err == nil { return resolved[:], nil } else if !isHash { + apiResolveFail.Inc(1) return nil, err } return common.Hex2Bytes(uri.Addr), nil @@ -110,16 +209,19 @@ func (self *Api) Resolve(uri *URI) (storage.Key, error) { // Put provides singleton manifest creation on top of dpa store func (self *Api) Put(content, contentType string) (storage.Key, error) { + apiPutCount.Inc(1) r := strings.NewReader(content) wg := &sync.WaitGroup{} key, err := self.dpa.Store(r, int64(len(content)), wg, nil) if err != nil { + apiPutFail.Inc(1) return nil, err } manifest := fmt.Sprintf(`{"entries":[{"hash":"%v","contentType":"%s"}]}`, key, contentType) r = strings.NewReader(manifest) key, err = self.dpa.Store(r, int64(len(manifest)), wg, nil) if err != nil { + apiPutFail.Inc(1) return nil, err } wg.Wait() @@ -130,8 +232,10 @@ func (self *Api) Put(content, contentType string) (storage.Key, error) { // to resolve basePath to content using dpa retrieve // it returns a section reader, mimeType, status and an error func (self *Api) Get(key storage.Key, path string) (reader storage.LazySectionReader, mimeType string, status int, err error) { + apiGetCount.Inc(1) trie, err := loadManifest(self.dpa, key, nil) if err != nil { + apiGetNotFound.Inc(1) status = http.StatusNotFound log.Warn(fmt.Sprintf("loadManifestTrie error: %v", err)) return @@ -145,6 +249,7 @@ func (self *Api) Get(key storage.Key, path string) (reader storage.LazySectionRe key = common.Hex2Bytes(entry.Hash) status = entry.Status if status == http.StatusMultipleChoices { + apiGetHttp300.Inc(1) return } else { mimeType = entry.ContentType @@ -153,6 +258,7 @@ func (self *Api) Get(key storage.Key, path string) (reader storage.LazySectionRe } } else { status = http.StatusNotFound + apiGetNotFound.Inc(1) err = fmt.Errorf("manifest entry for '%s' not found", path) log.Warn(fmt.Sprintf("%v", err)) } @@ -160,9 +266,11 @@ func (self *Api) Get(key storage.Key, path string) (reader storage.LazySectionRe } func (self *Api) Modify(key storage.Key, path, contentHash, contentType string) (storage.Key, error) { + apiModifyCount.Inc(1) quitC := make(chan bool) trie, err := loadManifest(self.dpa, key, quitC) if err != nil { + apiModifyFail.Inc(1) return nil, err } if contentHash != "" { @@ -177,19 +285,23 @@ func (self *Api) Modify(key storage.Key, path, contentHash, contentType string) } if err := trie.recalcAndStore(); err != nil { + apiModifyFail.Inc(1) return nil, err } return trie.hash, nil } func (self *Api) AddFile(mhash, path, fname string, content []byte, nameresolver bool) (storage.Key, string, error) { + apiAddFileCount.Inc(1) uri, err := Parse("bzz:/" + mhash) if err != nil { + apiAddFileFail.Inc(1) return nil, "", err } mkey, err := self.Resolve(uri) if err != nil { + apiAddFileFail.Inc(1) return nil, "", err } @@ -208,16 +320,19 @@ func (self *Api) AddFile(mhash, path, fname string, content []byte, nameresolver mw, err := self.NewManifestWriter(mkey, nil) if err != nil { + apiAddFileFail.Inc(1) return nil, "", err } fkey, err := mw.AddEntry(bytes.NewReader(content), entry) if err != nil { + apiAddFileFail.Inc(1) return nil, "", err } newMkey, err := mw.Store() if err != nil { + apiAddFileFail.Inc(1) return nil, "", err } @@ -227,13 +342,16 @@ func (self *Api) AddFile(mhash, path, fname string, content []byte, nameresolver } func (self *Api) RemoveFile(mhash, path, fname string, nameresolver bool) (string, error) { + apiRmFileCount.Inc(1) uri, err := Parse("bzz:/" + mhash) if err != nil { + apiRmFileFail.Inc(1) return "", err } mkey, err := self.Resolve(uri) if err != nil { + apiRmFileFail.Inc(1) return "", err } @@ -244,16 +362,19 @@ func (self *Api) RemoveFile(mhash, path, fname string, nameresolver bool) (strin mw, err := self.NewManifestWriter(mkey, nil) if err != nil { + apiRmFileFail.Inc(1) return "", err } err = mw.RemoveEntry(filepath.Join(path, fname)) if err != nil { + apiRmFileFail.Inc(1) return "", err } newMkey, err := mw.Store() if err != nil { + apiRmFileFail.Inc(1) return "", err } @@ -262,6 +383,7 @@ func (self *Api) RemoveFile(mhash, path, fname string, nameresolver bool) (strin } func (self *Api) AppendFile(mhash, path, fname string, existingSize int64, content []byte, oldKey storage.Key, offset int64, addSize int64, nameresolver bool) (storage.Key, string, error) { + apiAppendFileCount.Inc(1) buffSize := offset + addSize if buffSize < existingSize { @@ -290,10 +412,12 @@ func (self *Api) AppendFile(mhash, path, fname string, existingSize int64, conte uri, err := Parse("bzz:/" + mhash) if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } mkey, err := self.Resolve(uri) if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } @@ -304,11 +428,13 @@ func (self *Api) AppendFile(mhash, path, fname string, existingSize int64, conte mw, err := self.NewManifestWriter(mkey, nil) if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } err = mw.RemoveEntry(filepath.Join(path, fname)) if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } @@ -322,11 +448,13 @@ func (self *Api) AppendFile(mhash, path, fname string, existingSize int64, conte fkey, err := mw.AddEntry(io.Reader(combinedReader), entry) if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } newMkey, err := mw.Store() if err != nil { + apiAppendFileFail.Inc(1) return nil, "", err } @@ -336,6 +464,7 @@ func (self *Api) AppendFile(mhash, path, fname string, existingSize int64, conte } func (self *Api) BuildDirectoryTree(mhash string, nameresolver bool) (key storage.Key, manifestEntryMap map[string]*manifestTrieEntry, err error) { + uri, err := Parse("bzz:/" + mhash) if err != nil { return nil, nil, err diff --git a/swarm/api/api_test.go b/swarm/api/api_test.go index e673f76c4..4ee26bd8a 100644 --- a/swarm/api/api_test.go +++ b/swarm/api/api_test.go @@ -237,3 +237,128 @@ func TestAPIResolve(t *testing.T) { }) } } + +func TestMultiResolver(t *testing.T) { + doesntResolve := newTestResolver("") + + ethAddr := "swarm.eth" + ethHash := "0x2222222222222222222222222222222222222222222222222222222222222222" + ethResolve := newTestResolver(ethHash) + + testAddr := "swarm.test" + testHash := "0x1111111111111111111111111111111111111111111111111111111111111111" + testResolve := newTestResolver(testHash) + + tests := []struct { + desc string + r Resolver + addr string + result string + err error + }{ + { + desc: "No resolvers, returns error", + r: NewMultiResolver(), + err: NewNoResolverError(""), + }, + { + desc: "One default resolver, returns resolved address", + r: NewMultiResolver(MultiResolverOptionWithResolver(ethResolve, "")), + addr: ethAddr, + result: ethHash, + }, + { + desc: "Two default resolvers, returns resolved address", + r: NewMultiResolver( + MultiResolverOptionWithResolver(ethResolve, ""), + MultiResolverOptionWithResolver(ethResolve, ""), + ), + addr: ethAddr, + result: ethHash, + }, + { + desc: "Two default resolvers, first doesn't resolve, returns resolved address", + r: NewMultiResolver( + MultiResolverOptionWithResolver(doesntResolve, ""), + MultiResolverOptionWithResolver(ethResolve, ""), + ), + addr: ethAddr, + result: ethHash, + }, + { + desc: "Default resolver doesn't resolve, tld resolver resolve, returns resolved address", + r: NewMultiResolver( + MultiResolverOptionWithResolver(doesntResolve, ""), + MultiResolverOptionWithResolver(ethResolve, "eth"), + ), + addr: ethAddr, + result: ethHash, + }, + { + desc: "Three TLD resolvers, third resolves, returns resolved address", + r: NewMultiResolver( + MultiResolverOptionWithResolver(doesntResolve, "eth"), + MultiResolverOptionWithResolver(doesntResolve, "eth"), + MultiResolverOptionWithResolver(ethResolve, "eth"), + ), + addr: ethAddr, + result: ethHash, + }, + { + desc: "One TLD resolver doesn't resolve, returns error", + r: NewMultiResolver( + MultiResolverOptionWithResolver(doesntResolve, ""), + MultiResolverOptionWithResolver(ethResolve, "eth"), + ), + addr: ethAddr, + result: ethHash, + }, + { + desc: "One defautl and one TLD resolver, all doesn't resolve, returns error", + r: NewMultiResolver( + MultiResolverOptionWithResolver(doesntResolve, ""), + MultiResolverOptionWithResolver(doesntResolve, "eth"), + ), + addr: ethAddr, + result: ethHash, + err: errors.New(`DNS name not found: "swarm.eth"`), + }, + { + desc: "Two TLD resolvers, both resolve, returns resolved address", + r: NewMultiResolver( + MultiResolverOptionWithResolver(ethResolve, "eth"), + MultiResolverOptionWithResolver(testResolve, "test"), + ), + addr: testAddr, + result: testHash, + }, + { + desc: "One TLD resolver, no default resolver, returns error for different TLD", + r: NewMultiResolver( + MultiResolverOptionWithResolver(ethResolve, "eth"), + ), + addr: testAddr, + err: NewNoResolverError("test"), + }, + } + for _, x := range tests { + t.Run(x.desc, func(t *testing.T) { + res, err := x.r.Resolve(x.addr) + if err == nil { + if x.err != nil { + t.Fatalf("expected error %q, got result %q", x.err, res.Hex()) + } + if res.Hex() != x.result { + t.Fatalf("expected result %q, got %q", x.result, res.Hex()) + } + } else { + if x.err == nil { + t.Fatalf("expected no error, got %q", err) + } + if err.Error() != x.err.Error() { + t.Fatalf("expected error %q, got %q", x.err, err) + } + } + }) + } +} diff --git a/swarm/api/config.go b/swarm/api/config.go index 140c938ae..6b224140a 100644 --- a/swarm/api/config.go +++ b/swarm/api/config.go @@ -48,7 +48,7 @@ type Config struct { *network.SyncParams Contract common.Address EnsRoot common.Address - EnsApi string + EnsAPIs []string Path string ListenAddr string Port string @@ -75,7 +75,7 @@ func NewDefaultConfig() (self *Config) { ListenAddr: DefaultHTTPListenAddr, Port: DefaultHTTPPort, Path: node.DefaultDataDir(), - EnsApi: node.DefaultIPCEndpoint("geth"), + EnsAPIs: nil, EnsRoot: ens.TestNetAddress, NetworkId: network.NetworkId, SwapEnabled: false, diff --git a/swarm/api/http/error.go b/swarm/api/http/error.go index dbd97182f..9a65412cf 100644 --- a/swarm/api/http/error.go +++ b/swarm/api/http/error.go @@ -29,11 +29,19 @@ import ( "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/swarm/api" ) //templateMap holds a mapping of an HTTP error code to a template var templateMap map[int]*template.Template +var caseErrors []CaseError + +//metrics variables +var ( + htmlCounter = metrics.NewRegisteredCounter("api.http.errorpage.html.count", nil) + jsonCounter = metrics.NewRegisteredCounter("api.http.errorpage.json.count", nil) +) //parameters needed for formatting the correct HTML page type ErrorParams struct { @@ -44,6 +52,13 @@ type ErrorParams struct { Details template.HTML } +//a custom error case struct that would be used to store validators and +//additional error info to display with client responses. +type CaseError struct { + Validator func(*Request) bool + Msg func(*Request) string +} + //we init the error handling right on boot time, so lookup and http response is fast func init() { initErrHandling() @@ -67,6 +82,29 @@ func initErrHandling() { //assign formatted HTML to the code templateMap[code] = template.Must(template.New(fmt.Sprintf("%d", code)).Parse(tname)) } + + caseErrors = []CaseError{ + { + Validator: func(r *Request) bool { return r.uri != nil && r.uri.Addr != "" && strings.HasPrefix(r.uri.Addr, "0x") }, + Msg: func(r *Request) string { + uriCopy := r.uri + uriCopy.Addr = strings.TrimPrefix(uriCopy.Addr, "0x") + return fmt.Sprintf(`The requested hash seems to be prefixed with '0x'. You will be redirected to the correct URL within 5 seconds.<br/> + Please click <a href='%[1]s'>here</a> if your browser does not redirect you.<script>setTimeout("location.href='%[1]s';",5000);</script>`, "/"+uriCopy.String()) + }, + }} +} + +//ValidateCaseErrors is a method that process the request object through certain validators +//that assert if certain conditions are met for further information to log as an error +func ValidateCaseErrors(r *Request) string { + for _, err := range caseErrors { + if err.Validator(r) { + return err.Msg(r) + } + } + + return "" } //ShowMultipeChoices is used when a user requests a resource in a manifest which results @@ -75,10 +113,10 @@ func initErrHandling() { //For example, if the user requests bzz:/<hash>/read and that manifest contains entries //"readme.md" and "readinglist.txt", a HTML page is returned with this two links. //This only applies if the manifest has no default entry -func ShowMultipleChoices(w http.ResponseWriter, r *http.Request, list api.ManifestList) { +func ShowMultipleChoices(w http.ResponseWriter, r *Request, list api.ManifestList) { msg := "" if list.Entries == nil { - ShowError(w, r, "Internal Server Error", http.StatusInternalServerError) + ShowError(w, r, "Could not resolve", http.StatusInternalServerError) return } //make links relative @@ -95,7 +133,7 @@ func ShowMultipleChoices(w http.ResponseWriter, r *http.Request, list api.Manife //create clickable link for each entry msg += "<a href='" + base + e.Path + "'>" + e.Path + "</a><br/>" } - respond(w, r, &ErrorParams{ + respond(w, &r.Request, &ErrorParams{ Code: http.StatusMultipleChoices, Details: template.HTML(msg), Timestamp: time.Now().Format(time.RFC1123), @@ -108,13 +146,15 @@ func ShowMultipleChoices(w http.ResponseWriter, r *http.Request, list api.Manife //The function just takes a string message which will be displayed in the error page. //The code is used to evaluate which template will be displayed //(and return the correct HTTP status code) -func ShowError(w http.ResponseWriter, r *http.Request, msg string, code int) { +func ShowError(w http.ResponseWriter, r *Request, msg string, code int) { + additionalMessage := ValidateCaseErrors(r) if code == http.StatusInternalServerError { log.Error(msg) } - respond(w, r, &ErrorParams{ + respond(w, &r.Request, &ErrorParams{ Code: code, Msg: msg, + Details: template.HTML(additionalMessage), Timestamp: time.Now().Format(time.RFC1123), template: getTemplate(code), }) @@ -132,6 +172,7 @@ func respond(w http.ResponseWriter, r *http.Request, params *ErrorParams) { //return a HTML page func respondHtml(w http.ResponseWriter, params *ErrorParams) { + htmlCounter.Inc(1) err := params.template.Execute(w, params) if err != nil { log.Error(err.Error()) @@ -140,6 +181,7 @@ func respondHtml(w http.ResponseWriter, params *ErrorParams) { //return JSON func respondJson(w http.ResponseWriter, params *ErrorParams) { + jsonCounter.Inc(1) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(params) } diff --git a/swarm/api/http/error_templates.go b/swarm/api/http/error_templates.go index 0457cb8a7..cc9b996ba 100644 --- a/swarm/api/http/error_templates.go +++ b/swarm/api/http/error_templates.go @@ -168,6 +168,11 @@ func GetGenericErrorPage() string { {{.Msg}} </td> </tr> + <tr> + <td class="value"> + {{.Details}} + </td> + </tr> <tr> <td class="key"> @@ -342,6 +347,12 @@ func GetNotFoundErrorPage() string { {{.Msg}} </td> </tr> + <tr> + <td class="value"> + {{.Details}} + </td> + </tr> + <tr> <td class="key"> diff --git a/swarm/api/http/error_test.go b/swarm/api/http/error_test.go index c2c8b908b..dc545722e 100644 --- a/swarm/api/http/error_test.go +++ b/swarm/api/http/error_test.go @@ -18,12 +18,13 @@ package http_test import ( "encoding/json" - "golang.org/x/net/html" "io/ioutil" "net/http" "strings" "testing" + "golang.org/x/net/html" + "github.com/ethereum/go-ethereum/swarm/testutil" ) @@ -96,8 +97,37 @@ func Test500Page(t *testing.T) { defer resp.Body.Close() respbody, err = ioutil.ReadAll(resp.Body) - if resp.StatusCode != 500 || !strings.Contains(string(respbody), "500") { - t.Fatalf("Invalid Status Code received, expected 500, got %d", resp.StatusCode) + if resp.StatusCode != 404 { + t.Fatalf("Invalid Status Code received, expected 404, got %d", resp.StatusCode) + } + + _, err = html.Parse(strings.NewReader(string(respbody))) + if err != nil { + t.Fatalf("HTML validation failed for error page returned!") + } +} +func Test500PageWith0xHashPrefix(t *testing.T) { + srv := testutil.NewTestSwarmServer(t) + defer srv.Close() + + var resp *http.Response + var respbody []byte + + url := srv.URL + "/bzz:/0xthisShouldFailWith500CodeAndAHelpfulMessage" + resp, err := http.Get(url) + + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + respbody, err = ioutil.ReadAll(resp.Body) + + if resp.StatusCode != 404 { + t.Fatalf("Invalid Status Code received, expected 404, got %d", resp.StatusCode) + } + + if !strings.Contains(string(respbody), "The requested hash seems to be prefixed with") { + t.Fatalf("Did not receive the expected error message") } _, err = html.Parse(strings.NewReader(string(respbody))) @@ -127,8 +157,8 @@ func TestJsonResponse(t *testing.T) { defer resp.Body.Close() respbody, err = ioutil.ReadAll(resp.Body) - if resp.StatusCode != 500 { - t.Fatalf("Invalid Status Code received, expected 500, got %d", resp.StatusCode) + if resp.StatusCode != 404 { + t.Fatalf("Invalid Status Code received, expected 404, got %d", resp.StatusCode) } if !isJSON(string(respbody)) { diff --git a/swarm/api/http/server.go b/swarm/api/http/server.go index 74341899d..b8e7436cf 100644 --- a/swarm/api/http/server.go +++ b/swarm/api/http/server.go @@ -37,11 +37,35 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/swarm/api" "github.com/ethereum/go-ethereum/swarm/storage" "github.com/rs/cors" ) +//setup metrics +var ( + postRawCount = metrics.NewRegisteredCounter("api.http.post.raw.count", nil) + postRawFail = metrics.NewRegisteredCounter("api.http.post.raw.fail", nil) + postFilesCount = metrics.NewRegisteredCounter("api.http.post.files.count", nil) + postFilesFail = metrics.NewRegisteredCounter("api.http.post.files.fail", nil) + deleteCount = metrics.NewRegisteredCounter("api.http.delete.count", nil) + deleteFail = metrics.NewRegisteredCounter("api.http.delete.fail", nil) + getCount = metrics.NewRegisteredCounter("api.http.get.count", nil) + getFail = metrics.NewRegisteredCounter("api.http.get.fail", nil) + getFileCount = metrics.NewRegisteredCounter("api.http.get.file.count", nil) + getFileNotFound = metrics.NewRegisteredCounter("api.http.get.file.notfound", nil) + getFileFail = metrics.NewRegisteredCounter("api.http.get.file.fail", nil) + getFilesCount = metrics.NewRegisteredCounter("api.http.get.files.count", nil) + getFilesFail = metrics.NewRegisteredCounter("api.http.get.files.fail", nil) + getListCount = metrics.NewRegisteredCounter("api.http.get.list.count", nil) + getListFail = metrics.NewRegisteredCounter("api.http.get.list.fail", nil) + requestCount = metrics.NewRegisteredCounter("http.request.count", nil) + htmlRequestCount = metrics.NewRegisteredCounter("http.request.html.count", nil) + jsonRequestCount = metrics.NewRegisteredCounter("http.request.json.count", nil) + requestTimer = metrics.NewRegisteredResettingTimer("http.request.time", nil) +) + // ServerConfig is the basic configuration needed for the HTTP server and also // includes CORS settings. type ServerConfig struct { @@ -89,18 +113,22 @@ type Request struct { // HandlePostRaw handles a POST request to a raw bzz-raw:/ URI, stores the request // body in swarm and returns the resulting storage key as a text/plain response func (s *Server) HandlePostRaw(w http.ResponseWriter, r *Request) { + postRawCount.Inc(1) if r.uri.Path != "" { + postRawFail.Inc(1) s.BadRequest(w, r, "raw POST request cannot contain a path") return } if r.Header.Get("Content-Length") == "" { + postRawFail.Inc(1) s.BadRequest(w, r, "missing Content-Length header in request") return } key, err := s.api.Store(r.Body, r.ContentLength, nil) if err != nil { + postRawFail.Inc(1) s.Error(w, r, err) return } @@ -117,8 +145,10 @@ func (s *Server) HandlePostRaw(w http.ResponseWriter, r *Request) { // existing manifest or to a new manifest under <path> and returns the // resulting manifest hash as a text/plain response func (s *Server) HandlePostFiles(w http.ResponseWriter, r *Request) { + postFilesCount.Inc(1) contentType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { + postFilesFail.Inc(1) s.BadRequest(w, r, err.Error()) return } @@ -127,12 +157,14 @@ func (s *Server) HandlePostFiles(w http.ResponseWriter, r *Request) { if r.uri.Addr != "" { key, err = s.api.Resolve(r.uri) if err != nil { + postFilesFail.Inc(1) s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } } else { key, err = s.api.NewManifest() if err != nil { + postFilesFail.Inc(1) s.Error(w, r, err) return } @@ -152,6 +184,7 @@ func (s *Server) HandlePostFiles(w http.ResponseWriter, r *Request) { } }) if err != nil { + postFilesFail.Inc(1) s.Error(w, r, fmt.Errorf("error creating manifest: %s", err)) return } @@ -270,8 +303,10 @@ func (s *Server) handleDirectUpload(req *Request, mw *api.ManifestWriter) error // <path> from <manifest> and returns the resulting manifest hash as a // text/plain response func (s *Server) HandleDelete(w http.ResponseWriter, r *Request) { + deleteCount.Inc(1) key, err := s.api.Resolve(r.uri) if err != nil { + deleteFail.Inc(1) s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } @@ -281,6 +316,7 @@ func (s *Server) HandleDelete(w http.ResponseWriter, r *Request) { return mw.RemoveEntry(r.uri.Path) }) if err != nil { + deleteFail.Inc(1) s.Error(w, r, fmt.Errorf("error updating manifest: %s", err)) return } @@ -296,9 +332,11 @@ func (s *Server) HandleDelete(w http.ResponseWriter, r *Request) { // - bzz-hash://<key> and responds with the hash of the content stored // at the given storage key as a text/plain response func (s *Server) HandleGet(w http.ResponseWriter, r *Request) { + getCount.Inc(1) key, err := s.api.Resolve(r.uri) if err != nil { - s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) + getFail.Inc(1) + s.NotFound(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } @@ -307,6 +345,7 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *Request) { if r.uri.Path != "" { walker, err := s.api.NewManifestWalker(key, nil) if err != nil { + getFail.Inc(1) s.BadRequest(w, r, fmt.Sprintf("%s is not a manifest", key)) return } @@ -335,6 +374,7 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *Request) { return api.SkipManifest }) if entry == nil { + getFail.Inc(1) s.NotFound(w, r, fmt.Errorf("Manifest entry could not be loaded")) return } @@ -344,12 +384,13 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *Request) { // check the root chunk exists by retrieving the file's size reader := s.api.Retrieve(key) if _, err := reader.Size(nil); err != nil { + getFail.Inc(1) s.NotFound(w, r, fmt.Errorf("Root chunk not found %s: %s", key, err)) return } switch { - case r.uri.Raw(): + case r.uri.Raw() || r.uri.DeprecatedRaw(): // allow the request to overwrite the content type using a query // parameter contentType := "application/octet-stream" @@ -370,19 +411,23 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *Request) { // header of "application/x-tar" and returns a tar stream of all files // contained in the manifest func (s *Server) HandleGetFiles(w http.ResponseWriter, r *Request) { + getFilesCount.Inc(1) if r.uri.Path != "" { + getFilesFail.Inc(1) s.BadRequest(w, r, "files request cannot contain a path") return } key, err := s.api.Resolve(r.uri) if err != nil { - s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) + getFilesFail.Inc(1) + s.NotFound(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } walker, err := s.api.NewManifestWalker(key, nil) if err != nil { + getFilesFail.Inc(1) s.Error(w, r, err) return } @@ -430,6 +475,7 @@ func (s *Server) HandleGetFiles(w http.ResponseWriter, r *Request) { return nil }) if err != nil { + getFilesFail.Inc(1) s.logError("error generating tar stream: %s", err) } } @@ -438,6 +484,7 @@ func (s *Server) HandleGetFiles(w http.ResponseWriter, r *Request) { // a list of all files contained in <manifest> under <path> grouped into // common prefixes using "/" as a delimiter func (s *Server) HandleGetList(w http.ResponseWriter, r *Request) { + getListCount.Inc(1) // ensure the root path has a trailing slash so that relative URLs work if r.uri.Path == "" && !strings.HasSuffix(r.URL.Path, "/") { http.Redirect(w, &r.Request, r.URL.Path+"/", http.StatusMovedPermanently) @@ -446,13 +493,15 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *Request) { key, err := s.api.Resolve(r.uri) if err != nil { - s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) + getListFail.Inc(1) + s.NotFound(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } list, err := s.getManifestList(key, r.uri.Path) if err != nil { + getListFail.Inc(1) s.Error(w, r, err) return } @@ -470,6 +519,7 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *Request) { List: &list, }) if err != nil { + getListFail.Inc(1) s.logError("error rendering list HTML: %s", err) } return @@ -538,6 +588,7 @@ func (s *Server) getManifestList(key storage.Key, prefix string) (list api.Manif // HandleGetFile handles a GET request to bzz://<manifest>/<path> and responds // with the content of the file at <path> from the given <manifest> func (s *Server) HandleGetFile(w http.ResponseWriter, r *Request) { + getFileCount.Inc(1) // ensure the root path has a trailing slash so that relative URLs work if r.uri.Path == "" && !strings.HasSuffix(r.URL.Path, "/") { http.Redirect(w, &r.Request, r.URL.Path+"/", http.StatusMovedPermanently) @@ -546,7 +597,8 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *Request) { key, err := s.api.Resolve(r.uri) if err != nil { - s.Error(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) + getFileFail.Inc(1) + s.NotFound(w, r, fmt.Errorf("error resolving %s: %s", r.uri.Addr, err)) return } @@ -554,8 +606,10 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *Request) { if err != nil { switch status { case http.StatusNotFound: + getFileNotFound.Inc(1) s.NotFound(w, r, err) default: + getFileFail.Inc(1) s.Error(w, r, err) } return @@ -567,18 +621,20 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *Request) { list, err := s.getManifestList(key, r.uri.Path) if err != nil { + getFileFail.Inc(1) s.Error(w, r, err) return } s.logDebug(fmt.Sprintf("Multiple choices! --> %v", list)) //show a nice page links to available entries - ShowMultipleChoices(w, &r.Request, list) + ShowMultipleChoices(w, r, list) return } // check the root chunk exists by retrieving the file's size if _, err := reader.Size(nil); err != nil { + getFileNotFound.Inc(1) s.NotFound(w, r, fmt.Errorf("File not found %s: %s", r.uri, err)) return } @@ -589,8 +645,30 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *Request) { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if metrics.Enabled { + //The increment for request count and request timer themselves have a flag check + //for metrics.Enabled. Nevertheless, we introduce the if here because we + //are looking into the header just to see what request type it is (json/html). + //So let's take advantage and add all metrics related stuff here + requestCount.Inc(1) + defer requestTimer.UpdateSince(time.Now()) + if r.Header.Get("Accept") == "application/json" { + jsonRequestCount.Inc(1) + } else { + htmlRequestCount.Inc(1) + } + } s.logDebug("HTTP %s request URL: '%s', Host: '%s', Path: '%s', Referer: '%s', Accept: '%s'", r.Method, r.RequestURI, r.URL.Host, r.URL.Path, r.Referer(), r.Header.Get("Accept")) + if r.RequestURI == "/" && strings.Contains(r.Header.Get("Accept"), "text/html") { + + err := landingPageTemplate.Execute(w, nil) + if err != nil { + s.logError("error rendering landing page: %s", err) + } + return + } + uri, err := api.Parse(strings.TrimLeft(r.URL.Path, "/")) req := &Request{Request: *r, uri: uri} if err != nil { @@ -615,7 +693,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // strictly a traditional PUT request which replaces content // at a URI, and POST is more ubiquitous) if uri.Raw() || uri.DeprecatedRaw() { - ShowError(w, r, fmt.Sprintf("No PUT to %s allowed.", uri), http.StatusBadRequest) + ShowError(w, req, fmt.Sprintf("No PUT to %s allowed.", uri), http.StatusBadRequest) return } else { s.HandlePostFiles(w, req) @@ -623,7 +701,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "DELETE": if uri.Raw() || uri.DeprecatedRaw() { - ShowError(w, r, fmt.Sprintf("No DELETE to %s allowed.", uri), http.StatusBadRequest) + ShowError(w, req, fmt.Sprintf("No DELETE to %s allowed.", uri), http.StatusBadRequest) return } s.HandleDelete(w, req) @@ -647,7 +725,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.HandleGetFile(w, req) default: - ShowError(w, r, fmt.Sprintf("Method "+r.Method+" is not supported.", uri), http.StatusMethodNotAllowed) + ShowError(w, req, fmt.Sprintf("Method "+r.Method+" is not supported.", uri), http.StatusMethodNotAllowed) } } @@ -679,13 +757,13 @@ func (s *Server) logError(format string, v ...interface{}) { } func (s *Server) BadRequest(w http.ResponseWriter, r *Request, reason string) { - ShowError(w, &r.Request, fmt.Sprintf("Bad request %s %s: %s", r.Method, r.uri, reason), http.StatusBadRequest) + ShowError(w, r, fmt.Sprintf("Bad request %s %s: %s", r.Request.Method, r.uri, reason), http.StatusBadRequest) } func (s *Server) Error(w http.ResponseWriter, r *Request, err error) { - ShowError(w, &r.Request, fmt.Sprintf("Error serving %s %s: %s", r.Method, r.uri, err), http.StatusInternalServerError) + ShowError(w, r, fmt.Sprintf("Error serving %s %s: %s", r.Request.Method, r.uri, err), http.StatusInternalServerError) } func (s *Server) NotFound(w http.ResponseWriter, r *Request, err error) { - ShowError(w, &r.Request, fmt.Sprintf("NOT FOUND error serving %s %s: %s", r.Method, r.uri, err), http.StatusNotFound) + ShowError(w, r, fmt.Sprintf("NOT FOUND error serving %s %s: %s", r.Request.Method, r.uri, err), http.StatusNotFound) } diff --git a/swarm/api/http/templates.go b/swarm/api/http/templates.go index 189a99912..cd9d21289 100644 --- a/swarm/api/http/templates.go +++ b/swarm/api/http/templates.go @@ -70,3 +70,146 @@ var htmlListTemplate = template.Must(template.New("html-list").Funcs(template.Fu <hr> </body> `[1:])) + +var landingPageTemplate = template.Must(template.New("landingPage").Parse(` +<html> + <head> + <meta http-equiv="Content-Type" content="text/html; charset=UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0"> + <meta http-equiv="X-UA-Compatible" ww="chrome=1"> + <meta name="description" content="Ethereum/Swarm Landing page"> + <meta property="og:url" content="https://swarm-gateways.net/bzz:/theswarm.eth"> + <style> + + body, div, header, footer { + margin: 0; + padding: 0; + } + + body { + overflow: hidden; + } + + .container { + min-width: 100%; + min-height: 100%; + max-height: 100%; + } + + header { + display: flex; + align-items: center; + background-color: #ffa500; + /* height: 20vh; */ + padding: 5px; + } + + .header-left, .header-right { + width: 20%; + } + + .header-left { + padding-left: 40px; + float: left; + } + + .header-right { + padding-right: 40px; + float: right; + } + + .page-title { + /* margin-top: 4.5vh; */ + text-align: center; + float: left; + width: 60%; + color: white; + } + + content-body { + display: block; + margin: 0 auto; + text-align: center; + /* width: 50%; */ + min-height: 60vh; + max-height: 60vh; + padding: 50px 20px; + opacity: 0.6; + background-color: #A9F5BF; + } + + table { + font-size: 1.2em; + margin: 0 auto; + } + + tr { + height: 60px; + } + + td { + text-align: center; + } + + .key { + color: #111; + font-weight: bold; + width: 200px; + } + + .value { + color: red; + font-weight: bold + } + + footer { + height: 20vh; + background-color: #ffa500; + font-size: 1em; + text-align: center; + padding: 20px; + } + + </style> + <title>Swarm :: Welcome to Swarm</title> + </head> + <body> + + + <header> + <div class="header-left"> + <img style="height:18vh;margin-left:40px" src=""/> + </div> + <div class="page-title"> + <h1>Welcome to Swarm</h1> + </div> + </header> + + <script type="text/javascript"> + function goToPage() { + var page = document.getElementById('page').value; + if (page == "") { + var page = "theswarm.eth" + } + var address = "/bzz:/" + page; + location.href = address; + console.log(address) + } + </script> + <content-body> + + <h1>Enter the hash or ENS of a Swarm-hosted file below:</h1> + <input type="text" id="page" size="64"/> + <input type="submit" value="submit" onclick="goToPage();" /> + + </content-body> + <footer> + <p> + Swarm: Serverless Hosting Incentivised Peer-To-Peer Storage And Content Distribution<br/> + <a href="http://swarm-gateways.net/bzz:/theswarm.eth">Swarm</a> + </p> + </footer> + + </body> +</html> +`[1:])) diff --git a/swarm/fuse/swarmfs_util.go b/swarm/fuse/swarmfs_util.go index d39966c0e..169b67487 100644 --- a/swarm/fuse/swarmfs_util.go +++ b/swarm/fuse/swarmfs_util.go @@ -47,7 +47,6 @@ func externalUnmount(mountPoint string) error { } func addFileToSwarm(sf *SwarmFile, content []byte, size int) error { - fkey, mhash, err := sf.mountInfo.swarmApi.AddFile(sf.mountInfo.LatestManifest, sf.path, sf.name, content, true) if err != nil { return err @@ -64,11 +63,9 @@ func addFileToSwarm(sf *SwarmFile, content []byte, size int) error { log.Info("Added new file:", "fname", sf.name, "New Manifest hash", mhash) return nil - } func removeFileFromSwarm(sf *SwarmFile) error { - mkey, err := sf.mountInfo.swarmApi.RemoveFile(sf.mountInfo.LatestManifest, sf.path, sf.name, true) if err != nil { return err @@ -83,7 +80,6 @@ func removeFileFromSwarm(sf *SwarmFile) error { } func removeDirectoryFromSwarm(sd *SwarmDir) error { - if len(sd.directories) == 0 && len(sd.files) == 0 { return nil } @@ -103,11 +99,9 @@ func removeDirectoryFromSwarm(sd *SwarmDir) error { } return nil - } func appendToExistingFileInSwarm(sf *SwarmFile, content []byte, offset int64, length int64) error { - fkey, mhash, err := sf.mountInfo.swarmApi.AppendFile(sf.mountInfo.LatestManifest, sf.path, sf.name, sf.fileSize, content, sf.key, offset, length, true) if err != nil { return err @@ -124,5 +118,4 @@ func appendToExistingFileInSwarm(sf *SwarmFile, content []byte, offset int64, le log.Info("Appended file:", "fname", sf.name, "New Manifest hash", mhash) return nil - } diff --git a/swarm/metrics/flags.go b/swarm/metrics/flags.go new file mode 100644 index 000000000..48b231b21 --- /dev/null +++ b/swarm/metrics/flags.go @@ -0,0 +1,91 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package metrics + +import ( + "time" + + "github.com/ethereum/go-ethereum/cmd/utils" + "github.com/ethereum/go-ethereum/log" + gethmetrics "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/metrics/influxdb" + "gopkg.in/urfave/cli.v1" +) + +var ( + metricsEnableInfluxDBExportFlag = cli.BoolFlag{ + Name: "metrics.influxdb.export", + Usage: "Enable metrics export/push to an external InfluxDB database", + } + metricsInfluxDBEndpointFlag = cli.StringFlag{ + Name: "metrics.influxdb.endpoint", + Usage: "Metrics InfluxDB endpoint", + Value: "http://127.0.0.1:8086", + } + metricsInfluxDBDatabaseFlag = cli.StringFlag{ + Name: "metrics.influxdb.database", + Usage: "Metrics InfluxDB database", + Value: "metrics", + } + metricsInfluxDBUsernameFlag = cli.StringFlag{ + Name: "metrics.influxdb.username", + Usage: "Metrics InfluxDB username", + Value: "", + } + metricsInfluxDBPasswordFlag = cli.StringFlag{ + Name: "metrics.influxdb.password", + Usage: "Metrics InfluxDB password", + Value: "", + } + // The `host` tag is part of every measurement sent to InfluxDB. Queries on tags are faster in InfluxDB. + // It is used so that we can group all nodes and average a measurement across all of them, but also so + // that we can select a specific node and inspect its measurements. + // https://docs.influxdata.com/influxdb/v1.4/concepts/key_concepts/#tag-key + metricsInfluxDBHostTagFlag = cli.StringFlag{ + Name: "metrics.influxdb.host.tag", + Usage: "Metrics InfluxDB `host` tag attached to all measurements", + Value: "localhost", + } +) + +// Flags holds all command-line flags required for metrics collection. +var Flags = []cli.Flag{ + utils.MetricsEnabledFlag, + metricsEnableInfluxDBExportFlag, + metricsInfluxDBEndpointFlag, metricsInfluxDBDatabaseFlag, metricsInfluxDBUsernameFlag, metricsInfluxDBPasswordFlag, metricsInfluxDBHostTagFlag, +} + +func Setup(ctx *cli.Context) { + if gethmetrics.Enabled { + log.Info("Enabling swarm metrics collection") + var ( + enableExport = ctx.GlobalBool(metricsEnableInfluxDBExportFlag.Name) + endpoint = ctx.GlobalString(metricsInfluxDBEndpointFlag.Name) + database = ctx.GlobalString(metricsInfluxDBDatabaseFlag.Name) + username = ctx.GlobalString(metricsInfluxDBUsernameFlag.Name) + password = ctx.GlobalString(metricsInfluxDBPasswordFlag.Name) + hosttag = ctx.GlobalString(metricsInfluxDBHostTagFlag.Name) + ) + + if enableExport { + log.Info("Enabling swarm metrics export to InfluxDB") + go influxdb.InfluxDBWithTags(gethmetrics.DefaultRegistry, 10*time.Second, endpoint, database, username, password, "swarm.", map[string]string{ + "host": hosttag, + }) + } + } +} diff --git a/swarm/network/depo.go b/swarm/network/depo.go index 17540d2f9..5ffbf8be1 100644 --- a/swarm/network/depo.go +++ b/swarm/network/depo.go @@ -23,9 +23,19 @@ import ( "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/swarm/storage" ) +//metrics variables +var ( + syncReceiveCount = metrics.NewRegisteredCounter("network.sync.recv.count", nil) + syncReceiveIgnore = metrics.NewRegisteredCounter("network.sync.recv.ignore", nil) + syncSendCount = metrics.NewRegisteredCounter("network.sync.send.count", nil) + syncSendRefused = metrics.NewRegisteredCounter("network.sync.send.refused", nil) + syncSendNotFound = metrics.NewRegisteredCounter("network.sync.send.notfound", nil) +) + // Handler for storage/retrieval related protocol requests // implements the StorageHandler interface used by the bzz protocol type Depo struct { @@ -107,6 +117,7 @@ func (self *Depo) HandleStoreRequestMsg(req *storeRequestMsgData, p *peer) { log.Trace(fmt.Sprintf("Depo.handleStoreRequest: %v not found locally. create new chunk/request", req.Key)) // not found in memory cache, ie., a genuine store request // create chunk + syncReceiveCount.Inc(1) chunk = storage.NewChunk(req.Key, nil) case chunk.SData == nil: @@ -116,6 +127,7 @@ func (self *Depo) HandleStoreRequestMsg(req *storeRequestMsgData, p *peer) { default: // data is found, store request ignored // this should update access count? + syncReceiveIgnore.Inc(1) log.Trace(fmt.Sprintf("Depo.HandleStoreRequest: %v found locally. ignore.", req)) islocal = true //return @@ -172,11 +184,14 @@ func (self *Depo) HandleRetrieveRequestMsg(req *retrieveRequestMsgData, p *peer) SData: chunk.SData, requestTimeout: req.timeout, // } + syncSendCount.Inc(1) p.syncer.addRequest(sreq, DeliverReq) } else { + syncSendRefused.Inc(1) log.Trace(fmt.Sprintf("Depo.HandleRetrieveRequest: %v - content found, not wanted", req.Key.Log())) } } else { + syncSendNotFound.Inc(1) log.Trace(fmt.Sprintf("Depo.HandleRetrieveRequest: %v - content not found locally. asked swarm for help. will get back", req.Key.Log())) } } diff --git a/swarm/network/hive.go b/swarm/network/hive.go index 2504a4610..8404ffcc2 100644 --- a/swarm/network/hive.go +++ b/swarm/network/hive.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/swarm/network/kademlia" @@ -39,6 +40,12 @@ import ( // connections and disconnections are reported and relayed // to keep the nodetable uptodate +var ( + peersNumGauge = metrics.NewRegisteredGauge("network.peers.num", nil) + addPeerCounter = metrics.NewRegisteredCounter("network.addpeer.count", nil) + removePeerCounter = metrics.NewRegisteredCounter("network.removepeer.count", nil) +) + type Hive struct { listenAddr func() string callInterval uint64 @@ -192,6 +199,7 @@ func (self *Hive) Start(id discover.NodeID, listenAddr func() string, connectPee func (self *Hive) keepAlive() { alarm := time.NewTicker(time.Duration(self.callInterval)).C for { + peersNumGauge.Update(int64(self.kad.Count())) select { case <-alarm: if self.kad.DBCount() > 0 { @@ -223,6 +231,7 @@ func (self *Hive) Stop() error { // called at the end of a successful protocol handshake func (self *Hive) addPeer(p *peer) error { + addPeerCounter.Inc(1) defer func() { select { case self.more <- true: @@ -247,6 +256,7 @@ func (self *Hive) addPeer(p *peer) error { // called after peer disconnected func (self *Hive) removePeer(p *peer) { + removePeerCounter.Inc(1) log.Debug(fmt.Sprintf("bee %v removed", p)) self.kad.Off(p, saveSync) select { diff --git a/swarm/network/kademlia/kademlia.go b/swarm/network/kademlia/kademlia.go index 0abc42a19..b5999b52d 100644 --- a/swarm/network/kademlia/kademlia.go +++ b/swarm/network/kademlia/kademlia.go @@ -24,6 +24,16 @@ import ( "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" +) + +//metrics variables +//For metrics, we want to count how many times peers are added/removed +//at a certain index. Thus we do that with an array of counters with +//entry for each index +var ( + bucketAddIndexCount []metrics.Counter + bucketRmIndexCount []metrics.Counter ) const ( @@ -88,12 +98,14 @@ type Node interface { // params is KadParams configuration func New(addr Address, params *KadParams) *Kademlia { buckets := make([][]Node, params.MaxProx+1) - return &Kademlia{ + kad := &Kademlia{ addr: addr, KadParams: params, buckets: buckets, db: newKadDb(addr, params), } + kad.initMetricsVariables() + return kad } // accessor for KAD base address @@ -138,6 +150,7 @@ func (self *Kademlia) On(node Node, cb func(*NodeRecord, Node) error) (err error // TODO: give priority to peers with active traffic if len(bucket) < self.BucketSize { // >= allows us to add peers beyond the bucketsize limitation self.buckets[index] = append(bucket, node) + bucketAddIndexCount[index].Inc(1) log.Debug(fmt.Sprintf("add node %v to table", node)) self.setProxLimit(index, true) record.node = node @@ -178,6 +191,7 @@ func (self *Kademlia) Off(node Node, cb func(*NodeRecord, Node)) (err error) { defer self.lock.Unlock() index := self.proximityBin(node.Addr()) + bucketRmIndexCount[index].Inc(1) bucket := self.buckets[index] for i := 0; i < len(bucket); i++ { if node.Addr() == bucket[i].Addr() { @@ -426,3 +440,15 @@ func (self *Kademlia) String() string { rows = append(rows, "=========================================================================") return strings.Join(rows, "\n") } + +//We have to build up the array of counters for each index +func (self *Kademlia) initMetricsVariables() { + //create the arrays + bucketAddIndexCount = make([]metrics.Counter, self.MaxProx+1) + bucketRmIndexCount = make([]metrics.Counter, self.MaxProx+1) + //at each index create a metrics counter + for i := 0; i < (self.KadParams.MaxProx + 1); i++ { + bucketAddIndexCount[i] = metrics.NewRegisteredCounter(fmt.Sprintf("network.kademlia.bucket.add.%d.index", i), nil) + bucketRmIndexCount[i] = metrics.NewRegisteredCounter(fmt.Sprintf("network.kademlia.bucket.rm.%d.index", i), nil) + } +} diff --git a/swarm/network/protocol.go b/swarm/network/protocol.go index a418c1dbb..1cbe00a97 100644 --- a/swarm/network/protocol.go +++ b/swarm/network/protocol.go @@ -39,12 +39,26 @@ import ( "github.com/ethereum/go-ethereum/contracts/chequebook" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" bzzswap "github.com/ethereum/go-ethereum/swarm/services/swap" "github.com/ethereum/go-ethereum/swarm/services/swap/swap" "github.com/ethereum/go-ethereum/swarm/storage" ) +//metrics variables +var ( + storeRequestMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.storerequest.count", nil) + retrieveRequestMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.retrieverequest.count", nil) + peersMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.peers.count", nil) + syncRequestMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.syncrequest.count", nil) + unsyncedKeysMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.unsyncedkeys.count", nil) + deliverRequestMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.deliverrequest.count", nil) + paymentMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.payment.count", nil) + invalidMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.invalid.count", nil) + handleStatusMsgCounter = metrics.NewRegisteredCounter("network.protocol.msg.handlestatus.count", nil) +) + const ( Version = 0 ProtocolLength = uint64(8) @@ -206,6 +220,7 @@ func (self *bzz) handle() error { case storeRequestMsg: // store requests are dispatched to netStore + storeRequestMsgCounter.Inc(1) var req storeRequestMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<- %v: %v", msg, err) @@ -221,6 +236,7 @@ func (self *bzz) handle() error { case retrieveRequestMsg: // retrieve Requests are dispatched to netStore + retrieveRequestMsgCounter.Inc(1) var req retrieveRequestMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<- %v: %v", msg, err) @@ -241,6 +257,7 @@ func (self *bzz) handle() error { case peersMsg: // response to lookups and immediate response to retrieve requests // dispatches new peer data to the hive that adds them to KADDB + peersMsgCounter.Inc(1) var req peersMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<- %v: %v", msg, err) @@ -250,6 +267,7 @@ func (self *bzz) handle() error { self.hive.HandlePeersMsg(&req, &peer{bzz: self}) case syncRequestMsg: + syncRequestMsgCounter.Inc(1) var req syncRequestMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<- %v: %v", msg, err) @@ -260,6 +278,7 @@ func (self *bzz) handle() error { case unsyncedKeysMsg: // coming from parent node offering + unsyncedKeysMsgCounter.Inc(1) var req unsyncedKeysMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<- %v: %v", msg, err) @@ -274,6 +293,7 @@ func (self *bzz) handle() error { case deliveryRequestMsg: // response to syncKeysMsg hashes filtered not existing in db // also relays the last synced state to the source + deliverRequestMsgCounter.Inc(1) var req deliveryRequestMsgData if err := msg.Decode(&req); err != nil { return fmt.Errorf("<-msg %v: %v", msg, err) @@ -287,6 +307,7 @@ func (self *bzz) handle() error { case paymentMsg: // swap protocol message for payment, Units paid for, Cheque paid with + paymentMsgCounter.Inc(1) if self.swapEnabled { var req paymentMsgData if err := msg.Decode(&req); err != nil { @@ -298,6 +319,7 @@ func (self *bzz) handle() error { default: // no other message is allowed + invalidMsgCounter.Inc(1) return fmt.Errorf("invalid message code: %v", msg.Code) } return nil @@ -332,6 +354,8 @@ func (self *bzz) handleStatus() (err error) { return fmt.Errorf("first msg has code %x (!= %x)", msg.Code, statusMsg) } + handleStatusMsgCounter.Inc(1) + if msg.Size > ProtocolMaxMsgSize { return fmt.Errorf("message too long: %v > %v", msg.Size, ProtocolMaxMsgSize) } diff --git a/swarm/storage/chunker.go b/swarm/storage/chunker.go index 98cd6e75e..2b397f801 100644 --- a/swarm/storage/chunker.go +++ b/swarm/storage/chunker.go @@ -23,6 +23,8 @@ import ( "io" "sync" "time" + + "github.com/ethereum/go-ethereum/metrics" ) /* @@ -63,6 +65,11 @@ var ( errOperationTimedOut = errors.New("operation timed out") ) +//metrics variables +var ( + newChunkCounter = metrics.NewRegisteredCounter("storage.chunks.new", nil) +) + type TreeChunker struct { branches int64 hashFunc SwarmHasher @@ -298,6 +305,13 @@ func (self *TreeChunker) hashChunk(hasher SwarmHash, job *hashJob, chunkC chan * job.parentWg.Done() if chunkC != nil { + //NOTE: this increases the chunk count even if the local node already has this chunk; + //on file upload the node will increase this counter even if the same file has already been uploaded + //So it should be evaluated whether it is worth keeping this counter + //and/or actually better track when the chunk is Put to the local database + //(which may question the need for disambiguation when a completely new chunk has been created + //and/or a chunk is being put to the local DB; for chunk tracking it may be worth distinguishing + newChunkCounter.Inc(1) chunkC <- newChunk } } diff --git a/swarm/storage/dbstore.go b/swarm/storage/dbstore.go index 46a5c16cc..421bb061d 100644 --- a/swarm/storage/dbstore.go +++ b/swarm/storage/dbstore.go @@ -33,11 +33,18 @@ import ( "sync" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/iterator" ) +//metrics variables +var ( + gcCounter = metrics.NewRegisteredCounter("storage.db.dbstore.gc.count", nil) + dbStoreDeleteCounter = metrics.NewRegisteredCounter("storage.db.dbstore.rm.count", nil) +) + const ( defaultDbCapacity = 5000000 defaultRadius = 0 // not yet used @@ -255,6 +262,7 @@ func (s *DbStore) collectGarbage(ratio float32) { // actual gc for i := 0; i < gcnt; i++ { if s.gcArray[i].value <= cutval { + gcCounter.Inc(1) s.delete(s.gcArray[i].idx, s.gcArray[i].idxKey) } } @@ -383,6 +391,7 @@ func (s *DbStore) delete(idx uint64, idxKey []byte) { batch := new(leveldb.Batch) batch.Delete(idxKey) batch.Delete(getDataKey(idx)) + dbStoreDeleteCounter.Inc(1) s.entryCnt-- batch.Put(keyEntryCnt, U64ToBytes(s.entryCnt)) s.db.Write(batch) diff --git a/swarm/storage/localstore.go b/swarm/storage/localstore.go index b442e6cc5..ece0c8615 100644 --- a/swarm/storage/localstore.go +++ b/swarm/storage/localstore.go @@ -18,6 +18,13 @@ package storage import ( "encoding/binary" + + "github.com/ethereum/go-ethereum/metrics" +) + +//metrics variables +var ( + dbStorePutCounter = metrics.NewRegisteredCounter("storage.db.dbstore.put.count", nil) ) // LocalStore is a combination of inmemory db over a disk persisted db @@ -39,6 +46,14 @@ func NewLocalStore(hash SwarmHasher, params *StoreParams) (*LocalStore, error) { }, nil } +func (self *LocalStore) CacheCounter() uint64 { + return uint64(self.memStore.(*MemStore).Counter()) +} + +func (self *LocalStore) DbCounter() uint64 { + return self.DbStore.(*DbStore).Counter() +} + // LocalStore is itself a chunk store // unsafe, in that the data is not integrity checked func (self *LocalStore) Put(chunk *Chunk) { @@ -48,6 +63,7 @@ func (self *LocalStore) Put(chunk *Chunk) { chunk.wg.Add(1) } go func() { + dbStorePutCounter.Inc(1) self.DbStore.Put(chunk) if chunk.wg != nil { chunk.wg.Done() diff --git a/swarm/storage/memstore.go b/swarm/storage/memstore.go index 3cb25ac62..d6be54220 100644 --- a/swarm/storage/memstore.go +++ b/swarm/storage/memstore.go @@ -23,6 +23,13 @@ import ( "sync" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" +) + +//metrics variables +var ( + memstorePutCounter = metrics.NewRegisteredCounter("storage.db.memstore.put.count", nil) + memstoreRemoveCounter = metrics.NewRegisteredCounter("storage.db.memstore.rm.count", nil) ) const ( @@ -130,6 +137,10 @@ func (s *MemStore) setCapacity(c uint) { s.capacity = c } +func (s *MemStore) Counter() uint { + return s.entryCnt +} + // entry (not its copy) is going to be in MemStore func (s *MemStore) Put(entry *Chunk) { if s.capacity == 0 { @@ -145,6 +156,8 @@ func (s *MemStore) Put(entry *Chunk) { s.accessCnt++ + memstorePutCounter.Inc(1) + node := s.memtree bitpos := uint(0) for node.entry == nil { @@ -289,6 +302,7 @@ func (s *MemStore) removeOldest() { } if node.entry.SData != nil { + memstoreRemoveCounter.Inc(1) node.entry = nil s.entryCnt-- } diff --git a/swarm/swarm.go b/swarm/swarm.go index 3be3660b5..0a120db1f 100644 --- a/swarm/swarm.go +++ b/swarm/swarm.go @@ -21,7 +21,11 @@ import ( "context" "crypto/ecdsa" "fmt" + "math/big" "net" + "strings" + "time" + "unicode" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" @@ -30,9 +34,11 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/swarm/api" httpapi "github.com/ethereum/go-ethereum/swarm/api/http" @@ -41,6 +47,16 @@ import ( "github.com/ethereum/go-ethereum/swarm/storage" ) +var ( + startTime time.Time + updateGaugesPeriod = 5 * time.Second + startCounter = metrics.NewRegisteredCounter("stack,start", nil) + stopCounter = metrics.NewRegisteredCounter("stack,stop", nil) + uptimeGauge = metrics.NewRegisteredGauge("stack.uptime", nil) + dbSizeGauge = metrics.NewRegisteredGauge("storage.db.chunks.size", nil) + cacheSizeGauge = metrics.NewRegisteredGauge("storage.db.cache.size", nil) +) + // the swarm stack type Swarm struct { config *api.Config // swarm configuration @@ -76,7 +92,7 @@ func (self *Swarm) API() *SwarmAPI { // creates a new swarm service instance // implements node.Service -func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, ensClient *ethclient.Client, config *api.Config, swapEnabled, syncEnabled bool, cors string) (self *Swarm, err error) { +func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, config *api.Config) (self *Swarm, err error) { if bytes.Equal(common.FromHex(config.PublicKey), storage.ZeroKey) { return nil, fmt.Errorf("empty public key") } @@ -86,10 +102,10 @@ func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, ensClient *e self = &Swarm{ config: config, - swapEnabled: swapEnabled, + swapEnabled: config.SwapEnabled, backend: backend, privateKey: config.Swap.PrivateKey(), - corsString: cors, + corsString: config.Cors, } log.Debug(fmt.Sprintf("Setting up Swarm service components")) @@ -109,8 +125,8 @@ func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, ensClient *e self.hive = network.NewHive( common.HexToHash(self.config.BzzKey), // key to hive (kademlia base address) config.HiveParams, // configuration parameters - swapEnabled, // SWAP enabled - syncEnabled, // syncronisation enabled + config.SwapEnabled, // SWAP enabled + config.SyncEnabled, // syncronisation enabled ) log.Debug(fmt.Sprintf("Set up swarm network with Kademlia hive")) @@ -133,18 +149,18 @@ func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, ensClient *e self.dpa = storage.NewDPA(dpaChunkStore, self.config.ChunkerParams) log.Debug(fmt.Sprintf("-> Content Store API")) - // set up high level api - transactOpts := bind.NewKeyedTransactor(self.privateKey) - - if ensClient == nil { - log.Warn("No ENS, please specify non-empty --ens-api to use domain name resolution") - } else { - self.dns, err = ens.NewENS(transactOpts, config.EnsRoot, ensClient) - if err != nil { - return nil, err + if len(config.EnsAPIs) > 0 { + opts := []api.MultiResolverOption{} + for _, c := range config.EnsAPIs { + tld, endpoint, addr := parseEnsAPIAddress(c) + r, err := newEnsClient(endpoint, addr, config) + if err != nil { + return nil, err + } + opts = append(opts, api.MultiResolverOptionWithResolver(r, tld)) } + self.dns = api.NewMultiResolver(opts...) } - log.Debug(fmt.Sprintf("-> Swarm Domain Name Registrar @ address %v", config.EnsRoot.Hex())) self.api = api.NewApi(self.dpa, self.dns) // Manifests for Smart Hosting @@ -156,6 +172,95 @@ func NewSwarm(ctx *node.ServiceContext, backend chequebook.Backend, ensClient *e return self, nil } +// parseEnsAPIAddress parses string according to format +// [tld:][contract-addr@]url and returns ENSClientConfig structure +// with endpoint, contract address and TLD. +func parseEnsAPIAddress(s string) (tld, endpoint string, addr common.Address) { + isAllLetterString := func(s string) bool { + for _, r := range s { + if !unicode.IsLetter(r) { + return false + } + } + return true + } + endpoint = s + if i := strings.Index(endpoint, ":"); i > 0 { + if isAllLetterString(endpoint[:i]) && len(endpoint) > i+2 && endpoint[i+1:i+3] != "//" { + tld = endpoint[:i] + endpoint = endpoint[i+1:] + } + } + if i := strings.Index(endpoint, "@"); i > 0 { + addr = common.HexToAddress(endpoint[:i]) + endpoint = endpoint[i+1:] + } + return +} + +// newEnsClient creates a new ENS client for that is a consumer of +// a ENS API on a specific endpoint. It is used as a helper function +// for creating multiple resolvers in NewSwarm function. +func newEnsClient(endpoint string, addr common.Address, config *api.Config) (*ens.ENS, error) { + log.Info("connecting to ENS API", "url", endpoint) + client, err := rpc.Dial(endpoint) + if err != nil { + return nil, fmt.Errorf("error connecting to ENS API %s: %s", endpoint, err) + } + ensClient := ethclient.NewClient(client) + + ensRoot := config.EnsRoot + if addr != (common.Address{}) { + ensRoot = addr + } else { + a, err := detectEnsAddr(client) + if err == nil { + ensRoot = a + } else { + log.Warn(fmt.Sprintf("could not determine ENS contract address, using default %s", ensRoot), "err", err) + } + } + transactOpts := bind.NewKeyedTransactor(config.Swap.PrivateKey()) + dns, err := ens.NewENS(transactOpts, ensRoot, ensClient) + if err != nil { + return nil, err + } + log.Debug(fmt.Sprintf("-> Swarm Domain Name Registrar %v @ address %v", endpoint, ensRoot.Hex())) + return dns, err +} + +// detectEnsAddr determines the ENS contract address by getting both the +// version and genesis hash using the client and matching them to either +// mainnet or testnet addresses +func detectEnsAddr(client *rpc.Client) (common.Address, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var version string + if err := client.CallContext(ctx, &version, "net_version"); err != nil { + return common.Address{}, err + } + + block, err := ethclient.NewClient(client).BlockByNumber(ctx, big.NewInt(0)) + if err != nil { + return common.Address{}, err + } + + switch { + + case version == "1" && block.Hash() == params.MainnetGenesisHash: + log.Info("using Mainnet ENS contract address", "addr", ens.MainNetAddress) + return ens.MainNetAddress, nil + + case version == "3" && block.Hash() == params.TestnetGenesisHash: + log.Info("using Testnet ENS contract address", "addr", ens.TestNetAddress) + return ens.TestNetAddress, nil + + default: + return common.Address{}, fmt.Errorf("unknown version and genesis hash: %s %s", version, block.Hash()) + } +} + /* Start is called when the stack is started * starts the network kademlia hive peer management @@ -168,6 +273,7 @@ Start is called when the stack is started */ // implements the node.Service interface func (self *Swarm) Start(srv *p2p.Server) error { + startTime = time.Now() connectPeer := func(url string) error { node, err := discover.ParseNode(url) if err != nil { @@ -213,9 +319,28 @@ func (self *Swarm) Start(srv *p2p.Server) error { } } + self.periodicallyUpdateGauges() + + startCounter.Inc(1) return nil } +func (self *Swarm) periodicallyUpdateGauges() { + ticker := time.NewTicker(updateGaugesPeriod) + + go func() { + for range ticker.C { + self.updateGauges() + } + }() +} + +func (self *Swarm) updateGauges() { + dbSizeGauge.Update(int64(self.lstore.DbCounter())) + cacheSizeGauge.Update(int64(self.lstore.CacheCounter())) + uptimeGauge.Update(time.Since(startTime).Nanoseconds()) +} + // implements the node.Service interface // stops all component services. func (self *Swarm) Stop() error { @@ -230,6 +355,7 @@ func (self *Swarm) Stop() error { self.lstore.DbStore.Close() } self.sfs.Stop() + stopCounter.Inc(1) return err } diff --git a/swarm/swarm_test.go b/swarm/swarm_test.go new file mode 100644 index 000000000..8b1ae2888 --- /dev/null +++ b/swarm/swarm_test.go @@ -0,0 +1,119 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package swarm + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func TestParseEnsAPIAddress(t *testing.T) { + for _, x := range []struct { + description string + value string + tld string + endpoint string + addr common.Address + }{ + { + description: "IPC endpoint", + value: "/data/testnet/geth.ipc", + endpoint: "/data/testnet/geth.ipc", + }, + { + description: "HTTP endpoint", + value: "http://127.0.0.1:1234", + endpoint: "http://127.0.0.1:1234", + }, + { + description: "WS endpoint", + value: "ws://127.0.0.1:1234", + endpoint: "ws://127.0.0.1:1234", + }, + { + description: "IPC Endpoint and TLD", + value: "test:/data/testnet/geth.ipc", + endpoint: "/data/testnet/geth.ipc", + tld: "test", + }, + { + description: "HTTP endpoint and TLD", + value: "test:http://127.0.0.1:1234", + endpoint: "http://127.0.0.1:1234", + tld: "test", + }, + { + description: "WS endpoint and TLD", + value: "test:ws://127.0.0.1:1234", + endpoint: "ws://127.0.0.1:1234", + tld: "test", + }, + { + description: "IPC Endpoint and contract address", + value: "314159265dD8dbb310642f98f50C066173C1259b@/data/testnet/geth.ipc", + endpoint: "/data/testnet/geth.ipc", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + }, + { + description: "HTTP endpoint and contract address", + value: "314159265dD8dbb310642f98f50C066173C1259b@http://127.0.0.1:1234", + endpoint: "http://127.0.0.1:1234", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + }, + { + description: "WS endpoint and contract address", + value: "314159265dD8dbb310642f98f50C066173C1259b@ws://127.0.0.1:1234", + endpoint: "ws://127.0.0.1:1234", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + }, + { + description: "IPC Endpoint, TLD and contract address", + value: "test:314159265dD8dbb310642f98f50C066173C1259b@/data/testnet/geth.ipc", + endpoint: "/data/testnet/geth.ipc", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + tld: "test", + }, + { + description: "HTTP endpoint, TLD and contract address", + value: "eth:314159265dD8dbb310642f98f50C066173C1259b@http://127.0.0.1:1234", + endpoint: "http://127.0.0.1:1234", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + tld: "eth", + }, + { + description: "WS endpoint, TLD and contract address", + value: "eth:314159265dD8dbb310642f98f50C066173C1259b@ws://127.0.0.1:1234", + endpoint: "ws://127.0.0.1:1234", + addr: common.HexToAddress("314159265dD8dbb310642f98f50C066173C1259b"), + tld: "eth", + }, + } { + t.Run(x.description, func(t *testing.T) { + tld, endpoint, addr := parseEnsAPIAddress(x.value) + if endpoint != x.endpoint { + t.Errorf("expected Endpoint %q, got %q", x.endpoint, endpoint) + } + if addr != x.addr { + t.Errorf("expected ContractAddress %q, got %q", x.addr.String(), addr.String()) + } + if tld != x.tld { + t.Errorf("expected TLD %q, got %q", x.tld, tld) + } + }) + } +} diff --git a/trie/trie.go b/trie/trie.go index e37a1ae10..31a404e3a 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -24,7 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" - "github.com/rcrowley/go-metrics" + "github.com/ethereum/go-ethereum/metrics" ) var ( diff --git a/vendor/github.com/influxdata/influxdb/LICENSE b/vendor/github.com/influxdata/influxdb/LICENSE new file mode 100644 index 000000000..63cef79ba --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013-2016 Errplane Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/influxdata/influxdb/LICENSE_OF_DEPENDENCIES.md b/vendor/github.com/influxdata/influxdb/LICENSE_OF_DEPENDENCIES.md new file mode 100644 index 000000000..ea6fc69f3 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/LICENSE_OF_DEPENDENCIES.md @@ -0,0 +1,62 @@ +- # List +- bootstrap 3.3.5 [MIT LICENSE](https://github.com/twbs/bootstrap/blob/master/LICENSE) +- collectd.org [ISC LICENSE](https://github.com/collectd/go-collectd/blob/master/LICENSE) +- github.com/BurntSushi/toml [MIT LICENSE](https://github.com/BurntSushi/toml/blob/master/COPYING) +- github.com/RoaringBitmap/roaring [APACHE LICENSE](https://github.com/RoaringBitmap/roaring/blob/master/LICENSE) +- github.com/beorn7/perks [MIT LICENSE](https://github.com/beorn7/perks/blob/master/LICENSE) +- github.com/bmizerany/pat [MIT LICENSE](https://github.com/bmizerany/pat#license) +- github.com/boltdb/bolt [MIT LICENSE](https://github.com/boltdb/bolt/blob/master/LICENSE) +- github.com/cespare/xxhash [MIT LICENSE](https://github.com/cespare/xxhash/blob/master/LICENSE.txt) +- github.com/clarkduvall/hyperloglog [MIT LICENSE](https://github.com/clarkduvall/hyperloglog/blob/master/LICENSE) +- github.com/davecgh/go-spew/spew [ISC LICENSE](https://github.com/davecgh/go-spew/blob/master/LICENSE) +- github.com/dgrijalva/jwt-go [MIT LICENSE](https://github.com/dgrijalva/jwt-go/blob/master/LICENSE) +- github.com/dgryski/go-bits [MIT LICENSE](https://github.com/dgryski/go-bits/blob/master/LICENSE) +- github.com/dgryski/go-bitstream [MIT LICENSE](https://github.com/dgryski/go-bitstream/blob/master/LICENSE) +- github.com/glycerine/go-unsnap-stream [MIT LICENSE](https://github.com/glycerine/go-unsnap-stream/blob/master/LICENSE) +- github.com/gogo/protobuf/proto [BSD LICENSE](https://github.com/gogo/protobuf/blob/master/LICENSE) +- github.com/golang/protobuf [BSD LICENSE](https://github.com/golang/protobuf/blob/master/LICENSE) +- github.com/golang/snappy [BSD LICENSE](https://github.com/golang/snappy/blob/master/LICENSE) +- github.com/google/go-cmp [BSD LICENSE](https://github.com/google/go-cmp/blob/master/LICENSE) +- github.com/influxdata/influxql [MIT LICENSE](https://github.com/influxdata/influxql/blob/master/LICENSE) +- github.com/influxdata/usage-client [MIT LICENSE](https://github.com/influxdata/usage-client/blob/master/LICENSE.txt) +- github.com/influxdata/yamux [MOZILLA PUBLIC LICENSE](https://github.com/influxdata/yamux/blob/master/LICENSE) +- github.com/influxdata/yarpc [MIT LICENSE](https://github.com/influxdata/yarpc/blob/master/LICENSE) +- github.com/jsternberg/zap-logfmt [MIT LICENSE](https://github.com/jsternberg/zap-logfmt/blob/master/LICENSE) +- github.com/jwilder/encoding [MIT LICENSE](https://github.com/jwilder/encoding/blob/master/LICENSE) +- github.com/mattn/go-isatty [MIT LICENSE](https://github.com/mattn/go-isatty/blob/master/LICENSE) +- github.com/matttproud/golang_protobuf_extensions [APACHE LICENSE](https://github.com/matttproud/golang_protobuf_extensions/blob/master/LICENSE) +- github.com/opentracing/opentracing-go [MIT LICENSE](https://github.com/opentracing/opentracing-go/blob/master/LICENSE) +- github.com/paulbellamy/ratecounter [MIT LICENSE](https://github.com/paulbellamy/ratecounter/blob/master/LICENSE) +- github.com/peterh/liner [MIT LICENSE](https://github.com/peterh/liner/blob/master/COPYING) +- github.com/philhofer/fwd [MIT LICENSE](https://github.com/philhofer/fwd/blob/master/LICENSE.md) +- github.com/prometheus/client_golang [MIT LICENSE](https://github.com/prometheus/client_golang/blob/master/LICENSE) +- github.com/prometheus/client_model [MIT LICENSE](https://github.com/prometheus/client_model/blob/master/LICENSE) +- github.com/prometheus/common [APACHE LICENSE](https://github.com/prometheus/common/blob/master/LICENSE) +- github.com/prometheus/procfs [APACHE LICENSE](https://github.com/prometheus/procfs/blob/master/LICENSE) +- github.com/rakyll/statik [APACHE LICENSE](https://github.com/rakyll/statik/blob/master/LICENSE) +- github.com/retailnext/hllpp [BSD LICENSE](https://github.com/retailnext/hllpp/blob/master/LICENSE) +- github.com/tinylib/msgp [MIT LICENSE](https://github.com/tinylib/msgp/blob/master/LICENSE) +- go.uber.org/atomic [MIT LICENSE](https://github.com/uber-go/atomic/blob/master/LICENSE.txt) +- go.uber.org/multierr [MIT LICENSE](https://github.com/uber-go/multierr/blob/master/LICENSE.txt) +- go.uber.org/zap [MIT LICENSE](https://github.com/uber-go/zap/blob/master/LICENSE.txt) +- golang.org/x/crypto [BSD LICENSE](https://github.com/golang/crypto/blob/master/LICENSE) +- golang.org/x/net [BSD LICENSE](https://github.com/golang/net/blob/master/LICENSE) +- golang.org/x/sys [BSD LICENSE](https://github.com/golang/sys/blob/master/LICENSE) +- golang.org/x/text [BSD LICENSE](https://github.com/golang/text/blob/master/LICENSE) +- golang.org/x/time [BSD LICENSE](https://github.com/golang/time/blob/master/LICENSE) +- jquery 2.1.4 [MIT LICENSE](https://github.com/jquery/jquery/blob/master/LICENSE.txt) +- github.com/xlab/treeprint [MIT LICENSE](https://github.com/xlab/treeprint/blob/master/LICENSE) + + + + + + + + + + + + + + diff --git a/vendor/github.com/influxdata/influxdb/client/README.md b/vendor/github.com/influxdata/influxdb/client/README.md new file mode 100644 index 000000000..773a11122 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/client/README.md @@ -0,0 +1,306 @@ +# InfluxDB Client + +[![GoDoc](https://godoc.org/github.com/influxdata/influxdb?status.svg)](http://godoc.org/github.com/influxdata/influxdb/client/v2) + +## Description + +**NOTE:** The Go client library now has a "v2" version, with the old version +being deprecated. The new version can be imported at +`import "github.com/influxdata/influxdb/client/v2"`. It is not backwards-compatible. + +A Go client library written and maintained by the **InfluxDB** team. +This package provides convenience functions to read and write time series data. +It uses the HTTP protocol to communicate with your **InfluxDB** cluster. + + +## Getting Started + +### Connecting To Your Database + +Connecting to an **InfluxDB** database is straightforward. You will need a host +name, a port and the cluster user credentials if applicable. The default port is +8086. You can customize these settings to your specific installation via the +**InfluxDB** configuration file. + +Though not necessary for experimentation, you may want to create a new user +and authenticate the connection to your database. + +For more information please check out the +[Admin Docs](https://docs.influxdata.com/influxdb/latest/administration/). + +For the impatient, you can create a new admin user _bubba_ by firing off the +[InfluxDB CLI](https://github.com/influxdata/influxdb/blob/master/cmd/influx/main.go). + +```shell +influx +> create user bubba with password 'bumblebeetuna' +> grant all privileges to bubba +``` + +And now for good measure set the credentials in you shell environment. +In the example below we will use $INFLUX_USER and $INFLUX_PWD + +Now with the administrivia out of the way, let's connect to our database. + +NOTE: If you've opted out of creating a user, you can omit Username and Password in +the configuration below. + +```go +package main + +import ( + "log" + "time" + + "github.com/influxdata/influxdb/client/v2" +) + +const ( + MyDB = "square_holes" + username = "bubba" + password = "bumblebeetuna" +) + + +func main() { + // Create a new HTTPClient + c, err := client.NewHTTPClient(client.HTTPConfig{ + Addr: "http://localhost:8086", + Username: username, + Password: password, + }) + if err != nil { + log.Fatal(err) + } + + // Create a new point batch + bp, err := client.NewBatchPoints(client.BatchPointsConfig{ + Database: MyDB, + Precision: "s", + }) + if err != nil { + log.Fatal(err) + } + + // Create a point and add to batch + tags := map[string]string{"cpu": "cpu-total"} + fields := map[string]interface{}{ + "idle": 10.1, + "system": 53.3, + "user": 46.6, + } + + pt, err := client.NewPoint("cpu_usage", tags, fields, time.Now()) + if err != nil { + log.Fatal(err) + } + bp.AddPoint(pt) + + // Write the batch + if err := c.Write(bp); err != nil { + log.Fatal(err) + } +} + +``` + +### Inserting Data + +Time series data aka *points* are written to the database using batch inserts. +The mechanism is to create one or more points and then create a batch aka +*batch points* and write these to a given database and series. A series is a +combination of a measurement (time/values) and a set of tags. + +In this sample we will create a batch of a 1,000 points. Each point has a time and +a single value as well as 2 tags indicating a shape and color. We write these points +to a database called _square_holes_ using a measurement named _shapes_. + +NOTE: You can specify a RetentionPolicy as part of the batch points. If not +provided InfluxDB will use the database _default_ retention policy. + +```go + +func writePoints(clnt client.Client) { + sampleSize := 1000 + + bp, err := client.NewBatchPoints(client.BatchPointsConfig{ + Database: "systemstats", + Precision: "us", + }) + if err != nil { + log.Fatal(err) + } + + rand.Seed(time.Now().UnixNano()) + for i := 0; i < sampleSize; i++ { + regions := []string{"us-west1", "us-west2", "us-west3", "us-east1"} + tags := map[string]string{ + "cpu": "cpu-total", + "host": fmt.Sprintf("host%d", rand.Intn(1000)), + "region": regions[rand.Intn(len(regions))], + } + + idle := rand.Float64() * 100.0 + fields := map[string]interface{}{ + "idle": idle, + "busy": 100.0 - idle, + } + + pt, err := client.NewPoint( + "cpu_usage", + tags, + fields, + time.Now(), + ) + if err != nil { + log.Fatal(err) + } + bp.AddPoint(pt) + } + + if err := clnt.Write(bp); err != nil { + log.Fatal(err) + } +} +``` + +#### Uint64 Support + +The `uint64` data type is supported if your server is version `1.4.0` or +greater. To write a data point as an unsigned integer, you must insert +the point as `uint64`. You cannot use `uint` or any of the other +derivatives because previous versions of the client have supported +writing those types as an integer. + +### Querying Data + +One nice advantage of using **InfluxDB** the ability to query your data using familiar +SQL constructs. In this example we can create a convenience function to query the database +as follows: + +```go +// queryDB convenience function to query the database +func queryDB(clnt client.Client, cmd string) (res []client.Result, err error) { + q := client.Query{ + Command: cmd, + Database: MyDB, + } + if response, err := clnt.Query(q); err == nil { + if response.Error() != nil { + return res, response.Error() + } + res = response.Results + } else { + return res, err + } + return res, nil +} +``` + +#### Creating a Database + +```go +_, err := queryDB(clnt, fmt.Sprintf("CREATE DATABASE %s", MyDB)) +if err != nil { + log.Fatal(err) +} +``` + +#### Count Records + +```go +q := fmt.Sprintf("SELECT count(%s) FROM %s", "value", MyMeasurement) +res, err := queryDB(clnt, q) +if err != nil { + log.Fatal(err) +} +count := res[0].Series[0].Values[0][1] +log.Printf("Found a total of %v records\n", count) +``` + +#### Find the last 10 _shapes_ records + +```go +q := fmt.Sprintf("SELECT * FROM %s LIMIT %d", MyMeasurement, 10) +res, err = queryDB(clnt, q) +if err != nil { + log.Fatal(err) +} + +for i, row := range res[0].Series[0].Values { + t, err := time.Parse(time.RFC3339, row[0].(string)) + if err != nil { + log.Fatal(err) + } + val := row[1].(string) + log.Printf("[%2d] %s: %s\n", i, t.Format(time.Stamp), val) +} +``` + +### Using the UDP Client + +The **InfluxDB** client also supports writing over UDP. + +```go +func WriteUDP() { + // Make client + c, err := client.NewUDPClient("localhost:8089") + if err != nil { + panic(err.Error()) + } + + // Create a new point batch + bp, _ := client.NewBatchPoints(client.BatchPointsConfig{ + Precision: "s", + }) + + // Create a point and add to batch + tags := map[string]string{"cpu": "cpu-total"} + fields := map[string]interface{}{ + "idle": 10.1, + "system": 53.3, + "user": 46.6, + } + pt, err := client.NewPoint("cpu_usage", tags, fields, time.Now()) + if err != nil { + panic(err.Error()) + } + bp.AddPoint(pt) + + // Write the batch + c.Write(bp) +} +``` + +### Point Splitting + +The UDP client now supports splitting single points that exceed the configured +payload size. The logic for processing each point is listed here, starting with +an empty payload. + +1. If adding the point to the current (non-empty) payload would exceed the + configured size, send the current payload. Otherwise, add it to the current + payload. +1. If the point is smaller than the configured size, add it to the payload. +1. If the point has no timestamp, just try to send the entire point as a single + UDP payload, and process the next point. +1. Since the point has a timestamp, re-use the existing measurement name, + tagset, and timestamp and create multiple new points by splitting up the + fields. The per-point length will be kept close to the configured size, + staying under it if possible. This does mean that one large field, maybe a + long string, could be sent as a larger-than-configured payload. + +The above logic attempts to respect configured payload sizes, but not sacrifice +any data integrity. Points without a timestamp can't be split, as that may +cause fields to have differing timestamps when processed by the server. + +## Go Docs + +Please refer to +[http://godoc.org/github.com/influxdata/influxdb/client/v2](http://godoc.org/github.com/influxdata/influxdb/client/v2) +for documentation. + +## See Also + +You can also examine how the client library is used by the +[InfluxDB CLI](https://github.com/influxdata/influxdb/blob/master/cmd/influx/main.go). diff --git a/vendor/github.com/influxdata/influxdb/client/influxdb.go b/vendor/github.com/influxdata/influxdb/client/influxdb.go new file mode 100644 index 000000000..98d362d50 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/client/influxdb.go @@ -0,0 +1,840 @@ +// Package client implements a now-deprecated client for InfluxDB; +// use github.com/influxdata/influxdb/client/v2 instead. +package client // import "github.com/influxdata/influxdb/client" + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/influxdata/influxdb/models" +) + +const ( + // DefaultHost is the default host used to connect to an InfluxDB instance + DefaultHost = "localhost" + + // DefaultPort is the default port used to connect to an InfluxDB instance + DefaultPort = 8086 + + // DefaultTimeout is the default connection timeout used to connect to an InfluxDB instance + DefaultTimeout = 0 +) + +// Query is used to send a command to the server. Both Command and Database are required. +type Query struct { + Command string + Database string + + // Chunked tells the server to send back chunked responses. This places + // less load on the server by sending back chunks of the response rather + // than waiting for the entire response all at once. + Chunked bool + + // ChunkSize sets the maximum number of rows that will be returned per + // chunk. Chunks are either divided based on their series or if they hit + // the chunk size limit. + // + // Chunked must be set to true for this option to be used. + ChunkSize int +} + +// ParseConnectionString will parse a string to create a valid connection URL +func ParseConnectionString(path string, ssl bool) (url.URL, error) { + var host string + var port int + + h, p, err := net.SplitHostPort(path) + if err != nil { + if path == "" { + host = DefaultHost + } else { + host = path + } + // If they didn't specify a port, always use the default port + port = DefaultPort + } else { + host = h + port, err = strconv.Atoi(p) + if err != nil { + return url.URL{}, fmt.Errorf("invalid port number %q: %s\n", path, err) + } + } + + u := url.URL{ + Scheme: "http", + } + if ssl { + u.Scheme = "https" + } + + u.Host = net.JoinHostPort(host, strconv.Itoa(port)) + + return u, nil +} + +// Config is used to specify what server to connect to. +// URL: The URL of the server connecting to. +// Username/Password are optional. They will be passed via basic auth if provided. +// UserAgent: If not provided, will default "InfluxDBClient", +// Timeout: If not provided, will default to 0 (no timeout) +type Config struct { + URL url.URL + UnixSocket string + Username string + Password string + UserAgent string + Timeout time.Duration + Precision string + WriteConsistency string + UnsafeSsl bool +} + +// NewConfig will create a config to be used in connecting to the client +func NewConfig() Config { + return Config{ + Timeout: DefaultTimeout, + } +} + +// Client is used to make calls to the server. +type Client struct { + url url.URL + unixSocket string + username string + password string + httpClient *http.Client + userAgent string + precision string +} + +const ( + // ConsistencyOne requires at least one data node acknowledged a write. + ConsistencyOne = "one" + + // ConsistencyAll requires all data nodes to acknowledge a write. + ConsistencyAll = "all" + + // ConsistencyQuorum requires a quorum of data nodes to acknowledge a write. + ConsistencyQuorum = "quorum" + + // ConsistencyAny allows for hinted hand off, potentially no write happened yet. + ConsistencyAny = "any" +) + +// NewClient will instantiate and return a connected client to issue commands to the server. +func NewClient(c Config) (*Client, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: c.UnsafeSsl, + } + + tr := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + if c.UnixSocket != "" { + // No need for compression in local communications. + tr.DisableCompression = true + + tr.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", c.UnixSocket) + } + } + + client := Client{ + url: c.URL, + unixSocket: c.UnixSocket, + username: c.Username, + password: c.Password, + httpClient: &http.Client{Timeout: c.Timeout, Transport: tr}, + userAgent: c.UserAgent, + precision: c.Precision, + } + if client.userAgent == "" { + client.userAgent = "InfluxDBClient" + } + return &client, nil +} + +// SetAuth will update the username and passwords +func (c *Client) SetAuth(u, p string) { + c.username = u + c.password = p +} + +// SetPrecision will update the precision +func (c *Client) SetPrecision(precision string) { + c.precision = precision +} + +// Query sends a command to the server and returns the Response +func (c *Client) Query(q Query) (*Response, error) { + return c.QueryContext(context.Background(), q) +} + +// QueryContext sends a command to the server and returns the Response +// It uses a context that can be cancelled by the command line client +func (c *Client) QueryContext(ctx context.Context, q Query) (*Response, error) { + u := c.url + + u.Path = "query" + values := u.Query() + values.Set("q", q.Command) + values.Set("db", q.Database) + if q.Chunked { + values.Set("chunked", "true") + if q.ChunkSize > 0 { + values.Set("chunk_size", strconv.Itoa(q.ChunkSize)) + } + } + if c.precision != "" { + values.Set("epoch", c.precision) + } + u.RawQuery = values.Encode() + + req, err := http.NewRequest("POST", u.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + req = req.WithContext(ctx) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + if q.Chunked { + cr := NewChunkedResponse(resp.Body) + for { + r, err := cr.NextResponse() + if err != nil { + // If we got an error while decoding the response, send that back. + return nil, err + } + + if r == nil { + break + } + + response.Results = append(response.Results, r.Results...) + if r.Err != nil { + response.Err = r.Err + break + } + } + } else { + dec := json.NewDecoder(resp.Body) + dec.UseNumber() + if err := dec.Decode(&response); err != nil { + // Ignore EOF errors if we got an invalid status code. + if !(err == io.EOF && resp.StatusCode != http.StatusOK) { + return nil, err + } + } + } + + // If we don't have an error in our json response, and didn't get StatusOK, + // then send back an error. + if resp.StatusCode != http.StatusOK && response.Error() == nil { + return &response, fmt.Errorf("received status code %d from server", resp.StatusCode) + } + return &response, nil +} + +// Write takes BatchPoints and allows for writing of multiple points with defaults +// If successful, error is nil and Response is nil +// If an error occurs, Response may contain additional information if populated. +func (c *Client) Write(bp BatchPoints) (*Response, error) { + u := c.url + u.Path = "write" + + var b bytes.Buffer + for _, p := range bp.Points { + err := checkPointTypes(p) + if err != nil { + return nil, err + } + if p.Raw != "" { + if _, err := b.WriteString(p.Raw); err != nil { + return nil, err + } + } else { + for k, v := range bp.Tags { + if p.Tags == nil { + p.Tags = make(map[string]string, len(bp.Tags)) + } + p.Tags[k] = v + } + + if _, err := b.WriteString(p.MarshalString()); err != nil { + return nil, err + } + } + + if err := b.WriteByte('\n'); err != nil { + return nil, err + } + } + + req, err := http.NewRequest("POST", u.String(), &b) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + precision := bp.Precision + if precision == "" { + precision = "ns" + } + + params := req.URL.Query() + params.Set("db", bp.Database) + params.Set("rp", bp.RetentionPolicy) + params.Set("precision", precision) + params.Set("consistency", bp.WriteConsistency) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + var err = fmt.Errorf(string(body)) + response.Err = err + return &response, err + } + + return nil, nil +} + +// WriteLineProtocol takes a string with line returns to delimit each write +// If successful, error is nil and Response is nil +// If an error occurs, Response may contain additional information if populated. +func (c *Client) WriteLineProtocol(data, database, retentionPolicy, precision, writeConsistency string) (*Response, error) { + u := c.url + u.Path = "write" + + r := strings.NewReader(data) + + req, err := http.NewRequest("POST", u.String(), r) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + params := req.URL.Query() + params.Set("db", database) + params.Set("rp", retentionPolicy) + params.Set("precision", precision) + params.Set("consistency", writeConsistency) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + err := fmt.Errorf(string(body)) + response.Err = err + return &response, err + } + + return nil, nil +} + +// Ping will check to see if the server is up +// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred. +func (c *Client) Ping() (time.Duration, string, error) { + now := time.Now() + u := c.url + u.Path = "ping" + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return 0, "", err + } + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + version := resp.Header.Get("X-Influxdb-Version") + return time.Since(now), version, nil +} + +// Structs + +// Message represents a user message. +type Message struct { + Level string `json:"level,omitempty"` + Text string `json:"text,omitempty"` +} + +// Result represents a resultset returned from a single statement. +type Result struct { + Series []models.Row + Messages []*Message + Err error +} + +// MarshalJSON encodes the result into JSON. +func (r *Result) MarshalJSON() ([]byte, error) { + // Define a struct that outputs "error" as a string. + var o struct { + Series []models.Row `json:"series,omitempty"` + Messages []*Message `json:"messages,omitempty"` + Err string `json:"error,omitempty"` + } + + // Copy fields to output struct. + o.Series = r.Series + o.Messages = r.Messages + if r.Err != nil { + o.Err = r.Err.Error() + } + + return json.Marshal(&o) +} + +// UnmarshalJSON decodes the data into the Result struct +func (r *Result) UnmarshalJSON(b []byte) error { + var o struct { + Series []models.Row `json:"series,omitempty"` + Messages []*Message `json:"messages,omitempty"` + Err string `json:"error,omitempty"` + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + err := dec.Decode(&o) + if err != nil { + return err + } + r.Series = o.Series + r.Messages = o.Messages + if o.Err != "" { + r.Err = errors.New(o.Err) + } + return nil +} + +// Response represents a list of statement results. +type Response struct { + Results []Result + Err error +} + +// MarshalJSON encodes the response into JSON. +func (r *Response) MarshalJSON() ([]byte, error) { + // Define a struct that outputs "error" as a string. + var o struct { + Results []Result `json:"results,omitempty"` + Err string `json:"error,omitempty"` + } + + // Copy fields to output struct. + o.Results = r.Results + if r.Err != nil { + o.Err = r.Err.Error() + } + + return json.Marshal(&o) +} + +// UnmarshalJSON decodes the data into the Response struct +func (r *Response) UnmarshalJSON(b []byte) error { + var o struct { + Results []Result `json:"results,omitempty"` + Err string `json:"error,omitempty"` + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + err := dec.Decode(&o) + if err != nil { + return err + } + r.Results = o.Results + if o.Err != "" { + r.Err = errors.New(o.Err) + } + return nil +} + +// Error returns the first error from any statement. +// Returns nil if no errors occurred on any statements. +func (r *Response) Error() error { + if r.Err != nil { + return r.Err + } + for _, result := range r.Results { + if result.Err != nil { + return result.Err + } + } + return nil +} + +// duplexReader reads responses and writes it to another writer while +// satisfying the reader interface. +type duplexReader struct { + r io.Reader + w io.Writer +} + +func (r *duplexReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil { + r.w.Write(p[:n]) + } + return n, err +} + +// ChunkedResponse represents a response from the server that +// uses chunking to stream the output. +type ChunkedResponse struct { + dec *json.Decoder + duplex *duplexReader + buf bytes.Buffer +} + +// NewChunkedResponse reads a stream and produces responses from the stream. +func NewChunkedResponse(r io.Reader) *ChunkedResponse { + resp := &ChunkedResponse{} + resp.duplex = &duplexReader{r: r, w: &resp.buf} + resp.dec = json.NewDecoder(resp.duplex) + resp.dec.UseNumber() + return resp +} + +// NextResponse reads the next line of the stream and returns a response. +func (r *ChunkedResponse) NextResponse() (*Response, error) { + var response Response + if err := r.dec.Decode(&response); err != nil { + if err == io.EOF { + return nil, nil + } + // A decoding error happened. This probably means the server crashed + // and sent a last-ditch error message to us. Ensure we have read the + // entirety of the connection to get any remaining error text. + io.Copy(ioutil.Discard, r.duplex) + return nil, errors.New(strings.TrimSpace(r.buf.String())) + } + r.buf.Reset() + return &response, nil +} + +// Point defines the fields that will be written to the database +// Measurement, Time, and Fields are required +// Precision can be specified if the time is in epoch format (integer). +// Valid values for Precision are n, u, ms, s, m, and h +type Point struct { + Measurement string + Tags map[string]string + Time time.Time + Fields map[string]interface{} + Precision string + Raw string +} + +// MarshalJSON will format the time in RFC3339Nano +// Precision is also ignored as it is only used for writing, not reading +// Or another way to say it is we always send back in nanosecond precision +func (p *Point) MarshalJSON() ([]byte, error) { + point := struct { + Measurement string `json:"measurement,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Time string `json:"time,omitempty"` + Fields map[string]interface{} `json:"fields,omitempty"` + Precision string `json:"precision,omitempty"` + }{ + Measurement: p.Measurement, + Tags: p.Tags, + Fields: p.Fields, + Precision: p.Precision, + } + // Let it omit empty if it's really zero + if !p.Time.IsZero() { + point.Time = p.Time.UTC().Format(time.RFC3339Nano) + } + return json.Marshal(&point) +} + +// MarshalString renders string representation of a Point with specified +// precision. The default precision is nanoseconds. +func (p *Point) MarshalString() string { + pt, err := models.NewPoint(p.Measurement, models.NewTags(p.Tags), p.Fields, p.Time) + if err != nil { + return "# ERROR: " + err.Error() + " " + p.Measurement + } + if p.Precision == "" || p.Precision == "ns" || p.Precision == "n" { + return pt.String() + } + return pt.PrecisionString(p.Precision) +} + +// UnmarshalJSON decodes the data into the Point struct +func (p *Point) UnmarshalJSON(b []byte) error { + var normal struct { + Measurement string `json:"measurement"` + Tags map[string]string `json:"tags"` + Time time.Time `json:"time"` + Precision string `json:"precision"` + Fields map[string]interface{} `json:"fields"` + } + var epoch struct { + Measurement string `json:"measurement"` + Tags map[string]string `json:"tags"` + Time *int64 `json:"time"` + Precision string `json:"precision"` + Fields map[string]interface{} `json:"fields"` + } + + if err := func() error { + var err error + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + if err = dec.Decode(&epoch); err != nil { + return err + } + // Convert from epoch to time.Time, but only if Time + // was actually set. + var ts time.Time + if epoch.Time != nil { + ts, err = EpochToTime(*epoch.Time, epoch.Precision) + if err != nil { + return err + } + } + p.Measurement = epoch.Measurement + p.Tags = epoch.Tags + p.Time = ts + p.Precision = epoch.Precision + p.Fields = normalizeFields(epoch.Fields) + return nil + }(); err == nil { + return nil + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + if err := dec.Decode(&normal); err != nil { + return err + } + normal.Time = SetPrecision(normal.Time, normal.Precision) + p.Measurement = normal.Measurement + p.Tags = normal.Tags + p.Time = normal.Time + p.Precision = normal.Precision + p.Fields = normalizeFields(normal.Fields) + + return nil +} + +// Remove any notion of json.Number +func normalizeFields(fields map[string]interface{}) map[string]interface{} { + newFields := map[string]interface{}{} + + for k, v := range fields { + switch v := v.(type) { + case json.Number: + jv, e := v.Float64() + if e != nil { + panic(fmt.Sprintf("unable to convert json.Number to float64: %s", e)) + } + newFields[k] = jv + default: + newFields[k] = v + } + } + return newFields +} + +// BatchPoints is used to send batched data in a single write. +// Database and Points are required +// If no retention policy is specified, it will use the databases default retention policy. +// If tags are specified, they will be "merged" with all points. If a point already has that tag, it will be ignored. +// If time is specified, it will be applied to any point with an empty time. +// Precision can be specified if the time is in epoch format (integer). +// Valid values for Precision are n, u, ms, s, m, and h +type BatchPoints struct { + Points []Point `json:"points,omitempty"` + Database string `json:"database,omitempty"` + RetentionPolicy string `json:"retentionPolicy,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Time time.Time `json:"time,omitempty"` + Precision string `json:"precision,omitempty"` + WriteConsistency string `json:"-"` +} + +// UnmarshalJSON decodes the data into the BatchPoints struct +func (bp *BatchPoints) UnmarshalJSON(b []byte) error { + var normal struct { + Points []Point `json:"points"` + Database string `json:"database"` + RetentionPolicy string `json:"retentionPolicy"` + Tags map[string]string `json:"tags"` + Time time.Time `json:"time"` + Precision string `json:"precision"` + } + var epoch struct { + Points []Point `json:"points"` + Database string `json:"database"` + RetentionPolicy string `json:"retentionPolicy"` + Tags map[string]string `json:"tags"` + Time *int64 `json:"time"` + Precision string `json:"precision"` + } + + if err := func() error { + var err error + if err = json.Unmarshal(b, &epoch); err != nil { + return err + } + // Convert from epoch to time.Time + var ts time.Time + if epoch.Time != nil { + ts, err = EpochToTime(*epoch.Time, epoch.Precision) + if err != nil { + return err + } + } + bp.Points = epoch.Points + bp.Database = epoch.Database + bp.RetentionPolicy = epoch.RetentionPolicy + bp.Tags = epoch.Tags + bp.Time = ts + bp.Precision = epoch.Precision + return nil + }(); err == nil { + return nil + } + + if err := json.Unmarshal(b, &normal); err != nil { + return err + } + normal.Time = SetPrecision(normal.Time, normal.Precision) + bp.Points = normal.Points + bp.Database = normal.Database + bp.RetentionPolicy = normal.RetentionPolicy + bp.Tags = normal.Tags + bp.Time = normal.Time + bp.Precision = normal.Precision + + return nil +} + +// utility functions + +// Addr provides the current url as a string of the server the client is connected to. +func (c *Client) Addr() string { + if c.unixSocket != "" { + return c.unixSocket + } + return c.url.String() +} + +// checkPointTypes ensures no unsupported types are submitted to influxdb, returning error if they are found. +func checkPointTypes(p Point) error { + for _, v := range p.Fields { + switch v.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool, string, nil: + return nil + default: + return fmt.Errorf("unsupported point type: %T", v) + } + } + return nil +} + +// helper functions + +// EpochToTime takes a unix epoch time and uses precision to return back a time.Time +func EpochToTime(epoch int64, precision string) (time.Time, error) { + if precision == "" { + precision = "s" + } + var t time.Time + switch precision { + case "h": + t = time.Unix(0, epoch*int64(time.Hour)) + case "m": + t = time.Unix(0, epoch*int64(time.Minute)) + case "s": + t = time.Unix(0, epoch*int64(time.Second)) + case "ms": + t = time.Unix(0, epoch*int64(time.Millisecond)) + case "u": + t = time.Unix(0, epoch*int64(time.Microsecond)) + case "n": + t = time.Unix(0, epoch) + default: + return time.Time{}, fmt.Errorf("Unknown precision %q", precision) + } + return t, nil +} + +// SetPrecision will round a time to the specified precision +func SetPrecision(t time.Time, precision string) time.Time { + switch precision { + case "n": + case "u": + return t.Round(time.Microsecond) + case "ms": + return t.Round(time.Millisecond) + case "s": + return t.Round(time.Second) + case "m": + return t.Round(time.Minute) + case "h": + return t.Round(time.Hour) + } + return t +} diff --git a/vendor/github.com/influxdata/influxdb/client/v2/client.go b/vendor/github.com/influxdata/influxdb/client/v2/client.go new file mode 100644 index 000000000..77d44f2b3 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/client/v2/client.go @@ -0,0 +1,635 @@ +// Package client (v2) is the current official Go client for InfluxDB. +package client // import "github.com/influxdata/influxdb/client/v2" + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "mime" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/influxdata/influxdb/models" +) + +// HTTPConfig is the config data needed to create an HTTP Client. +type HTTPConfig struct { + // Addr should be of the form "http://host:port" + // or "http://[ipv6-host%zone]:port". + Addr string + + // Username is the influxdb username, optional. + Username string + + // Password is the influxdb password, optional. + Password string + + // UserAgent is the http User Agent, defaults to "InfluxDBClient". + UserAgent string + + // Timeout for influxdb writes, defaults to no timeout. + Timeout time.Duration + + // InsecureSkipVerify gets passed to the http client, if true, it will + // skip https certificate verification. Defaults to false. + InsecureSkipVerify bool + + // TLSConfig allows the user to set their own TLS config for the HTTP + // Client. If set, this option overrides InsecureSkipVerify. + TLSConfig *tls.Config +} + +// BatchPointsConfig is the config data needed to create an instance of the BatchPoints struct. +type BatchPointsConfig struct { + // Precision is the write precision of the points, defaults to "ns". + Precision string + + // Database is the database to write points to. + Database string + + // RetentionPolicy is the retention policy of the points. + RetentionPolicy string + + // Write consistency is the number of servers required to confirm write. + WriteConsistency string +} + +// Client is a client interface for writing & querying the database. +type Client interface { + // Ping checks that status of cluster, and will always return 0 time and no + // error for UDP clients. + Ping(timeout time.Duration) (time.Duration, string, error) + + // Write takes a BatchPoints object and writes all Points to InfluxDB. + Write(bp BatchPoints) error + + // Query makes an InfluxDB Query on the database. This will fail if using + // the UDP client. + Query(q Query) (*Response, error) + + // Close releases any resources a Client may be using. + Close() error +} + +// NewHTTPClient returns a new Client from the provided config. +// Client is safe for concurrent use by multiple goroutines. +func NewHTTPClient(conf HTTPConfig) (Client, error) { + if conf.UserAgent == "" { + conf.UserAgent = "InfluxDBClient" + } + + u, err := url.Parse(conf.Addr) + if err != nil { + return nil, err + } else if u.Scheme != "http" && u.Scheme != "https" { + m := fmt.Sprintf("Unsupported protocol scheme: %s, your address"+ + " must start with http:// or https://", u.Scheme) + return nil, errors.New(m) + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: conf.InsecureSkipVerify, + }, + } + if conf.TLSConfig != nil { + tr.TLSClientConfig = conf.TLSConfig + } + return &client{ + url: *u, + username: conf.Username, + password: conf.Password, + useragent: conf.UserAgent, + httpClient: &http.Client{ + Timeout: conf.Timeout, + Transport: tr, + }, + transport: tr, + }, nil +} + +// Ping will check to see if the server is up with an optional timeout on waiting for leader. +// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred. +func (c *client) Ping(timeout time.Duration) (time.Duration, string, error) { + now := time.Now() + u := c.url + u.Path = "ping" + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return 0, "", err + } + + req.Header.Set("User-Agent", c.useragent) + + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + if timeout > 0 { + params := req.URL.Query() + params.Set("wait_for_leader", fmt.Sprintf("%.0fs", timeout.Seconds())) + req.URL.RawQuery = params.Encode() + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + + if resp.StatusCode != http.StatusNoContent { + var err = fmt.Errorf(string(body)) + return 0, "", err + } + + version := resp.Header.Get("X-Influxdb-Version") + return time.Since(now), version, nil +} + +// Close releases the client's resources. +func (c *client) Close() error { + c.transport.CloseIdleConnections() + return nil +} + +// client is safe for concurrent use as the fields are all read-only +// once the client is instantiated. +type client struct { + // N.B - if url.UserInfo is accessed in future modifications to the + // methods on client, you will need to syncronise access to url. + url url.URL + username string + password string + useragent string + httpClient *http.Client + transport *http.Transport +} + +// BatchPoints is an interface into a batched grouping of points to write into +// InfluxDB together. BatchPoints is NOT thread-safe, you must create a separate +// batch for each goroutine. +type BatchPoints interface { + // AddPoint adds the given point to the Batch of points. + AddPoint(p *Point) + // AddPoints adds the given points to the Batch of points. + AddPoints(ps []*Point) + // Points lists the points in the Batch. + Points() []*Point + + // Precision returns the currently set precision of this Batch. + Precision() string + // SetPrecision sets the precision of this batch. + SetPrecision(s string) error + + // Database returns the currently set database of this Batch. + Database() string + // SetDatabase sets the database of this Batch. + SetDatabase(s string) + + // WriteConsistency returns the currently set write consistency of this Batch. + WriteConsistency() string + // SetWriteConsistency sets the write consistency of this Batch. + SetWriteConsistency(s string) + + // RetentionPolicy returns the currently set retention policy of this Batch. + RetentionPolicy() string + // SetRetentionPolicy sets the retention policy of this Batch. + SetRetentionPolicy(s string) +} + +// NewBatchPoints returns a BatchPoints interface based on the given config. +func NewBatchPoints(conf BatchPointsConfig) (BatchPoints, error) { + if conf.Precision == "" { + conf.Precision = "ns" + } + if _, err := time.ParseDuration("1" + conf.Precision); err != nil { + return nil, err + } + bp := &batchpoints{ + database: conf.Database, + precision: conf.Precision, + retentionPolicy: conf.RetentionPolicy, + writeConsistency: conf.WriteConsistency, + } + return bp, nil +} + +type batchpoints struct { + points []*Point + database string + precision string + retentionPolicy string + writeConsistency string +} + +func (bp *batchpoints) AddPoint(p *Point) { + bp.points = append(bp.points, p) +} + +func (bp *batchpoints) AddPoints(ps []*Point) { + bp.points = append(bp.points, ps...) +} + +func (bp *batchpoints) Points() []*Point { + return bp.points +} + +func (bp *batchpoints) Precision() string { + return bp.precision +} + +func (bp *batchpoints) Database() string { + return bp.database +} + +func (bp *batchpoints) WriteConsistency() string { + return bp.writeConsistency +} + +func (bp *batchpoints) RetentionPolicy() string { + return bp.retentionPolicy +} + +func (bp *batchpoints) SetPrecision(p string) error { + if _, err := time.ParseDuration("1" + p); err != nil { + return err + } + bp.precision = p + return nil +} + +func (bp *batchpoints) SetDatabase(db string) { + bp.database = db +} + +func (bp *batchpoints) SetWriteConsistency(wc string) { + bp.writeConsistency = wc +} + +func (bp *batchpoints) SetRetentionPolicy(rp string) { + bp.retentionPolicy = rp +} + +// Point represents a single data point. +type Point struct { + pt models.Point +} + +// NewPoint returns a point with the given timestamp. If a timestamp is not +// given, then data is sent to the database without a timestamp, in which case +// the server will assign local time upon reception. NOTE: it is recommended to +// send data with a timestamp. +func NewPoint( + name string, + tags map[string]string, + fields map[string]interface{}, + t ...time.Time, +) (*Point, error) { + var T time.Time + if len(t) > 0 { + T = t[0] + } + + pt, err := models.NewPoint(name, models.NewTags(tags), fields, T) + if err != nil { + return nil, err + } + return &Point{ + pt: pt, + }, nil +} + +// String returns a line-protocol string of the Point. +func (p *Point) String() string { + return p.pt.String() +} + +// PrecisionString returns a line-protocol string of the Point, +// with the timestamp formatted for the given precision. +func (p *Point) PrecisionString(precison string) string { + return p.pt.PrecisionString(precison) +} + +// Name returns the measurement name of the point. +func (p *Point) Name() string { + return string(p.pt.Name()) +} + +// Tags returns the tags associated with the point. +func (p *Point) Tags() map[string]string { + return p.pt.Tags().Map() +} + +// Time return the timestamp for the point. +func (p *Point) Time() time.Time { + return p.pt.Time() +} + +// UnixNano returns timestamp of the point in nanoseconds since Unix epoch. +func (p *Point) UnixNano() int64 { + return p.pt.UnixNano() +} + +// Fields returns the fields for the point. +func (p *Point) Fields() (map[string]interface{}, error) { + return p.pt.Fields() +} + +// NewPointFrom returns a point from the provided models.Point. +func NewPointFrom(pt models.Point) *Point { + return &Point{pt: pt} +} + +func (c *client) Write(bp BatchPoints) error { + var b bytes.Buffer + + for _, p := range bp.Points() { + if _, err := b.WriteString(p.pt.PrecisionString(bp.Precision())); err != nil { + return err + } + + if err := b.WriteByte('\n'); err != nil { + return err + } + } + + u := c.url + u.Path = "write" + req, err := http.NewRequest("POST", u.String(), &b) + if err != nil { + return err + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.useragent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + params := req.URL.Query() + params.Set("db", bp.Database()) + params.Set("rp", bp.RetentionPolicy()) + params.Set("precision", bp.Precision()) + params.Set("consistency", bp.WriteConsistency()) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + var err = fmt.Errorf(string(body)) + return err + } + + return nil +} + +// Query defines a query to send to the server. +type Query struct { + Command string + Database string + Precision string + Chunked bool + ChunkSize int + Parameters map[string]interface{} +} + +// NewQuery returns a query object. +// The database and precision arguments can be empty strings if they are not needed for the query. +func NewQuery(command, database, precision string) Query { + return Query{ + Command: command, + Database: database, + Precision: precision, + Parameters: make(map[string]interface{}), + } +} + +// NewQueryWithParameters returns a query object. +// The database and precision arguments can be empty strings if they are not needed for the query. +// parameters is a map of the parameter names used in the command to their values. +func NewQueryWithParameters(command, database, precision string, parameters map[string]interface{}) Query { + return Query{ + Command: command, + Database: database, + Precision: precision, + Parameters: parameters, + } +} + +// Response represents a list of statement results. +type Response struct { + Results []Result + Err string `json:"error,omitempty"` +} + +// Error returns the first error from any statement. +// It returns nil if no errors occurred on any statements. +func (r *Response) Error() error { + if r.Err != "" { + return fmt.Errorf(r.Err) + } + for _, result := range r.Results { + if result.Err != "" { + return fmt.Errorf(result.Err) + } + } + return nil +} + +// Message represents a user message. +type Message struct { + Level string + Text string +} + +// Result represents a resultset returned from a single statement. +type Result struct { + Series []models.Row + Messages []*Message + Err string `json:"error,omitempty"` +} + +// Query sends a command to the server and returns the Response. +func (c *client) Query(q Query) (*Response, error) { + u := c.url + u.Path = "query" + + jsonParameters, err := json.Marshal(q.Parameters) + + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", u.String(), nil) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.useragent) + + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + params := req.URL.Query() + params.Set("q", q.Command) + params.Set("db", q.Database) + params.Set("params", string(jsonParameters)) + if q.Chunked { + params.Set("chunked", "true") + if q.ChunkSize > 0 { + params.Set("chunk_size", strconv.Itoa(q.ChunkSize)) + } + } + + if q.Precision != "" { + params.Set("epoch", q.Precision) + } + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // If we lack a X-Influxdb-Version header, then we didn't get a response from influxdb + // but instead some other service. If the error code is also a 500+ code, then some + // downstream loadbalancer/proxy/etc had an issue and we should report that. + if resp.Header.Get("X-Influxdb-Version") == "" && resp.StatusCode >= http.StatusInternalServerError { + body, err := ioutil.ReadAll(resp.Body) + if err != nil || len(body) == 0 { + return nil, fmt.Errorf("received status code %d from downstream server", resp.StatusCode) + } + + return nil, fmt.Errorf("received status code %d from downstream server, with response body: %q", resp.StatusCode, body) + } + + // If we get an unexpected content type, then it is also not from influx direct and therefore + // we want to know what we received and what status code was returned for debugging purposes. + if cType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")); cType != "application/json" { + // Read up to 1kb of the body to help identify downstream errors and limit the impact of things + // like downstream serving a large file + body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil || len(body) == 0 { + return nil, fmt.Errorf("expected json response, got empty body, with status: %v", resp.StatusCode) + } + + return nil, fmt.Errorf("expected json response, got %q, with status: %v and response body: %q", cType, resp.StatusCode, body) + } + + var response Response + if q.Chunked { + cr := NewChunkedResponse(resp.Body) + for { + r, err := cr.NextResponse() + if err != nil { + // If we got an error while decoding the response, send that back. + return nil, err + } + + if r == nil { + break + } + + response.Results = append(response.Results, r.Results...) + if r.Err != "" { + response.Err = r.Err + break + } + } + } else { + dec := json.NewDecoder(resp.Body) + dec.UseNumber() + decErr := dec.Decode(&response) + + // ignore this error if we got an invalid status code + if decErr != nil && decErr.Error() == "EOF" && resp.StatusCode != http.StatusOK { + decErr = nil + } + // If we got a valid decode error, send that back + if decErr != nil { + return nil, fmt.Errorf("unable to decode json: received status code %d err: %s", resp.StatusCode, decErr) + } + } + + // If we don't have an error in our json response, and didn't get statusOK + // then send back an error + if resp.StatusCode != http.StatusOK && response.Error() == nil { + return &response, fmt.Errorf("received status code %d from server", resp.StatusCode) + } + return &response, nil +} + +// duplexReader reads responses and writes it to another writer while +// satisfying the reader interface. +type duplexReader struct { + r io.Reader + w io.Writer +} + +func (r *duplexReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil { + r.w.Write(p[:n]) + } + return n, err +} + +// ChunkedResponse represents a response from the server that +// uses chunking to stream the output. +type ChunkedResponse struct { + dec *json.Decoder + duplex *duplexReader + buf bytes.Buffer +} + +// NewChunkedResponse reads a stream and produces responses from the stream. +func NewChunkedResponse(r io.Reader) *ChunkedResponse { + resp := &ChunkedResponse{} + resp.duplex = &duplexReader{r: r, w: &resp.buf} + resp.dec = json.NewDecoder(resp.duplex) + resp.dec.UseNumber() + return resp +} + +// NextResponse reads the next line of the stream and returns a response. +func (r *ChunkedResponse) NextResponse() (*Response, error) { + var response Response + + if err := r.dec.Decode(&response); err != nil { + if err == io.EOF { + return nil, nil + } + // A decoding error happened. This probably means the server crashed + // and sent a last-ditch error message to us. Ensure we have read the + // entirety of the connection to get any remaining error text. + io.Copy(ioutil.Discard, r.duplex) + return nil, errors.New(strings.TrimSpace(r.buf.String())) + } + + r.buf.Reset() + return &response, nil +} diff --git a/vendor/github.com/influxdata/influxdb/client/v2/udp.go b/vendor/github.com/influxdata/influxdb/client/v2/udp.go new file mode 100644 index 000000000..779a28b33 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/client/v2/udp.go @@ -0,0 +1,112 @@ +package client + +import ( + "fmt" + "io" + "net" + "time" +) + +const ( + // UDPPayloadSize is a reasonable default payload size for UDP packets that + // could be travelling over the internet. + UDPPayloadSize = 512 +) + +// UDPConfig is the config data needed to create a UDP Client. +type UDPConfig struct { + // Addr should be of the form "host:port" + // or "[ipv6-host%zone]:port". + Addr string + + // PayloadSize is the maximum size of a UDP client message, optional + // Tune this based on your network. Defaults to UDPPayloadSize. + PayloadSize int +} + +// NewUDPClient returns a client interface for writing to an InfluxDB UDP +// service from the given config. +func NewUDPClient(conf UDPConfig) (Client, error) { + var udpAddr *net.UDPAddr + udpAddr, err := net.ResolveUDPAddr("udp", conf.Addr) + if err != nil { + return nil, err + } + + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, err + } + + payloadSize := conf.PayloadSize + if payloadSize == 0 { + payloadSize = UDPPayloadSize + } + + return &udpclient{ + conn: conn, + payloadSize: payloadSize, + }, nil +} + +// Close releases the udpclient's resources. +func (uc *udpclient) Close() error { + return uc.conn.Close() +} + +type udpclient struct { + conn io.WriteCloser + payloadSize int +} + +func (uc *udpclient) Write(bp BatchPoints) error { + var b = make([]byte, 0, uc.payloadSize) // initial buffer size, it will grow as needed + var d, _ = time.ParseDuration("1" + bp.Precision()) + + var delayedError error + + var checkBuffer = func(n int) { + if len(b) > 0 && len(b)+n > uc.payloadSize { + if _, err := uc.conn.Write(b); err != nil { + delayedError = err + } + b = b[:0] + } + } + + for _, p := range bp.Points() { + p.pt.Round(d) + pointSize := p.pt.StringSize() + 1 // include newline in size + //point := p.pt.RoundedString(d) + "\n" + + checkBuffer(pointSize) + + if p.Time().IsZero() || pointSize <= uc.payloadSize { + b = p.pt.AppendString(b) + b = append(b, '\n') + continue + } + + points := p.pt.Split(uc.payloadSize - 1) // account for newline character + for _, sp := range points { + checkBuffer(sp.StringSize() + 1) + b = sp.AppendString(b) + b = append(b, '\n') + } + } + + if len(b) > 0 { + if _, err := uc.conn.Write(b); err != nil { + return err + } + } + return delayedError +} + +func (uc *udpclient) Query(q Query) (*Response, error) { + return nil, fmt.Errorf("Querying via UDP is not supported") +} + +func (uc *udpclient) Ping(timeout time.Duration) (time.Duration, string, error) { + return 0, "", nil +} diff --git a/vendor/github.com/influxdata/influxdb/models/consistency.go b/vendor/github.com/influxdata/influxdb/models/consistency.go new file mode 100644 index 000000000..2a3269bca --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/consistency.go @@ -0,0 +1,48 @@ +package models + +import ( + "errors" + "strings" +) + +// ConsistencyLevel represent a required replication criteria before a write can +// be returned as successful. +// +// The consistency level is handled in open-source InfluxDB but only applicable to clusters. +type ConsistencyLevel int + +const ( + // ConsistencyLevelAny allows for hinted handoff, potentially no write happened yet. + ConsistencyLevelAny ConsistencyLevel = iota + + // ConsistencyLevelOne requires at least one data node acknowledged a write. + ConsistencyLevelOne + + // ConsistencyLevelQuorum requires a quorum of data nodes to acknowledge a write. + ConsistencyLevelQuorum + + // ConsistencyLevelAll requires all data nodes to acknowledge a write. + ConsistencyLevelAll +) + +var ( + // ErrInvalidConsistencyLevel is returned when parsing the string version + // of a consistency level. + ErrInvalidConsistencyLevel = errors.New("invalid consistency level") +) + +// ParseConsistencyLevel converts a consistency level string to the corresponding ConsistencyLevel const. +func ParseConsistencyLevel(level string) (ConsistencyLevel, error) { + switch strings.ToLower(level) { + case "any": + return ConsistencyLevelAny, nil + case "one": + return ConsistencyLevelOne, nil + case "quorum": + return ConsistencyLevelQuorum, nil + case "all": + return ConsistencyLevelAll, nil + default: + return 0, ErrInvalidConsistencyLevel + } +} diff --git a/vendor/github.com/influxdata/influxdb/models/inline_fnv.go b/vendor/github.com/influxdata/influxdb/models/inline_fnv.go new file mode 100644 index 000000000..eec1ae8b0 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/inline_fnv.go @@ -0,0 +1,32 @@ +package models // import "github.com/influxdata/influxdb/models" + +// from stdlib hash/fnv/fnv.go +const ( + prime64 = 1099511628211 + offset64 = 14695981039346656037 +) + +// InlineFNV64a is an alloc-free port of the standard library's fnv64a. +// See https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function. +type InlineFNV64a uint64 + +// NewInlineFNV64a returns a new instance of InlineFNV64a. +func NewInlineFNV64a() InlineFNV64a { + return offset64 +} + +// Write adds data to the running hash. +func (s *InlineFNV64a) Write(data []byte) (int, error) { + hash := uint64(*s) + for _, c := range data { + hash ^= uint64(c) + hash *= prime64 + } + *s = InlineFNV64a(hash) + return len(data), nil +} + +// Sum64 returns the uint64 of the current resulting hash. +func (s *InlineFNV64a) Sum64() uint64 { + return uint64(*s) +} diff --git a/vendor/github.com/influxdata/influxdb/models/inline_strconv_parse.go b/vendor/github.com/influxdata/influxdb/models/inline_strconv_parse.go new file mode 100644 index 000000000..8db483738 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/inline_strconv_parse.go @@ -0,0 +1,44 @@ +package models // import "github.com/influxdata/influxdb/models" + +import ( + "reflect" + "strconv" + "unsafe" +) + +// parseIntBytes is a zero-alloc wrapper around strconv.ParseInt. +func parseIntBytes(b []byte, base int, bitSize int) (i int64, err error) { + s := unsafeBytesToString(b) + return strconv.ParseInt(s, base, bitSize) +} + +// parseUintBytes is a zero-alloc wrapper around strconv.ParseUint. +func parseUintBytes(b []byte, base int, bitSize int) (i uint64, err error) { + s := unsafeBytesToString(b) + return strconv.ParseUint(s, base, bitSize) +} + +// parseFloatBytes is a zero-alloc wrapper around strconv.ParseFloat. +func parseFloatBytes(b []byte, bitSize int) (float64, error) { + s := unsafeBytesToString(b) + return strconv.ParseFloat(s, bitSize) +} + +// parseBoolBytes is a zero-alloc wrapper around strconv.ParseBool. +func parseBoolBytes(b []byte) (bool, error) { + return strconv.ParseBool(unsafeBytesToString(b)) +} + +// unsafeBytesToString converts a []byte to a string without a heap allocation. +// +// It is unsafe, and is intended to prepare input to short-lived functions +// that require strings. +func unsafeBytesToString(in []byte) string { + src := *(*reflect.SliceHeader)(unsafe.Pointer(&in)) + dst := reflect.StringHeader{ + Data: src.Data, + Len: src.Len, + } + s := *(*string)(unsafe.Pointer(&dst)) + return s +} diff --git a/vendor/github.com/influxdata/influxdb/models/points.go b/vendor/github.com/influxdata/influxdb/models/points.go new file mode 100644 index 000000000..ad80a816b --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/points.go @@ -0,0 +1,2337 @@ +// Package models implements basic objects used throughout the TICK stack. +package models // import "github.com/influxdata/influxdb/models" + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "sort" + "strconv" + "strings" + "time" + + "github.com/influxdata/influxdb/pkg/escape" +) + +var ( + measurementEscapeCodes = map[byte][]byte{ + ',': []byte(`\,`), + ' ': []byte(`\ `), + } + + tagEscapeCodes = map[byte][]byte{ + ',': []byte(`\,`), + ' ': []byte(`\ `), + '=': []byte(`\=`), + } + + // ErrPointMustHaveAField is returned when operating on a point that does not have any fields. + ErrPointMustHaveAField = errors.New("point without fields is unsupported") + + // ErrInvalidNumber is returned when a number is expected but not provided. + ErrInvalidNumber = errors.New("invalid number") + + // ErrInvalidPoint is returned when a point cannot be parsed correctly. + ErrInvalidPoint = errors.New("point is invalid") +) + +const ( + // MaxKeyLength is the largest allowed size of the combined measurement and tag keys. + MaxKeyLength = 65535 +) + +// enableUint64Support will enable uint64 support if set to true. +var enableUint64Support = false + +// EnableUintSupport manually enables uint support for the point parser. +// This function will be removed in the future and only exists for unit tests during the +// transition. +func EnableUintSupport() { + enableUint64Support = true +} + +// Point defines the values that will be written to the database. +type Point interface { + // Name return the measurement name for the point. + Name() []byte + + // SetName updates the measurement name for the point. + SetName(string) + + // Tags returns the tag set for the point. + Tags() Tags + + // AddTag adds or replaces a tag value for a point. + AddTag(key, value string) + + // SetTags replaces the tags for the point. + SetTags(tags Tags) + + // HasTag returns true if the tag exists for the point. + HasTag(tag []byte) bool + + // Fields returns the fields for the point. + Fields() (Fields, error) + + // Time return the timestamp for the point. + Time() time.Time + + // SetTime updates the timestamp for the point. + SetTime(t time.Time) + + // UnixNano returns the timestamp of the point as nanoseconds since Unix epoch. + UnixNano() int64 + + // HashID returns a non-cryptographic checksum of the point's key. + HashID() uint64 + + // Key returns the key (measurement joined with tags) of the point. + Key() []byte + + // String returns a string representation of the point. If there is a + // timestamp associated with the point then it will be specified with the default + // precision of nanoseconds. + String() string + + // MarshalBinary returns a binary representation of the point. + MarshalBinary() ([]byte, error) + + // PrecisionString returns a string representation of the point. If there + // is a timestamp associated with the point then it will be specified in the + // given unit. + PrecisionString(precision string) string + + // RoundedString returns a string representation of the point. If there + // is a timestamp associated with the point, then it will be rounded to the + // given duration. + RoundedString(d time.Duration) string + + // Split will attempt to return multiple points with the same timestamp whose + // string representations are no longer than size. Points with a single field or + // a point without a timestamp may exceed the requested size. + Split(size int) []Point + + // Round will round the timestamp of the point to the given duration. + Round(d time.Duration) + + // StringSize returns the length of the string that would be returned by String(). + StringSize() int + + // AppendString appends the result of String() to the provided buffer and returns + // the result, potentially reducing string allocations. + AppendString(buf []byte) []byte + + // FieldIterator retuns a FieldIterator that can be used to traverse the + // fields of a point without constructing the in-memory map. + FieldIterator() FieldIterator +} + +// FieldType represents the type of a field. +type FieldType int + +const ( + // Integer indicates the field's type is integer. + Integer FieldType = iota + + // Float indicates the field's type is float. + Float + + // Boolean indicates the field's type is boolean. + Boolean + + // String indicates the field's type is string. + String + + // Empty is used to indicate that there is no field. + Empty + + // Unsigned indicates the field's type is an unsigned integer. + Unsigned +) + +// FieldIterator provides a low-allocation interface to iterate through a point's fields. +type FieldIterator interface { + // Next indicates whether there any fields remaining. + Next() bool + + // FieldKey returns the key of the current field. + FieldKey() []byte + + // Type returns the FieldType of the current field. + Type() FieldType + + // StringValue returns the string value of the current field. + StringValue() string + + // IntegerValue returns the integer value of the current field. + IntegerValue() (int64, error) + + // UnsignedValue returns the unsigned value of the current field. + UnsignedValue() (uint64, error) + + // BooleanValue returns the boolean value of the current field. + BooleanValue() (bool, error) + + // FloatValue returns the float value of the current field. + FloatValue() (float64, error) + + // Reset resets the iterator to its initial state. + Reset() +} + +// Points represents a sortable list of points by timestamp. +type Points []Point + +// Len implements sort.Interface. +func (a Points) Len() int { return len(a) } + +// Less implements sort.Interface. +func (a Points) Less(i, j int) bool { return a[i].Time().Before(a[j].Time()) } + +// Swap implements sort.Interface. +func (a Points) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// point is the default implementation of Point. +type point struct { + time time.Time + + // text encoding of measurement and tags + // key must always be stored sorted by tags, if the original line was not sorted, + // we need to resort it + key []byte + + // text encoding of field data + fields []byte + + // text encoding of timestamp + ts []byte + + // cached version of parsed fields from data + cachedFields map[string]interface{} + + // cached version of parsed name from key + cachedName string + + // cached version of parsed tags + cachedTags Tags + + it fieldIterator +} + +// type assertions +var ( + _ Point = (*point)(nil) + _ FieldIterator = (*point)(nil) +) + +const ( + // the number of characters for the largest possible int64 (9223372036854775807) + maxInt64Digits = 19 + + // the number of characters for the smallest possible int64 (-9223372036854775808) + minInt64Digits = 20 + + // the number of characters for the largest possible uint64 (18446744073709551615) + maxUint64Digits = 20 + + // the number of characters required for the largest float64 before a range check + // would occur during parsing + maxFloat64Digits = 25 + + // the number of characters required for smallest float64 before a range check occur + // would occur during parsing + minFloat64Digits = 27 +) + +// ParsePoints returns a slice of Points from a text representation of a point +// with each point separated by newlines. If any points fail to parse, a non-nil error +// will be returned in addition to the points that parsed successfully. +func ParsePoints(buf []byte) ([]Point, error) { + return ParsePointsWithPrecision(buf, time.Now().UTC(), "n") +} + +// ParsePointsString is identical to ParsePoints but accepts a string. +func ParsePointsString(buf string) ([]Point, error) { + return ParsePoints([]byte(buf)) +} + +// ParseKey returns the measurement name and tags from a point. +// +// NOTE: to minimize heap allocations, the returned Tags will refer to subslices of buf. +// This can have the unintended effect preventing buf from being garbage collected. +func ParseKey(buf []byte) (string, Tags) { + meas, tags := ParseKeyBytes(buf) + return string(meas), tags +} + +func ParseKeyBytes(buf []byte) ([]byte, Tags) { + // Ignore the error because scanMeasurement returns "missing fields" which we ignore + // when just parsing a key + state, i, _ := scanMeasurement(buf, 0) + + var tags Tags + if state == tagKeyState { + tags = parseTags(buf) + // scanMeasurement returns the location of the comma if there are tags, strip that off + return buf[:i-1], tags + } + return buf[:i], tags +} + +func ParseTags(buf []byte) Tags { + return parseTags(buf) +} + +func ParseName(buf []byte) ([]byte, error) { + // Ignore the error because scanMeasurement returns "missing fields" which we ignore + // when just parsing a key + state, i, _ := scanMeasurement(buf, 0) + if state == tagKeyState { + return buf[:i-1], nil + } + return buf[:i], nil +} + +// ParsePointsWithPrecision is similar to ParsePoints, but allows the +// caller to provide a precision for time. +// +// NOTE: to minimize heap allocations, the returned Points will refer to subslices of buf. +// This can have the unintended effect preventing buf from being garbage collected. +func ParsePointsWithPrecision(buf []byte, defaultTime time.Time, precision string) ([]Point, error) { + points := make([]Point, 0, bytes.Count(buf, []byte{'\n'})+1) + var ( + pos int + block []byte + failed []string + ) + for pos < len(buf) { + pos, block = scanLine(buf, pos) + pos++ + + if len(block) == 0 { + continue + } + + // lines which start with '#' are comments + start := skipWhitespace(block, 0) + + // If line is all whitespace, just skip it + if start >= len(block) { + continue + } + + if block[start] == '#' { + continue + } + + // strip the newline if one is present + if block[len(block)-1] == '\n' { + block = block[:len(block)-1] + } + + pt, err := parsePoint(block[start:], defaultTime, precision) + if err != nil { + failed = append(failed, fmt.Sprintf("unable to parse '%s': %v", string(block[start:]), err)) + } else { + points = append(points, pt) + } + + } + if len(failed) > 0 { + return points, fmt.Errorf("%s", strings.Join(failed, "\n")) + } + return points, nil + +} + +func parsePoint(buf []byte, defaultTime time.Time, precision string) (Point, error) { + // scan the first block which is measurement[,tag1=value1,tag2=value=2...] + pos, key, err := scanKey(buf, 0) + if err != nil { + return nil, err + } + + // measurement name is required + if len(key) == 0 { + return nil, fmt.Errorf("missing measurement") + } + + if len(key) > MaxKeyLength { + return nil, fmt.Errorf("max key length exceeded: %v > %v", len(key), MaxKeyLength) + } + + // scan the second block is which is field1=value1[,field2=value2,...] + pos, fields, err := scanFields(buf, pos) + if err != nil { + return nil, err + } + + // at least one field is required + if len(fields) == 0 { + return nil, fmt.Errorf("missing fields") + } + + var maxKeyErr error + walkFields(fields, func(k, v []byte) bool { + if sz := seriesKeySize(key, k); sz > MaxKeyLength { + maxKeyErr = fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength) + return false + } + return true + }) + + if maxKeyErr != nil { + return nil, maxKeyErr + } + + // scan the last block which is an optional integer timestamp + pos, ts, err := scanTime(buf, pos) + if err != nil { + return nil, err + } + + pt := &point{ + key: key, + fields: fields, + ts: ts, + } + + if len(ts) == 0 { + pt.time = defaultTime + pt.SetPrecision(precision) + } else { + ts, err := parseIntBytes(ts, 10, 64) + if err != nil { + return nil, err + } + pt.time, err = SafeCalcTime(ts, precision) + if err != nil { + return nil, err + } + + // Determine if there are illegal non-whitespace characters after the + // timestamp block. + for pos < len(buf) { + if buf[pos] != ' ' { + return nil, ErrInvalidPoint + } + pos++ + } + } + return pt, nil +} + +// GetPrecisionMultiplier will return a multiplier for the precision specified. +func GetPrecisionMultiplier(precision string) int64 { + d := time.Nanosecond + switch precision { + case "u": + d = time.Microsecond + case "ms": + d = time.Millisecond + case "s": + d = time.Second + case "m": + d = time.Minute + case "h": + d = time.Hour + } + return int64(d) +} + +// scanKey scans buf starting at i for the measurement and tag portion of the point. +// It returns the ending position and the byte slice of key within buf. If there +// are tags, they will be sorted if they are not already. +func scanKey(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + + i = start + + // Determines whether the tags are sort, assume they are + sorted := true + + // indices holds the indexes within buf of the start of each tag. For example, + // a buf of 'cpu,host=a,region=b,zone=c' would have indices slice of [4,11,20] + // which indicates that the first tag starts at buf[4], seconds at buf[11], and + // last at buf[20] + indices := make([]int, 100) + + // tracks how many commas we've seen so we know how many values are indices. + // Since indices is an arbitrarily large slice, + // we need to know how many values in the buffer are in use. + commas := 0 + + // First scan the Point's measurement. + state, i, err := scanMeasurement(buf, i) + if err != nil { + return i, buf[start:i], err + } + + // Optionally scan tags if needed. + if state == tagKeyState { + i, commas, indices, err = scanTags(buf, i, indices) + if err != nil { + return i, buf[start:i], err + } + } + + // Now we know where the key region is within buf, and the location of tags, we + // need to determine if duplicate tags exist and if the tags are sorted. This iterates + // over the list comparing each tag in the sequence with each other. + for j := 0; j < commas-1; j++ { + // get the left and right tags + _, left := scanTo(buf[indices[j]:indices[j+1]-1], 0, '=') + _, right := scanTo(buf[indices[j+1]:indices[j+2]-1], 0, '=') + + // If left is greater than right, the tags are not sorted. We do not have to + // continue because the short path no longer works. + // If the tags are equal, then there are duplicate tags, and we should abort. + // If the tags are not sorted, this pass may not find duplicate tags and we + // need to do a more exhaustive search later. + if cmp := bytes.Compare(left, right); cmp > 0 { + sorted = false + break + } else if cmp == 0 { + return i, buf[start:i], fmt.Errorf("duplicate tags") + } + } + + // If the tags are not sorted, then sort them. This sort is inline and + // uses the tag indices we created earlier. The actual buffer is not sorted, the + // indices are using the buffer for value comparison. After the indices are sorted, + // the buffer is reconstructed from the sorted indices. + if !sorted && commas > 0 { + // Get the measurement name for later + measurement := buf[start : indices[0]-1] + + // Sort the indices + indices := indices[:commas] + insertionSort(0, commas, buf, indices) + + // Create a new key using the measurement and sorted indices + b := make([]byte, len(buf[start:i])) + pos := copy(b, measurement) + for _, i := range indices { + b[pos] = ',' + pos++ + _, v := scanToSpaceOr(buf, i, ',') + pos += copy(b[pos:], v) + } + + // Check again for duplicate tags now that the tags are sorted. + for j := 0; j < commas-1; j++ { + // get the left and right tags + _, left := scanTo(buf[indices[j]:], 0, '=') + _, right := scanTo(buf[indices[j+1]:], 0, '=') + + // If the tags are equal, then there are duplicate tags, and we should abort. + // If the tags are not sorted, this pass may not find duplicate tags and we + // need to do a more exhaustive search later. + if bytes.Equal(left, right) { + return i, b, fmt.Errorf("duplicate tags") + } + } + + return i, b, nil + } + + return i, buf[start:i], nil +} + +// The following constants allow us to specify which state to move to +// next, when scanning sections of a Point. +const ( + tagKeyState = iota + tagValueState + fieldsState +) + +// scanMeasurement examines the measurement part of a Point, returning +// the next state to move to, and the current location in the buffer. +func scanMeasurement(buf []byte, i int) (int, int, error) { + // Check first byte of measurement, anything except a comma is fine. + // It can't be a space, since whitespace is stripped prior to this + // function call. + if i >= len(buf) || buf[i] == ',' { + return -1, i, fmt.Errorf("missing measurement") + } + + for { + i++ + if i >= len(buf) { + // cpu + return -1, i, fmt.Errorf("missing fields") + } + + if buf[i-1] == '\\' { + // Skip character (it's escaped). + continue + } + + // Unescaped comma; move onto scanning the tags. + if buf[i] == ',' { + return tagKeyState, i + 1, nil + } + + // Unescaped space; move onto scanning the fields. + if buf[i] == ' ' { + // cpu value=1.0 + return fieldsState, i, nil + } + } +} + +// scanTags examines all the tags in a Point, keeping track of and +// returning the updated indices slice, number of commas and location +// in buf where to start examining the Point fields. +func scanTags(buf []byte, i int, indices []int) (int, int, []int, error) { + var ( + err error + commas int + state = tagKeyState + ) + + for { + switch state { + case tagKeyState: + // Grow our indices slice if we have too many tags. + if commas >= len(indices) { + newIndics := make([]int, cap(indices)*2) + copy(newIndics, indices) + indices = newIndics + } + indices[commas] = i + commas++ + + i, err = scanTagsKey(buf, i) + state = tagValueState // tag value always follows a tag key + case tagValueState: + state, i, err = scanTagsValue(buf, i) + case fieldsState: + indices[commas] = i + 1 + return i, commas, indices, nil + } + + if err != nil { + return i, commas, indices, err + } + } +} + +// scanTagsKey scans each character in a tag key. +func scanTagsKey(buf []byte, i int) (int, error) { + // First character of the key. + if i >= len(buf) || buf[i] == ' ' || buf[i] == ',' || buf[i] == '=' { + // cpu,{'', ' ', ',', '='} + return i, fmt.Errorf("missing tag key") + } + + // Examine each character in the tag key until we hit an unescaped + // equals (the tag value), or we hit an error (i.e., unescaped + // space or comma). + for { + i++ + + // Either we reached the end of the buffer or we hit an + // unescaped comma or space. + if i >= len(buf) || + ((buf[i] == ' ' || buf[i] == ',') && buf[i-1] != '\\') { + // cpu,tag{'', ' ', ','} + return i, fmt.Errorf("missing tag value") + } + + if buf[i] == '=' && buf[i-1] != '\\' { + // cpu,tag= + return i + 1, nil + } + } +} + +// scanTagsValue scans each character in a tag value. +func scanTagsValue(buf []byte, i int) (int, int, error) { + // Tag value cannot be empty. + if i >= len(buf) || buf[i] == ',' || buf[i] == ' ' { + // cpu,tag={',', ' '} + return -1, i, fmt.Errorf("missing tag value") + } + + // Examine each character in the tag value until we hit an unescaped + // comma (move onto next tag key), an unescaped space (move onto + // fields), or we error out. + for { + i++ + if i >= len(buf) { + // cpu,tag=value + return -1, i, fmt.Errorf("missing fields") + } + + // An unescaped equals sign is an invalid tag value. + if buf[i] == '=' && buf[i-1] != '\\' { + // cpu,tag={'=', 'fo=o'} + return -1, i, fmt.Errorf("invalid tag format") + } + + if buf[i] == ',' && buf[i-1] != '\\' { + // cpu,tag=foo, + return tagKeyState, i + 1, nil + } + + // cpu,tag=foo value=1.0 + // cpu, tag=foo\= value=1.0 + if buf[i] == ' ' && buf[i-1] != '\\' { + return fieldsState, i, nil + } + } +} + +func insertionSort(l, r int, buf []byte, indices []int) { + for i := l + 1; i < r; i++ { + for j := i; j > l && less(buf, indices, j, j-1); j-- { + indices[j], indices[j-1] = indices[j-1], indices[j] + } + } +} + +func less(buf []byte, indices []int, i, j int) bool { + // This grabs the tag names for i & j, it ignores the values + _, a := scanTo(buf, indices[i], '=') + _, b := scanTo(buf, indices[j], '=') + return bytes.Compare(a, b) < 0 +} + +// scanFields scans buf, starting at i for the fields section of a point. It returns +// the ending position and the byte slice of the fields within buf. +func scanFields(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + i = start + quoted := false + + // tracks how many '=' we've seen + equals := 0 + + // tracks how many commas we've seen + commas := 0 + + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // escaped characters? + if buf[i] == '\\' && i+1 < len(buf) { + i += 2 + continue + } + + // If the value is quoted, scan until we get to the end quote + // Only quote values in the field value since quotes are not significant + // in the field key + if buf[i] == '"' && equals > commas { + quoted = !quoted + i++ + continue + } + + // If we see an =, ensure that there is at least on char before and after it + if buf[i] == '=' && !quoted { + equals++ + + // check for "... =123" but allow "a\ =123" + if buf[i-1] == ' ' && buf[i-2] != '\\' { + return i, buf[start:i], fmt.Errorf("missing field key") + } + + // check for "...a=123,=456" but allow "a=123,a\,=456" + if buf[i-1] == ',' && buf[i-2] != '\\' { + return i, buf[start:i], fmt.Errorf("missing field key") + } + + // check for "... value=" + if i+1 >= len(buf) { + return i, buf[start:i], fmt.Errorf("missing field value") + } + + // check for "... value=,value2=..." + if buf[i+1] == ',' || buf[i+1] == ' ' { + return i, buf[start:i], fmt.Errorf("missing field value") + } + + if isNumeric(buf[i+1]) || buf[i+1] == '-' || buf[i+1] == 'N' || buf[i+1] == 'n' { + var err error + i, err = scanNumber(buf, i+1) + if err != nil { + return i, buf[start:i], err + } + continue + } + // If next byte is not a double-quote, the value must be a boolean + if buf[i+1] != '"' { + var err error + i, _, err = scanBoolean(buf, i+1) + if err != nil { + return i, buf[start:i], err + } + continue + } + } + + if buf[i] == ',' && !quoted { + commas++ + } + + // reached end of block? + if buf[i] == ' ' && !quoted { + break + } + i++ + } + + if quoted { + return i, buf[start:i], fmt.Errorf("unbalanced quotes") + } + + // check that all field sections had key and values (e.g. prevent "a=1,b" + if equals == 0 || commas != equals-1 { + return i, buf[start:i], fmt.Errorf("invalid field format") + } + + return i, buf[start:i], nil +} + +// scanTime scans buf, starting at i for the time section of a point. It +// returns the ending position and the byte slice of the timestamp within buf +// and and error if the timestamp is not in the correct numeric format. +func scanTime(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + i = start + + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // Reached end of block or trailing whitespace? + if buf[i] == '\n' || buf[i] == ' ' { + break + } + + // Handle negative timestamps + if i == start && buf[i] == '-' { + i++ + continue + } + + // Timestamps should be integers, make sure they are so we don't need + // to actually parse the timestamp until needed. + if buf[i] < '0' || buf[i] > '9' { + return i, buf[start:i], fmt.Errorf("bad timestamp") + } + i++ + } + return i, buf[start:i], nil +} + +func isNumeric(b byte) bool { + return (b >= '0' && b <= '9') || b == '.' +} + +// scanNumber returns the end position within buf, start at i after +// scanning over buf for an integer, or float. It returns an +// error if a invalid number is scanned. +func scanNumber(buf []byte, i int) (int, error) { + start := i + var isInt, isUnsigned bool + + // Is negative number? + if i < len(buf) && buf[i] == '-' { + i++ + // There must be more characters now, as just '-' is illegal. + if i == len(buf) { + return i, ErrInvalidNumber + } + } + + // how many decimal points we've see + decimal := false + + // indicates the number is float in scientific notation + scientific := false + + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' || buf[i] == ' ' { + break + } + + if buf[i] == 'i' && i > start && !(isInt || isUnsigned) { + isInt = true + i++ + continue + } else if buf[i] == 'u' && i > start && !(isInt || isUnsigned) { + isUnsigned = true + i++ + continue + } + + if buf[i] == '.' { + // Can't have more than 1 decimal (e.g. 1.1.1 should fail) + if decimal { + return i, ErrInvalidNumber + } + decimal = true + } + + // `e` is valid for floats but not as the first char + if i > start && (buf[i] == 'e' || buf[i] == 'E') { + scientific = true + i++ + continue + } + + // + and - are only valid at this point if they follow an e (scientific notation) + if (buf[i] == '+' || buf[i] == '-') && (buf[i-1] == 'e' || buf[i-1] == 'E') { + i++ + continue + } + + // NaN is an unsupported value + if i+2 < len(buf) && (buf[i] == 'N' || buf[i] == 'n') { + return i, ErrInvalidNumber + } + + if !isNumeric(buf[i]) { + return i, ErrInvalidNumber + } + i++ + } + + if (isInt || isUnsigned) && (decimal || scientific) { + return i, ErrInvalidNumber + } + + numericDigits := i - start + if isInt { + numericDigits-- + } + if decimal { + numericDigits-- + } + if buf[start] == '-' { + numericDigits-- + } + + if numericDigits == 0 { + return i, ErrInvalidNumber + } + + // It's more common that numbers will be within min/max range for their type but we need to prevent + // out or range numbers from being parsed successfully. This uses some simple heuristics to decide + // if we should parse the number to the actual type. It does not do it all the time because it incurs + // extra allocations and we end up converting the type again when writing points to disk. + if isInt { + // Make sure the last char is an 'i' for integers (e.g. 9i10 is not valid) + if buf[i-1] != 'i' { + return i, ErrInvalidNumber + } + // Parse the int to check bounds the number of digits could be larger than the max range + // We subtract 1 from the index to remove the `i` from our tests + if len(buf[start:i-1]) >= maxInt64Digits || len(buf[start:i-1]) >= minInt64Digits { + if _, err := parseIntBytes(buf[start:i-1], 10, 64); err != nil { + return i, fmt.Errorf("unable to parse integer %s: %s", buf[start:i-1], err) + } + } + } else if isUnsigned { + // Return an error if uint64 support has not been enabled. + if !enableUint64Support { + return i, ErrInvalidNumber + } + // Make sure the last char is a 'u' for unsigned + if buf[i-1] != 'u' { + return i, ErrInvalidNumber + } + // Make sure the first char is not a '-' for unsigned + if buf[start] == '-' { + return i, ErrInvalidNumber + } + // Parse the uint to check bounds the number of digits could be larger than the max range + // We subtract 1 from the index to remove the `u` from our tests + if len(buf[start:i-1]) >= maxUint64Digits { + if _, err := parseUintBytes(buf[start:i-1], 10, 64); err != nil { + return i, fmt.Errorf("unable to parse unsigned %s: %s", buf[start:i-1], err) + } + } + } else { + // Parse the float to check bounds if it's scientific or the number of digits could be larger than the max range + if scientific || len(buf[start:i]) >= maxFloat64Digits || len(buf[start:i]) >= minFloat64Digits { + if _, err := parseFloatBytes(buf[start:i], 10); err != nil { + return i, fmt.Errorf("invalid float") + } + } + } + + return i, nil +} + +// scanBoolean returns the end position within buf, start at i after +// scanning over buf for boolean. Valid values for a boolean are +// t, T, true, TRUE, f, F, false, FALSE. It returns an error if a invalid boolean +// is scanned. +func scanBoolean(buf []byte, i int) (int, []byte, error) { + start := i + + if i < len(buf) && (buf[i] != 't' && buf[i] != 'f' && buf[i] != 'T' && buf[i] != 'F') { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + i++ + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' || buf[i] == ' ' { + break + } + i++ + } + + // Single char bool (t, T, f, F) is ok + if i-start == 1 { + return i, buf[start:i], nil + } + + // length must be 4 for true or TRUE + if (buf[start] == 't' || buf[start] == 'T') && i-start != 4 { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + // length must be 5 for false or FALSE + if (buf[start] == 'f' || buf[start] == 'F') && i-start != 5 { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + // Otherwise + valid := false + switch buf[start] { + case 't': + valid = bytes.Equal(buf[start:i], []byte("true")) + case 'f': + valid = bytes.Equal(buf[start:i], []byte("false")) + case 'T': + valid = bytes.Equal(buf[start:i], []byte("TRUE")) || bytes.Equal(buf[start:i], []byte("True")) + case 'F': + valid = bytes.Equal(buf[start:i], []byte("FALSE")) || bytes.Equal(buf[start:i], []byte("False")) + } + + if !valid { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + return i, buf[start:i], nil + +} + +// skipWhitespace returns the end position within buf, starting at i after +// scanning over spaces in tags. +func skipWhitespace(buf []byte, i int) int { + for i < len(buf) { + if buf[i] != ' ' && buf[i] != '\t' && buf[i] != 0 { + break + } + i++ + } + return i +} + +// scanLine returns the end position in buf and the next line found within +// buf. +func scanLine(buf []byte, i int) (int, []byte) { + start := i + quoted := false + fields := false + + // tracks how many '=' and commas we've seen + // this duplicates some of the functionality in scanFields + equals := 0 + commas := 0 + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // skip past escaped characters + if buf[i] == '\\' && i+2 < len(buf) { + i += 2 + continue + } + + if buf[i] == ' ' { + fields = true + } + + // If we see a double quote, makes sure it is not escaped + if fields { + if !quoted && buf[i] == '=' { + i++ + equals++ + continue + } else if !quoted && buf[i] == ',' { + i++ + commas++ + continue + } else if buf[i] == '"' && equals > commas { + i++ + quoted = !quoted + continue + } + } + + if buf[i] == '\n' && !quoted { + break + } + + i++ + } + + return i, buf[start:i] +} + +// scanTo returns the end position in buf and the next consecutive block +// of bytes, starting from i and ending with stop byte, where stop byte +// has not been escaped. +// +// If there are leading spaces, they are skipped. +func scanTo(buf []byte, i int, stop byte) (int, []byte) { + start := i + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // Reached unescaped stop value? + if buf[i] == stop && (i == 0 || buf[i-1] != '\\') { + break + } + i++ + } + + return i, buf[start:i] +} + +// scanTo returns the end position in buf and the next consecutive block +// of bytes, starting from i and ending with stop byte. If there are leading +// spaces, they are skipped. +func scanToSpaceOr(buf []byte, i int, stop byte) (int, []byte) { + start := i + if buf[i] == stop || buf[i] == ' ' { + return i, buf[start:i] + } + + for { + i++ + if buf[i-1] == '\\' { + continue + } + + // reached the end of buf? + if i >= len(buf) { + return i, buf[start:i] + } + + // reached end of block? + if buf[i] == stop || buf[i] == ' ' { + return i, buf[start:i] + } + } +} + +func scanTagValue(buf []byte, i int) (int, []byte) { + start := i + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' && buf[i-1] != '\\' { + break + } + i++ + } + if i > len(buf) { + return i, nil + } + return i, buf[start:i] +} + +func scanFieldValue(buf []byte, i int) (int, []byte) { + start := i + quoted := false + for i < len(buf) { + // Only escape char for a field value is a double-quote and backslash + if buf[i] == '\\' && i+1 < len(buf) && (buf[i+1] == '"' || buf[i+1] == '\\') { + i += 2 + continue + } + + // Quoted value? (e.g. string) + if buf[i] == '"' { + i++ + quoted = !quoted + continue + } + + if buf[i] == ',' && !quoted { + break + } + i++ + } + return i, buf[start:i] +} + +func EscapeMeasurement(in []byte) []byte { + for b, esc := range measurementEscapeCodes { + in = bytes.Replace(in, []byte{b}, esc, -1) + } + return in +} + +func unescapeMeasurement(in []byte) []byte { + for b, esc := range measurementEscapeCodes { + in = bytes.Replace(in, esc, []byte{b}, -1) + } + return in +} + +func escapeTag(in []byte) []byte { + for b, esc := range tagEscapeCodes { + if bytes.IndexByte(in, b) != -1 { + in = bytes.Replace(in, []byte{b}, esc, -1) + } + } + return in +} + +func unescapeTag(in []byte) []byte { + if bytes.IndexByte(in, '\\') == -1 { + return in + } + + for b, esc := range tagEscapeCodes { + if bytes.IndexByte(in, b) != -1 { + in = bytes.Replace(in, esc, []byte{b}, -1) + } + } + return in +} + +// escapeStringFieldReplacer replaces double quotes and backslashes +// with the same character preceded by a backslash. +// As of Go 1.7 this benchmarked better in allocations and CPU time +// compared to iterating through a string byte-by-byte and appending to a new byte slice, +// calling strings.Replace twice, and better than (*Regex).ReplaceAllString. +var escapeStringFieldReplacer = strings.NewReplacer(`"`, `\"`, `\`, `\\`) + +// EscapeStringField returns a copy of in with any double quotes or +// backslashes with escaped values. +func EscapeStringField(in string) string { + return escapeStringFieldReplacer.Replace(in) +} + +// unescapeStringField returns a copy of in with any escaped double-quotes +// or backslashes unescaped. +func unescapeStringField(in string) string { + if strings.IndexByte(in, '\\') == -1 { + return in + } + + var out []byte + i := 0 + for { + if i >= len(in) { + break + } + // unescape backslashes + if in[i] == '\\' && i+1 < len(in) && in[i+1] == '\\' { + out = append(out, '\\') + i += 2 + continue + } + // unescape double-quotes + if in[i] == '\\' && i+1 < len(in) && in[i+1] == '"' { + out = append(out, '"') + i += 2 + continue + } + out = append(out, in[i]) + i++ + + } + return string(out) +} + +// NewPoint returns a new point with the given measurement name, tags, fields and timestamp. If +// an unsupported field value (NaN) or out of range time is passed, this function returns an error. +func NewPoint(name string, tags Tags, fields Fields, t time.Time) (Point, error) { + key, err := pointKey(name, tags, fields, t) + if err != nil { + return nil, err + } + + return &point{ + key: key, + time: t, + fields: fields.MarshalBinary(), + }, nil +} + +// pointKey checks some basic requirements for valid points, and returns the +// key, along with an possible error. +func pointKey(measurement string, tags Tags, fields Fields, t time.Time) ([]byte, error) { + if len(fields) == 0 { + return nil, ErrPointMustHaveAField + } + + if !t.IsZero() { + if err := CheckTime(t); err != nil { + return nil, err + } + } + + for key, value := range fields { + switch value := value.(type) { + case float64: + // Ensure the caller validates and handles invalid field values + if math.IsNaN(value) { + return nil, fmt.Errorf("NaN is an unsupported value for field %s", key) + } + case float32: + // Ensure the caller validates and handles invalid field values + if math.IsNaN(float64(value)) { + return nil, fmt.Errorf("NaN is an unsupported value for field %s", key) + } + } + if len(key) == 0 { + return nil, fmt.Errorf("all fields must have non-empty names") + } + } + + key := MakeKey([]byte(measurement), tags) + for field := range fields { + sz := seriesKeySize(key, []byte(field)) + if sz > MaxKeyLength { + return nil, fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength) + } + } + + return key, nil +} + +func seriesKeySize(key, field []byte) int { + // 4 is the length of the tsm1.fieldKeySeparator constant. It's inlined here to avoid a circular + // dependency. + return len(key) + 4 + len(field) +} + +// NewPointFromBytes returns a new Point from a marshalled Point. +func NewPointFromBytes(b []byte) (Point, error) { + p := &point{} + if err := p.UnmarshalBinary(b); err != nil { + return nil, err + } + + // This does some basic validation to ensure there are fields and they + // can be unmarshalled as well. + iter := p.FieldIterator() + var hasField bool + for iter.Next() { + if len(iter.FieldKey()) == 0 { + continue + } + hasField = true + switch iter.Type() { + case Float: + _, err := iter.FloatValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case Integer: + _, err := iter.IntegerValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case Unsigned: + _, err := iter.UnsignedValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case String: + // Skip since this won't return an error + case Boolean: + _, err := iter.BooleanValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + } + } + + if !hasField { + return nil, ErrPointMustHaveAField + } + + return p, nil +} + +// MustNewPoint returns a new point with the given measurement name, tags, fields and timestamp. If +// an unsupported field value (NaN) is passed, this function panics. +func MustNewPoint(name string, tags Tags, fields Fields, time time.Time) Point { + pt, err := NewPoint(name, tags, fields, time) + if err != nil { + panic(err.Error()) + } + return pt +} + +// Key returns the key (measurement joined with tags) of the point. +func (p *point) Key() []byte { + return p.key +} + +func (p *point) name() []byte { + _, name := scanTo(p.key, 0, ',') + return name +} + +func (p *point) Name() []byte { + return escape.Unescape(p.name()) +} + +// SetName updates the measurement name for the point. +func (p *point) SetName(name string) { + p.cachedName = "" + p.key = MakeKey([]byte(name), p.Tags()) +} + +// Time return the timestamp for the point. +func (p *point) Time() time.Time { + return p.time +} + +// SetTime updates the timestamp for the point. +func (p *point) SetTime(t time.Time) { + p.time = t +} + +// Round will round the timestamp of the point to the given duration. +func (p *point) Round(d time.Duration) { + p.time = p.time.Round(d) +} + +// Tags returns the tag set for the point. +func (p *point) Tags() Tags { + if p.cachedTags != nil { + return p.cachedTags + } + p.cachedTags = parseTags(p.key) + return p.cachedTags +} + +func (p *point) HasTag(tag []byte) bool { + if len(p.key) == 0 { + return false + } + + var exists bool + walkTags(p.key, func(key, value []byte) bool { + if bytes.Equal(tag, key) { + exists = true + return false + } + return true + }) + + return exists +} + +func walkTags(buf []byte, fn func(key, value []byte) bool) { + if len(buf) == 0 { + return + } + + pos, name := scanTo(buf, 0, ',') + + // it's an empty key, so there are no tags + if len(name) == 0 { + return + } + + hasEscape := bytes.IndexByte(buf, '\\') != -1 + i := pos + 1 + var key, value []byte + for { + if i >= len(buf) { + break + } + i, key = scanTo(buf, i, '=') + i, value = scanTagValue(buf, i+1) + + if len(value) == 0 { + continue + } + + if hasEscape { + if !fn(unescapeTag(key), unescapeTag(value)) { + return + } + } else { + if !fn(key, value) { + return + } + } + + i++ + } +} + +// walkFields walks each field key and value via fn. If fn returns false, the iteration +// is stopped. The values are the raw byte slices and not the converted types. +func walkFields(buf []byte, fn func(key, value []byte) bool) { + var i int + var key, val []byte + for len(buf) > 0 { + i, key = scanTo(buf, 0, '=') + buf = buf[i+1:] + i, val = scanFieldValue(buf, 0) + buf = buf[i:] + if !fn(key, val) { + break + } + + // slice off comma + if len(buf) > 0 { + buf = buf[1:] + } + } +} + +func parseTags(buf []byte) Tags { + if len(buf) == 0 { + return nil + } + + tags := make(Tags, bytes.Count(buf, []byte(","))) + p := 0 + walkTags(buf, func(key, value []byte) bool { + tags[p].Key = key + tags[p].Value = value + p++ + return true + }) + return tags +} + +// MakeKey creates a key for a set of tags. +func MakeKey(name []byte, tags Tags) []byte { + // unescape the name and then re-escape it to avoid double escaping. + // The key should always be stored in escaped form. + return append(EscapeMeasurement(unescapeMeasurement(name)), tags.HashKey()...) +} + +// SetTags replaces the tags for the point. +func (p *point) SetTags(tags Tags) { + p.key = MakeKey(p.Name(), tags) + p.cachedTags = tags +} + +// AddTag adds or replaces a tag value for a point. +func (p *point) AddTag(key, value string) { + tags := p.Tags() + tags = append(tags, Tag{Key: []byte(key), Value: []byte(value)}) + sort.Sort(tags) + p.cachedTags = tags + p.key = MakeKey(p.Name(), tags) +} + +// Fields returns the fields for the point. +func (p *point) Fields() (Fields, error) { + if p.cachedFields != nil { + return p.cachedFields, nil + } + cf, err := p.unmarshalBinary() + if err != nil { + return nil, err + } + p.cachedFields = cf + return p.cachedFields, nil +} + +// SetPrecision will round a time to the specified precision. +func (p *point) SetPrecision(precision string) { + switch precision { + case "n": + case "u": + p.SetTime(p.Time().Truncate(time.Microsecond)) + case "ms": + p.SetTime(p.Time().Truncate(time.Millisecond)) + case "s": + p.SetTime(p.Time().Truncate(time.Second)) + case "m": + p.SetTime(p.Time().Truncate(time.Minute)) + case "h": + p.SetTime(p.Time().Truncate(time.Hour)) + } +} + +// String returns the string representation of the point. +func (p *point) String() string { + if p.Time().IsZero() { + return string(p.Key()) + " " + string(p.fields) + } + return string(p.Key()) + " " + string(p.fields) + " " + strconv.FormatInt(p.UnixNano(), 10) +} + +// AppendString appends the string representation of the point to buf. +func (p *point) AppendString(buf []byte) []byte { + buf = append(buf, p.key...) + buf = append(buf, ' ') + buf = append(buf, p.fields...) + + if !p.time.IsZero() { + buf = append(buf, ' ') + buf = strconv.AppendInt(buf, p.UnixNano(), 10) + } + + return buf +} + +// StringSize returns the length of the string that would be returned by String(). +func (p *point) StringSize() int { + size := len(p.key) + len(p.fields) + 1 + + if !p.time.IsZero() { + digits := 1 // even "0" has one digit + t := p.UnixNano() + if t < 0 { + // account for negative sign, then negate + digits++ + t = -t + } + for t > 9 { // already accounted for one digit + digits++ + t /= 10 + } + size += digits + 1 // digits and a space + } + + return size +} + +// MarshalBinary returns a binary representation of the point. +func (p *point) MarshalBinary() ([]byte, error) { + if len(p.fields) == 0 { + return nil, ErrPointMustHaveAField + } + + tb, err := p.time.MarshalBinary() + if err != nil { + return nil, err + } + + b := make([]byte, 8+len(p.key)+len(p.fields)+len(tb)) + i := 0 + + binary.BigEndian.PutUint32(b[i:], uint32(len(p.key))) + i += 4 + + i += copy(b[i:], p.key) + + binary.BigEndian.PutUint32(b[i:i+4], uint32(len(p.fields))) + i += 4 + + i += copy(b[i:], p.fields) + + copy(b[i:], tb) + return b, nil +} + +// UnmarshalBinary decodes a binary representation of the point into a point struct. +func (p *point) UnmarshalBinary(b []byte) error { + var n int + + // Read key length. + if len(b) < 4 { + return io.ErrShortBuffer + } + n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:] + + // Read key. + if len(b) < n { + return io.ErrShortBuffer + } + p.key, b = b[:n], b[n:] + + // Read fields length. + if len(b) < 4 { + return io.ErrShortBuffer + } + n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:] + + // Read fields. + if len(b) < n { + return io.ErrShortBuffer + } + p.fields, b = b[:n], b[n:] + + // Read timestamp. + if err := p.time.UnmarshalBinary(b); err != nil { + return err + } + return nil +} + +// PrecisionString returns a string representation of the point. If there +// is a timestamp associated with the point then it will be specified in the +// given unit. +func (p *point) PrecisionString(precision string) string { + if p.Time().IsZero() { + return fmt.Sprintf("%s %s", p.Key(), string(p.fields)) + } + return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields), + p.UnixNano()/GetPrecisionMultiplier(precision)) +} + +// RoundedString returns a string representation of the point. If there +// is a timestamp associated with the point, then it will be rounded to the +// given duration. +func (p *point) RoundedString(d time.Duration) string { + if p.Time().IsZero() { + return fmt.Sprintf("%s %s", p.Key(), string(p.fields)) + } + return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields), + p.time.Round(d).UnixNano()) +} + +func (p *point) unmarshalBinary() (Fields, error) { + iter := p.FieldIterator() + fields := make(Fields, 8) + for iter.Next() { + if len(iter.FieldKey()) == 0 { + continue + } + switch iter.Type() { + case Float: + v, err := iter.FloatValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case Integer: + v, err := iter.IntegerValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case Unsigned: + v, err := iter.UnsignedValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case String: + fields[string(iter.FieldKey())] = iter.StringValue() + case Boolean: + v, err := iter.BooleanValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + } + } + return fields, nil +} + +// HashID returns a non-cryptographic checksum of the point's key. +func (p *point) HashID() uint64 { + h := NewInlineFNV64a() + h.Write(p.key) + sum := h.Sum64() + return sum +} + +// UnixNano returns the timestamp of the point as nanoseconds since Unix epoch. +func (p *point) UnixNano() int64 { + return p.Time().UnixNano() +} + +// Split will attempt to return multiple points with the same timestamp whose +// string representations are no longer than size. Points with a single field or +// a point without a timestamp may exceed the requested size. +func (p *point) Split(size int) []Point { + if p.time.IsZero() || p.StringSize() <= size { + return []Point{p} + } + + // key string, timestamp string, spaces + size -= len(p.key) + len(strconv.FormatInt(p.time.UnixNano(), 10)) + 2 + + var points []Point + var start, cur int + + for cur < len(p.fields) { + end, _ := scanTo(p.fields, cur, '=') + end, _ = scanFieldValue(p.fields, end+1) + + if cur > start && end-start > size { + points = append(points, &point{ + key: p.key, + time: p.time, + fields: p.fields[start : cur-1], + }) + start = cur + } + + cur = end + 1 + } + + points = append(points, &point{ + key: p.key, + time: p.time, + fields: p.fields[start:], + }) + + return points +} + +// Tag represents a single key/value tag pair. +type Tag struct { + Key []byte + Value []byte +} + +// NewTag returns a new Tag. +func NewTag(key, value []byte) Tag { + return Tag{ + Key: key, + Value: value, + } +} + +// Size returns the size of the key and value. +func (t Tag) Size() int { return len(t.Key) + len(t.Value) } + +// Clone returns a shallow copy of Tag. +// +// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed. +// Use Clone to create a Tag with new byte slices that do not refer to the argument to ParsePointsWithPrecision. +func (t Tag) Clone() Tag { + other := Tag{ + Key: make([]byte, len(t.Key)), + Value: make([]byte, len(t.Value)), + } + + copy(other.Key, t.Key) + copy(other.Value, t.Value) + + return other +} + +// String returns the string reprsentation of the tag. +func (t *Tag) String() string { + var buf bytes.Buffer + buf.WriteByte('{') + buf.WriteString(string(t.Key)) + buf.WriteByte(' ') + buf.WriteString(string(t.Value)) + buf.WriteByte('}') + return buf.String() +} + +// Tags represents a sorted list of tags. +type Tags []Tag + +// NewTags returns a new Tags from a map. +func NewTags(m map[string]string) Tags { + if len(m) == 0 { + return nil + } + a := make(Tags, 0, len(m)) + for k, v := range m { + a = append(a, NewTag([]byte(k), []byte(v))) + } + sort.Sort(a) + return a +} + +// Keys returns the list of keys for a tag set. +func (a Tags) Keys() []string { + if len(a) == 0 { + return nil + } + keys := make([]string, len(a)) + for i, tag := range a { + keys[i] = string(tag.Key) + } + return keys +} + +// Values returns the list of values for a tag set. +func (a Tags) Values() []string { + if len(a) == 0 { + return nil + } + values := make([]string, len(a)) + for i, tag := range a { + values[i] = string(tag.Value) + } + return values +} + +// String returns the string representation of the tags. +func (a Tags) String() string { + var buf bytes.Buffer + buf.WriteByte('[') + for i := range a { + buf.WriteString(a[i].String()) + if i < len(a)-1 { + buf.WriteByte(' ') + } + } + buf.WriteByte(']') + return buf.String() +} + +// Size returns the number of bytes needed to store all tags. Note, this is +// the number of bytes needed to store all keys and values and does not account +// for data structures or delimiters for example. +func (a Tags) Size() int { + var total int + for _, t := range a { + total += t.Size() + } + return total +} + +// Clone returns a copy of the slice where the elements are a result of calling `Clone` on the original elements +// +// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed. +// Use Clone to create Tags with new byte slices that do not refer to the argument to ParsePointsWithPrecision. +func (a Tags) Clone() Tags { + if len(a) == 0 { + return nil + } + + others := make(Tags, len(a)) + for i := range a { + others[i] = a[i].Clone() + } + + return others +} + +func (a Tags) Len() int { return len(a) } +func (a Tags) Less(i, j int) bool { return bytes.Compare(a[i].Key, a[j].Key) == -1 } +func (a Tags) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// Equal returns true if a equals other. +func (a Tags) Equal(other Tags) bool { + if len(a) != len(other) { + return false + } + for i := range a { + if !bytes.Equal(a[i].Key, other[i].Key) || !bytes.Equal(a[i].Value, other[i].Value) { + return false + } + } + return true +} + +// CompareTags returns -1 if a < b, 1 if a > b, and 0 if a == b. +func CompareTags(a, b Tags) int { + // Compare each key & value until a mismatch. + for i := 0; i < len(a) && i < len(b); i++ { + if cmp := bytes.Compare(a[i].Key, b[i].Key); cmp != 0 { + return cmp + } + if cmp := bytes.Compare(a[i].Value, b[i].Value); cmp != 0 { + return cmp + } + } + + // If all tags are equal up to this point then return shorter tagset. + if len(a) < len(b) { + return -1 + } else if len(a) > len(b) { + return 1 + } + + // All tags are equal. + return 0 +} + +// Get returns the value for a key. +func (a Tags) Get(key []byte) []byte { + // OPTIMIZE: Use sort.Search if tagset is large. + + for _, t := range a { + if bytes.Equal(t.Key, key) { + return t.Value + } + } + return nil +} + +// GetString returns the string value for a string key. +func (a Tags) GetString(key string) string { + return string(a.Get([]byte(key))) +} + +// Set sets the value for a key. +func (a *Tags) Set(key, value []byte) { + for i, t := range *a { + if bytes.Equal(t.Key, key) { + (*a)[i].Value = value + return + } + } + *a = append(*a, Tag{Key: key, Value: value}) + sort.Sort(*a) +} + +// SetString sets the string value for a string key. +func (a *Tags) SetString(key, value string) { + a.Set([]byte(key), []byte(value)) +} + +// Delete removes a tag by key. +func (a *Tags) Delete(key []byte) { + for i, t := range *a { + if bytes.Equal(t.Key, key) { + copy((*a)[i:], (*a)[i+1:]) + (*a)[len(*a)-1] = Tag{} + *a = (*a)[:len(*a)-1] + return + } + } +} + +// Map returns a map representation of the tags. +func (a Tags) Map() map[string]string { + m := make(map[string]string, len(a)) + for _, t := range a { + m[string(t.Key)] = string(t.Value) + } + return m +} + +// Merge merges the tags combining the two. If both define a tag with the +// same key, the merged value overwrites the old value. +// A new map is returned. +func (a Tags) Merge(other map[string]string) Tags { + merged := make(map[string]string, len(a)+len(other)) + for _, t := range a { + merged[string(t.Key)] = string(t.Value) + } + for k, v := range other { + merged[k] = v + } + return NewTags(merged) +} + +// HashKey hashes all of a tag's keys. +func (a Tags) HashKey() []byte { + // Empty maps marshal to empty bytes. + if len(a) == 0 { + return nil + } + + // Type invariant: Tags are sorted + + escaped := make(Tags, 0, len(a)) + sz := 0 + for _, t := range a { + ek := escapeTag(t.Key) + ev := escapeTag(t.Value) + + if len(ev) > 0 { + escaped = append(escaped, Tag{Key: ek, Value: ev}) + sz += len(ek) + len(ev) + } + } + + sz += len(escaped) + (len(escaped) * 2) // separators + + // Generate marshaled bytes. + b := make([]byte, sz) + buf := b + idx := 0 + for _, k := range escaped { + buf[idx] = ',' + idx++ + copy(buf[idx:idx+len(k.Key)], k.Key) + idx += len(k.Key) + buf[idx] = '=' + idx++ + copy(buf[idx:idx+len(k.Value)], k.Value) + idx += len(k.Value) + } + return b[:idx] +} + +// CopyTags returns a shallow copy of tags. +func CopyTags(a Tags) Tags { + other := make(Tags, len(a)) + copy(other, a) + return other +} + +// DeepCopyTags returns a deep copy of tags. +func DeepCopyTags(a Tags) Tags { + // Calculate size of keys/values in bytes. + var n int + for _, t := range a { + n += len(t.Key) + len(t.Value) + } + + // Build single allocation for all key/values. + buf := make([]byte, n) + + // Copy tags to new set. + other := make(Tags, len(a)) + for i, t := range a { + copy(buf, t.Key) + other[i].Key, buf = buf[:len(t.Key)], buf[len(t.Key):] + + copy(buf, t.Value) + other[i].Value, buf = buf[:len(t.Value)], buf[len(t.Value):] + } + + return other +} + +// Fields represents a mapping between a Point's field names and their +// values. +type Fields map[string]interface{} + +// FieldIterator retuns a FieldIterator that can be used to traverse the +// fields of a point without constructing the in-memory map. +func (p *point) FieldIterator() FieldIterator { + p.Reset() + return p +} + +type fieldIterator struct { + start, end int + key, keybuf []byte + valueBuf []byte + fieldType FieldType +} + +// Next indicates whether there any fields remaining. +func (p *point) Next() bool { + p.it.start = p.it.end + if p.it.start >= len(p.fields) { + return false + } + + p.it.end, p.it.key = scanTo(p.fields, p.it.start, '=') + if escape.IsEscaped(p.it.key) { + p.it.keybuf = escape.AppendUnescaped(p.it.keybuf[:0], p.it.key) + p.it.key = p.it.keybuf + } + + p.it.end, p.it.valueBuf = scanFieldValue(p.fields, p.it.end+1) + p.it.end++ + + if len(p.it.valueBuf) == 0 { + p.it.fieldType = Empty + return true + } + + c := p.it.valueBuf[0] + + if c == '"' { + p.it.fieldType = String + return true + } + + if strings.IndexByte(`0123456789-.nNiIu`, c) >= 0 { + if p.it.valueBuf[len(p.it.valueBuf)-1] == 'i' { + p.it.fieldType = Integer + p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1] + } else if p.it.valueBuf[len(p.it.valueBuf)-1] == 'u' { + p.it.fieldType = Unsigned + p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1] + } else { + p.it.fieldType = Float + } + return true + } + + // to keep the same behavior that currently exists, default to boolean + p.it.fieldType = Boolean + return true +} + +// FieldKey returns the key of the current field. +func (p *point) FieldKey() []byte { + return p.it.key +} + +// Type returns the FieldType of the current field. +func (p *point) Type() FieldType { + return p.it.fieldType +} + +// StringValue returns the string value of the current field. +func (p *point) StringValue() string { + return unescapeStringField(string(p.it.valueBuf[1 : len(p.it.valueBuf)-1])) +} + +// IntegerValue returns the integer value of the current field. +func (p *point) IntegerValue() (int64, error) { + n, err := parseIntBytes(p.it.valueBuf, 10, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse integer value %q: %v", p.it.valueBuf, err) + } + return n, nil +} + +// UnsignedValue returns the unsigned value of the current field. +func (p *point) UnsignedValue() (uint64, error) { + n, err := parseUintBytes(p.it.valueBuf, 10, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse unsigned value %q: %v", p.it.valueBuf, err) + } + return n, nil +} + +// BooleanValue returns the boolean value of the current field. +func (p *point) BooleanValue() (bool, error) { + b, err := parseBoolBytes(p.it.valueBuf) + if err != nil { + return false, fmt.Errorf("unable to parse bool value %q: %v", p.it.valueBuf, err) + } + return b, nil +} + +// FloatValue returns the float value of the current field. +func (p *point) FloatValue() (float64, error) { + f, err := parseFloatBytes(p.it.valueBuf, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse floating point value %q: %v", p.it.valueBuf, err) + } + return f, nil +} + +// Reset resets the iterator to its initial state. +func (p *point) Reset() { + p.it.fieldType = Empty + p.it.key = nil + p.it.valueBuf = nil + p.it.start = 0 + p.it.end = 0 +} + +// MarshalBinary encodes all the fields to their proper type and returns the binary +// represenation +// NOTE: uint64 is specifically not supported due to potential overflow when we decode +// again later to an int64 +// NOTE2: uint is accepted, and may be 64 bits, and is for some reason accepted... +func (p Fields) MarshalBinary() []byte { + var b []byte + keys := make([]string, 0, len(p)) + + for k := range p { + keys = append(keys, k) + } + + // Not really necessary, can probably be removed. + sort.Strings(keys) + + for i, k := range keys { + if i > 0 { + b = append(b, ',') + } + b = appendField(b, k, p[k]) + } + + return b +} + +func appendField(b []byte, k string, v interface{}) []byte { + b = append(b, []byte(escape.String(k))...) + b = append(b, '=') + + // check popular types first + switch v := v.(type) { + case float64: + b = strconv.AppendFloat(b, v, 'f', -1, 64) + case int64: + b = strconv.AppendInt(b, v, 10) + b = append(b, 'i') + case string: + b = append(b, '"') + b = append(b, []byte(EscapeStringField(v))...) + b = append(b, '"') + case bool: + b = strconv.AppendBool(b, v) + case int32: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int16: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int8: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint64: + b = strconv.AppendUint(b, v, 10) + b = append(b, 'u') + case uint32: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint16: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint8: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint: + // TODO: 'uint' should be converted to writing as an unsigned integer, + // but we cannot since that would break backwards compatibility. + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case float32: + b = strconv.AppendFloat(b, float64(v), 'f', -1, 32) + case []byte: + b = append(b, v...) + case nil: + // skip + default: + // Can't determine the type, so convert to string + b = append(b, '"') + b = append(b, []byte(EscapeStringField(fmt.Sprintf("%v", v)))...) + b = append(b, '"') + + } + + return b +} + +type byteSlices [][]byte + +func (a byteSlices) Len() int { return len(a) } +func (a byteSlices) Less(i, j int) bool { return bytes.Compare(a[i], a[j]) == -1 } +func (a byteSlices) Swap(i, j int) { a[i], a[j] = a[j], a[i] } diff --git a/vendor/github.com/influxdata/influxdb/models/rows.go b/vendor/github.com/influxdata/influxdb/models/rows.go new file mode 100644 index 000000000..c087a4882 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/rows.go @@ -0,0 +1,62 @@ +package models + +import ( + "sort" +) + +// Row represents a single row returned from the execution of a statement. +type Row struct { + Name string `json:"name,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Columns []string `json:"columns,omitempty"` + Values [][]interface{} `json:"values,omitempty"` + Partial bool `json:"partial,omitempty"` +} + +// SameSeries returns true if r contains values for the same series as o. +func (r *Row) SameSeries(o *Row) bool { + return r.tagsHash() == o.tagsHash() && r.Name == o.Name +} + +// tagsHash returns a hash of tag key/value pairs. +func (r *Row) tagsHash() uint64 { + h := NewInlineFNV64a() + keys := r.tagsKeys() + for _, k := range keys { + h.Write([]byte(k)) + h.Write([]byte(r.Tags[k])) + } + return h.Sum64() +} + +// tagKeys returns a sorted list of tag keys. +func (r *Row) tagsKeys() []string { + a := make([]string, 0, len(r.Tags)) + for k := range r.Tags { + a = append(a, k) + } + sort.Strings(a) + return a +} + +// Rows represents a collection of rows. Rows implements sort.Interface. +type Rows []*Row + +// Len implements sort.Interface. +func (p Rows) Len() int { return len(p) } + +// Less implements sort.Interface. +func (p Rows) Less(i, j int) bool { + // Sort by name first. + if p[i].Name != p[j].Name { + return p[i].Name < p[j].Name + } + + // Sort by tag set hash. Tags don't have a meaningful sort order so we + // just compute a hash and sort by that instead. This allows the tests + // to receive rows in a predictable order every time. + return p[i].tagsHash() < p[j].tagsHash() +} + +// Swap implements sort.Interface. +func (p Rows) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/vendor/github.com/influxdata/influxdb/models/statistic.go b/vendor/github.com/influxdata/influxdb/models/statistic.go new file mode 100644 index 000000000..553e9d09f --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/statistic.go @@ -0,0 +1,42 @@ +package models + +// Statistic is the representation of a statistic used by the monitoring service. +type Statistic struct { + Name string `json:"name"` + Tags map[string]string `json:"tags"` + Values map[string]interface{} `json:"values"` +} + +// NewStatistic returns an initialized Statistic. +func NewStatistic(name string) Statistic { + return Statistic{ + Name: name, + Tags: make(map[string]string), + Values: make(map[string]interface{}), + } +} + +// StatisticTags is a map that can be merged with others without causing +// mutations to either map. +type StatisticTags map[string]string + +// Merge creates a new map containing the merged contents of tags and t. +// If both tags and the receiver map contain the same key, the value in tags +// is used in the resulting map. +// +// Merge always returns a usable map. +func (t StatisticTags) Merge(tags map[string]string) map[string]string { + // Add everything in tags to the result. + out := make(map[string]string, len(tags)) + for k, v := range tags { + out[k] = v + } + + // Only add values from t that don't appear in tags. + for k, v := range t { + if _, ok := tags[k]; !ok { + out[k] = v + } + } + return out +} diff --git a/vendor/github.com/influxdata/influxdb/models/time.go b/vendor/github.com/influxdata/influxdb/models/time.go new file mode 100644 index 000000000..e98f2cb33 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/time.go @@ -0,0 +1,74 @@ +package models + +// Helper time methods since parsing time can easily overflow and we only support a +// specific time range. + +import ( + "fmt" + "math" + "time" +) + +const ( + // MinNanoTime is the minumum time that can be represented. + // + // 1677-09-21 00:12:43.145224194 +0000 UTC + // + // The two lowest minimum integers are used as sentinel values. The + // minimum value needs to be used as a value lower than any other value for + // comparisons and another separate value is needed to act as a sentinel + // default value that is unusable by the user, but usable internally. + // Because these two values need to be used for a special purpose, we do + // not allow users to write points at these two times. + MinNanoTime = int64(math.MinInt64) + 2 + + // MaxNanoTime is the maximum time that can be represented. + // + // 2262-04-11 23:47:16.854775806 +0000 UTC + // + // The highest time represented by a nanosecond needs to be used for an + // exclusive range in the shard group, so the maximum time needs to be one + // less than the possible maximum number of nanoseconds representable by an + // int64 so that we don't lose a point at that one time. + MaxNanoTime = int64(math.MaxInt64) - 1 +) + +var ( + minNanoTime = time.Unix(0, MinNanoTime).UTC() + maxNanoTime = time.Unix(0, MaxNanoTime).UTC() + + // ErrTimeOutOfRange gets returned when time is out of the representable range using int64 nanoseconds since the epoch. + ErrTimeOutOfRange = fmt.Errorf("time outside range %d - %d", MinNanoTime, MaxNanoTime) +) + +// SafeCalcTime safely calculates the time given. Will return error if the time is outside the +// supported range. +func SafeCalcTime(timestamp int64, precision string) (time.Time, error) { + mult := GetPrecisionMultiplier(precision) + if t, ok := safeSignedMult(timestamp, mult); ok { + tme := time.Unix(0, t).UTC() + return tme, CheckTime(tme) + } + + return time.Time{}, ErrTimeOutOfRange +} + +// CheckTime checks that a time is within the safe range. +func CheckTime(t time.Time) error { + if t.Before(minNanoTime) || t.After(maxNanoTime) { + return ErrTimeOutOfRange + } + return nil +} + +// Perform the multiplication and check to make sure it didn't overflow. +func safeSignedMult(a, b int64) (int64, bool) { + if a == 0 || b == 0 || a == 1 || b == 1 { + return a * b, true + } + if a == MinNanoTime || b == MaxNanoTime { + return 0, false + } + c := a * b + return c, c/b == a +} diff --git a/vendor/github.com/influxdata/influxdb/models/uint_support.go b/vendor/github.com/influxdata/influxdb/models/uint_support.go new file mode 100644 index 000000000..18d1ca06e --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/models/uint_support.go @@ -0,0 +1,7 @@ +// +build uint uint64 + +package models + +func init() { + EnableUintSupport() +} diff --git a/vendor/github.com/influxdata/influxdb/pkg/escape/bytes.go b/vendor/github.com/influxdata/influxdb/pkg/escape/bytes.go new file mode 100644 index 000000000..f3b31f42d --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/pkg/escape/bytes.go @@ -0,0 +1,115 @@ +// Package escape contains utilities for escaping parts of InfluxQL +// and InfluxDB line protocol. +package escape // import "github.com/influxdata/influxdb/pkg/escape" + +import ( + "bytes" + "strings" +) + +// Codes is a map of bytes to be escaped. +var Codes = map[byte][]byte{ + ',': []byte(`\,`), + '"': []byte(`\"`), + ' ': []byte(`\ `), + '=': []byte(`\=`), +} + +// Bytes escapes characters on the input slice, as defined by Codes. +func Bytes(in []byte) []byte { + for b, esc := range Codes { + in = bytes.Replace(in, []byte{b}, esc, -1) + } + return in +} + +const escapeChars = `," =` + +// IsEscaped returns whether b has any escaped characters, +// i.e. whether b seems to have been processed by Bytes. +func IsEscaped(b []byte) bool { + for len(b) > 0 { + i := bytes.IndexByte(b, '\\') + if i < 0 { + return false + } + + if i+1 < len(b) && strings.IndexByte(escapeChars, b[i+1]) >= 0 { + return true + } + b = b[i+1:] + } + return false +} + +// AppendUnescaped appends the unescaped version of src to dst +// and returns the resulting slice. +func AppendUnescaped(dst, src []byte) []byte { + var pos int + for len(src) > 0 { + next := bytes.IndexByte(src[pos:], '\\') + if next < 0 || pos+next+1 >= len(src) { + return append(dst, src...) + } + + if pos+next+1 < len(src) && strings.IndexByte(escapeChars, src[pos+next+1]) >= 0 { + if pos+next > 0 { + dst = append(dst, src[:pos+next]...) + } + src = src[pos+next+1:] + pos = 0 + } else { + pos += next + 1 + } + } + + return dst +} + +// Unescape returns a new slice containing the unescaped version of in. +func Unescape(in []byte) []byte { + if len(in) == 0 { + return nil + } + + if bytes.IndexByte(in, '\\') == -1 { + return in + } + + i := 0 + inLen := len(in) + + // The output size will be no more than inLen. Preallocating the + // capacity of the output is faster and uses less memory than + // letting append() do its own (over)allocation. + out := make([]byte, 0, inLen) + + for { + if i >= inLen { + break + } + if in[i] == '\\' && i+1 < inLen { + switch in[i+1] { + case ',': + out = append(out, ',') + i += 2 + continue + case '"': + out = append(out, '"') + i += 2 + continue + case ' ': + out = append(out, ' ') + i += 2 + continue + case '=': + out = append(out, '=') + i += 2 + continue + } + } + out = append(out, in[i]) + i += 1 + } + return out +} diff --git a/vendor/github.com/influxdata/influxdb/pkg/escape/strings.go b/vendor/github.com/influxdata/influxdb/pkg/escape/strings.go new file mode 100644 index 000000000..db98033b0 --- /dev/null +++ b/vendor/github.com/influxdata/influxdb/pkg/escape/strings.go @@ -0,0 +1,21 @@ +package escape + +import "strings" + +var ( + escaper = strings.NewReplacer(`,`, `\,`, `"`, `\"`, ` `, `\ `, `=`, `\=`) + unescaper = strings.NewReplacer(`\,`, `,`, `\"`, `"`, `\ `, ` `, `\=`, `=`) +) + +// UnescapeString returns unescaped version of in. +func UnescapeString(in string) string { + if strings.IndexByte(in, '\\') == -1 { + return in + } + return unescaper.Replace(in) +} + +// String returns the escaped version of in. +func String(in string) string { + return escaper.Replace(in) +} diff --git a/vendor/github.com/rcrowley/go-metrics/json.go b/vendor/github.com/rcrowley/go-metrics/json.go deleted file mode 100644 index 2fdcbcfbf..000000000 --- a/vendor/github.com/rcrowley/go-metrics/json.go +++ /dev/null @@ -1,87 +0,0 @@ -package metrics - -import ( - "encoding/json" - "io" - "time" -) - -// MarshalJSON returns a byte slice containing a JSON representation of all -// the metrics in the Registry. -func (r *StandardRegistry) MarshalJSON() ([]byte, error) { - data := make(map[string]map[string]interface{}) - r.Each(func(name string, i interface{}) { - values := make(map[string]interface{}) - switch metric := i.(type) { - case Counter: - values["count"] = metric.Count() - case Gauge: - values["value"] = metric.Value() - case GaugeFloat64: - values["value"] = metric.Value() - case Healthcheck: - values["error"] = nil - metric.Check() - if err := metric.Error(); nil != err { - values["error"] = metric.Error().Error() - } - case Histogram: - h := metric.Snapshot() - ps := h.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) - values["count"] = h.Count() - values["min"] = h.Min() - values["max"] = h.Max() - values["mean"] = h.Mean() - values["stddev"] = h.StdDev() - values["median"] = ps[0] - values["75%"] = ps[1] - values["95%"] = ps[2] - values["99%"] = ps[3] - values["99.9%"] = ps[4] - case Meter: - m := metric.Snapshot() - values["count"] = m.Count() - values["1m.rate"] = m.Rate1() - values["5m.rate"] = m.Rate5() - values["15m.rate"] = m.Rate15() - values["mean.rate"] = m.RateMean() - case Timer: - t := metric.Snapshot() - ps := t.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) - values["count"] = t.Count() - values["min"] = t.Min() - values["max"] = t.Max() - values["mean"] = t.Mean() - values["stddev"] = t.StdDev() - values["median"] = ps[0] - values["75%"] = ps[1] - values["95%"] = ps[2] - values["99%"] = ps[3] - values["99.9%"] = ps[4] - values["1m.rate"] = t.Rate1() - values["5m.rate"] = t.Rate5() - values["15m.rate"] = t.Rate15() - values["mean.rate"] = t.RateMean() - } - data[name] = values - }) - return json.Marshal(data) -} - -// WriteJSON writes metrics from the given registry periodically to the -// specified io.Writer as JSON. -func WriteJSON(r Registry, d time.Duration, w io.Writer) { - for _ = range time.Tick(d) { - WriteJSONOnce(r, w) - } -} - -// WriteJSONOnce writes metrics from the given registry to the specified -// io.Writer as JSON. -func WriteJSONOnce(r Registry, w io.Writer) { - json.NewEncoder(w).Encode(r) -} - -func (p *PrefixedRegistry) MarshalJSON() ([]byte, error) { - return json.Marshal(p.underlying) -} diff --git a/vendor/github.com/rcrowley/go-metrics/metrics.go b/vendor/github.com/rcrowley/go-metrics/metrics.go deleted file mode 100644 index b97a49ed1..000000000 --- a/vendor/github.com/rcrowley/go-metrics/metrics.go +++ /dev/null @@ -1,13 +0,0 @@ -// Go port of Coda Hale's Metrics library -// -// <https://github.com/rcrowley/go-metrics> -// -// Coda Hale's original work: <https://github.com/codahale/metrics> -package metrics - -// UseNilMetrics is checked by the constructor functions for all of the -// standard metrics. If it is true, the metric returned is a stub. -// -// This global kill-switch helps quantify the observer effect and makes -// for less cluttered pprof profiles. -var UseNilMetrics bool = false diff --git a/vendor/github.com/rjeczalik/notify/watcher_fsevents_cgo.go b/vendor/github.com/rjeczalik/notify/watcher_fsevents_cgo.go index 2248a1b12..a2b332a2e 100644 --- a/vendor/github.com/rjeczalik/notify/watcher_fsevents_cgo.go +++ b/vendor/github.com/rjeczalik/notify/watcher_fsevents_cgo.go @@ -48,7 +48,7 @@ var wg sync.WaitGroup // used to wait until the runloop starts // started and is ready via the wg. It also serves purpose of a dummy source, // thanks to it the runloop does not return as it also has at least one source // registered. -var source = C.CFRunLoopSourceCreate(refZero, 0, &C.CFRunLoopSourceContext{ +var source = C.CFRunLoopSourceCreate(nil, 0, &C.CFRunLoopSourceContext{ perform: (C.CFRunLoopPerformCallBack)(C.gosource), }) @@ -162,8 +162,8 @@ func (s *stream) Start() error { return nil } wg.Wait() - p := C.CFStringCreateWithCStringNoCopy(refZero, C.CString(s.path), C.kCFStringEncodingUTF8, refZero) - path := C.CFArrayCreate(refZero, (*unsafe.Pointer)(unsafe.Pointer(&p)), 1, nil) + p := C.CFStringCreateWithCStringNoCopy(nil, C.CString(s.path), C.kCFStringEncodingUTF8, nil) + path := C.CFArrayCreate(nil, (*unsafe.Pointer)(unsafe.Pointer(&p)), 1, nil) ctx := C.FSEventStreamContext{} ref := C.EventStreamCreate(&ctx, C.uintptr_t(s.info), path, C.FSEventStreamEventId(atomic.LoadUint64(&since)), latency, flags) if ref == nilstream { diff --git a/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.10.go b/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.10.go deleted file mode 100644 index 0edd3782f..000000000 --- a/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.10.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2017 The Notify Authors. All rights reserved. -// Use of this source code is governed by the MIT license that can be -// found in the LICENSE file. - -// +build darwin,!kqueue,go1.10 - -package notify - -const refZero = 0 diff --git a/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.9.go b/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.9.go deleted file mode 100644 index b81c3c185..000000000 --- a/vendor/github.com/rjeczalik/notify/watcher_fsevents_go1.9.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2017 The Notify Authors. All rights reserved. -// Use of this source code is governed by the MIT license that can be -// found in the LICENSE file. - -// +build darwin,!kqueue,cgo,!go1.10 - -package notify - -/* -#include <CoreServices/CoreServices.h> -*/ -import "C" - -var refZero = (*C.struct___CFAllocator)(nil) diff --git a/vendor/vendor.json b/vendor/vendor.json index 436022329..134158995 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -207,6 +207,30 @@ "revisionTime": "2016-12-24T10:41:01Z" }, { + "checksumSHA1": "6tNwbL5tUS0dxYzADKVZtI2d/lE=", + "path": "github.com/influxdata/influxdb/client", + "revision": "a55dd0f50edd14c9c798d3564189eb4f53914309", + "revisionTime": "2017-10-09T17:24:46Z" + }, + { + "checksumSHA1": "O4XpbSNeUhSIMD2FWtQximJiFIs=", + "path": "github.com/influxdata/influxdb/client/v2", + "revision": "b36b9f109f2da91c8941679caf5356e08eee0b2b", + "revisionTime": "2018-01-17T01:42:09Z" + }, + { + "checksumSHA1": "cfumoC9gHEUROd+fA8qK3WLFAZQ=", + "path": "github.com/influxdata/influxdb/models", + "revision": "b36b9f109f2da91c8941679caf5356e08eee0b2b", + "revisionTime": "2018-01-17T01:42:09Z" + }, + { + "checksumSHA1": "Z0Bb5PWa5WL/j5Dm2KJCLGn1l7U=", + "path": "github.com/influxdata/influxdb/pkg/escape", + "revision": "01288bdb0883a01cac999326bd34421b29acaec8", + "revisionTime": "2018-02-21T22:33:40Z" + }, + { "checksumSHA1": "vTGKMIfiMwz43y5bsgx9PrL+AVw=", "path": "github.com/jackpal/go-nat-pmp", "revision": "1fa385a6f45828c83361136b45b1a21a12139493", @@ -310,22 +334,10 @@ "revisionTime": "2017-08-14T17:01:13Z" }, { - "checksumSHA1": "KAzbLjI9MzW2tjfcAsK75lVRp6I=", - "path": "github.com/rcrowley/go-metrics", - "revision": "1f30fe9094a513ce4c700b9a54458bbb0c96996c", - "revisionTime": "2016-11-28T21:05:44Z" - }, - { - "checksumSHA1": "q/d9nXRQYKEJ/EWn+5y6jL8rPGs=", - "path": "github.com/rcrowley/go-metrics/exp", - "revision": "1f30fe9094a513ce4c700b9a54458bbb0c96996c", - "revisionTime": "2016-11-28T21:05:44Z" - }, - { - "checksumSHA1": "1ESHllhZOIBg7MnlGHUdhz047bI=", + "checksumSHA1": "28UVHMmHx0iqO0XiJsjx+fwILyI=", "path": "github.com/rjeczalik/notify", - "revision": "27b537f07230b3f917421af6dcf044038dbe57e2", - "revisionTime": "2018-01-03T13:19:05Z" + "revision": "c31e5f2cb22b3e4ef3f882f413847669bf2652b9", + "revisionTime": "2018-02-03T14:01:15Z" }, { "checksumSHA1": "5uqO4ITTDMklKi3uNaE/D9LQ5nM=", diff --git a/whisper/mailserver/mailserver.go b/whisper/mailserver/mailserver.go index 6555fd5c0..57e6505ad 100644 --- a/whisper/mailserver/mailserver.go +++ b/whisper/mailserver/mailserver.go @@ -17,7 +17,6 @@ package mailserver import ( - "bytes" "encoding/binary" "fmt" @@ -108,17 +107,16 @@ func (s *WMailServer) DeliverMail(peer *whisper.Peer, request *whisper.Envelope) return } - ok, lower, upper, topic := s.validateRequest(peer.ID(), request) + ok, lower, upper, bloom := s.validateRequest(peer.ID(), request) if ok { - s.processRequest(peer, lower, upper, topic) + s.processRequest(peer, lower, upper, bloom) } } -func (s *WMailServer) processRequest(peer *whisper.Peer, lower, upper uint32, topic whisper.TopicType) []*whisper.Envelope { +func (s *WMailServer) processRequest(peer *whisper.Peer, lower, upper uint32, bloom []byte) []*whisper.Envelope { ret := make([]*whisper.Envelope, 0) var err error var zero common.Hash - var empty whisper.TopicType kl := NewDbKey(lower, zero) ku := NewDbKey(upper, zero) i := s.db.NewIterator(&util.Range{Start: kl.raw, Limit: ku.raw}, nil) @@ -131,7 +129,7 @@ func (s *WMailServer) processRequest(peer *whisper.Peer, lower, upper uint32, to log.Error(fmt.Sprintf("RLP decoding failed: %s", err)) } - if topic == empty || envelope.Topic == topic { + if whisper.BloomFilterMatch(bloom, envelope.Bloom()) { if peer == nil { // used for test purposes ret = append(ret, &envelope) @@ -153,39 +151,45 @@ func (s *WMailServer) processRequest(peer *whisper.Peer, lower, upper uint32, to return ret } -func (s *WMailServer) validateRequest(peerID []byte, request *whisper.Envelope) (bool, uint32, uint32, whisper.TopicType) { - var topic whisper.TopicType +func (s *WMailServer) validateRequest(peerID []byte, request *whisper.Envelope) (bool, uint32, uint32, []byte) { if s.pow > 0.0 && request.PoW() < s.pow { - return false, 0, 0, topic + return false, 0, 0, nil } f := whisper.Filter{KeySym: s.key} decrypted := request.Open(&f) if decrypted == nil { log.Warn(fmt.Sprintf("Failed to decrypt p2p request")) - return false, 0, 0, topic - } - - if len(decrypted.Payload) < 8 { - log.Warn(fmt.Sprintf("Undersized p2p request")) - return false, 0, 0, topic + return false, 0, 0, nil } src := crypto.FromECDSAPub(decrypted.Src) if len(src)-len(peerID) == 1 { src = src[1:] } - if !bytes.Equal(peerID, src) { + + // if you want to check the signature, you can do it here. e.g.: + // if !bytes.Equal(peerID, src) { + if src == nil { log.Warn(fmt.Sprintf("Wrong signature of p2p request")) - return false, 0, 0, topic + return false, 0, 0, nil } - lower := binary.BigEndian.Uint32(decrypted.Payload[:4]) - upper := binary.BigEndian.Uint32(decrypted.Payload[4:8]) - - if len(decrypted.Payload) >= 8+whisper.TopicLength { - topic = whisper.BytesToTopic(decrypted.Payload[8:]) + var bloom []byte + payloadSize := len(decrypted.Payload) + if payloadSize < 8 { + log.Warn(fmt.Sprintf("Undersized p2p request")) + return false, 0, 0, nil + } else if payloadSize == 8 { + bloom = whisper.MakeFullNodeBloom() + } else if payloadSize < 8+whisper.BloomFilterSize { + log.Warn(fmt.Sprintf("Undersized bloom filter in p2p request")) + return false, 0, 0, nil + } else { + bloom = decrypted.Payload[8 : 8+whisper.BloomFilterSize] } - return true, lower, upper, topic + lower := binary.BigEndian.Uint32(decrypted.Payload[:4]) + upper := binary.BigEndian.Uint32(decrypted.Payload[4:8]) + return true, lower, upper, bloom } diff --git a/whisper/mailserver/server_test.go b/whisper/mailserver/server_test.go index c8e0a553a..d5b993afb 100644 --- a/whisper/mailserver/server_test.go +++ b/whisper/mailserver/server_test.go @@ -17,6 +17,7 @@ package mailserver import ( + "bytes" "crypto/ecdsa" "encoding/binary" "io/ioutil" @@ -61,7 +62,7 @@ func generateEnvelope(t *testing.T) *whisper.Envelope { h := crypto.Keccak256Hash([]byte("test sample data")) params := &whisper.MessageParams{ KeySym: h[:], - Topic: whisper.TopicType{}, + Topic: whisper.TopicType{0x1F, 0x7E, 0xA1, 0x7F}, Payload: []byte("test payload"), PoW: powRequirement, WorkTime: 2, @@ -121,6 +122,7 @@ func deliverTest(t *testing.T, server *WMailServer, env *whisper.Envelope) { upp: birth + 1, key: testPeerID, } + singleRequest(t, server, env, p, true) p.low, p.upp = birth+1, 0xffffffff @@ -131,14 +133,14 @@ func deliverTest(t *testing.T, server *WMailServer, env *whisper.Envelope) { p.low = birth - 1 p.upp = birth + 1 - p.topic[0]++ + p.topic[0] = 0xFF singleRequest(t, server, env, p, false) } func singleRequest(t *testing.T, server *WMailServer, env *whisper.Envelope, p *ServerTestParams, expect bool) { request := createRequest(t, p) src := crypto.FromECDSAPub(&p.key.PublicKey) - ok, lower, upper, topic := server.validateRequest(src, request) + ok, lower, upper, bloom := server.validateRequest(src, request) if !ok { t.Fatalf("request validation failed, seed: %d.", seed) } @@ -148,12 +150,13 @@ func singleRequest(t *testing.T, server *WMailServer, env *whisper.Envelope, p * if upper != p.upp { t.Fatalf("request validation failed (upper bound), seed: %d.", seed) } - if topic != p.topic { + expectedBloom := whisper.TopicToBloom(p.topic) + if !bytes.Equal(bloom, expectedBloom) { t.Fatalf("request validation failed (topic), seed: %d.", seed) } var exist bool - mail := server.processRequest(nil, p.low, p.upp, p.topic) + mail := server.processRequest(nil, p.low, p.upp, bloom) for _, msg := range mail { if msg.Hash() == env.Hash() { exist = true @@ -166,17 +169,19 @@ func singleRequest(t *testing.T, server *WMailServer, env *whisper.Envelope, p * } src[0]++ - ok, lower, upper, topic = server.validateRequest(src, request) - if ok { - t.Fatalf("request validation false positive, seed: %d (lower: %d, upper: %d).", seed, lower, upper) + ok, lower, upper, bloom = server.validateRequest(src, request) + if !ok { + // request should be valid regardless of signature + t.Fatalf("request validation false negative, seed: %d (lower: %d, upper: %d).", seed, lower, upper) } } func createRequest(t *testing.T, p *ServerTestParams) *whisper.Envelope { - data := make([]byte, 8+whisper.TopicLength) + bloom := whisper.TopicToBloom(p.topic) + data := make([]byte, 8) binary.BigEndian.PutUint32(data, p.low) binary.BigEndian.PutUint32(data[4:], p.upp) - copy(data[8:], p.topic[:]) + data = append(data, bloom...) key, err := shh.GetSymKey(keyID) if err != nil { diff --git a/whisper/whisperv2/api.go b/whisper/whisperv2/api.go deleted file mode 100644 index 5c6d17095..000000000 --- a/whisper/whisperv2/api.go +++ /dev/null @@ -1,402 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/crypto" -) - -// PublicWhisperAPI provides the whisper RPC service. -type PublicWhisperAPI struct { - w *Whisper - - messagesMu sync.RWMutex - messages map[hexutil.Uint]*whisperFilter -} - -type whisperOfflineError struct{} - -func (e *whisperOfflineError) Error() string { - return "whisper is offline" -} - -// whisperOffLineErr is returned when the node doesn't offer the shh service. -var whisperOffLineErr = new(whisperOfflineError) - -// NewPublicWhisperAPI create a new RPC whisper service. -func NewPublicWhisperAPI(w *Whisper) *PublicWhisperAPI { - return &PublicWhisperAPI{w: w, messages: make(map[hexutil.Uint]*whisperFilter)} -} - -// Version returns the Whisper version this node offers. -func (s *PublicWhisperAPI) Version() (hexutil.Uint, error) { - if s.w == nil { - return 0, whisperOffLineErr - } - return hexutil.Uint(s.w.Version()), nil -} - -// HasIdentity checks if the the whisper node is configured with the private key -// of the specified public pair. -func (s *PublicWhisperAPI) HasIdentity(identity string) (bool, error) { - if s.w == nil { - return false, whisperOffLineErr - } - return s.w.HasIdentity(crypto.ToECDSAPub(common.FromHex(identity))), nil -} - -// NewIdentity generates a new cryptographic identity for the client, and injects -// it into the known identities for message decryption. -func (s *PublicWhisperAPI) NewIdentity() (string, error) { - if s.w == nil { - return "", whisperOffLineErr - } - - identity := s.w.NewIdentity() - return common.ToHex(crypto.FromECDSAPub(&identity.PublicKey)), nil -} - -type NewFilterArgs struct { - To string - From string - Topics [][][]byte -} - -// NewWhisperFilter creates and registers a new message filter to watch for inbound whisper messages. -func (s *PublicWhisperAPI) NewFilter(args NewFilterArgs) (hexutil.Uint, error) { - if s.w == nil { - return 0, whisperOffLineErr - } - - var id hexutil.Uint - filter := Filter{ - To: crypto.ToECDSAPub(common.FromHex(args.To)), - From: crypto.ToECDSAPub(common.FromHex(args.From)), - Topics: NewFilterTopics(args.Topics...), - Fn: func(message *Message) { - wmsg := NewWhisperMessage(message) - s.messagesMu.RLock() // Only read lock to the filter pool - defer s.messagesMu.RUnlock() - if s.messages[id] != nil { - s.messages[id].insert(wmsg) - } - }, - } - id = hexutil.Uint(s.w.Watch(filter)) - - s.messagesMu.Lock() - s.messages[id] = newWhisperFilter(id, s.w) - s.messagesMu.Unlock() - - return id, nil -} - -// GetFilterChanges retrieves all the new messages matched by a filter since the last retrieval. -func (s *PublicWhisperAPI) GetFilterChanges(filterId hexutil.Uint) []WhisperMessage { - s.messagesMu.RLock() - defer s.messagesMu.RUnlock() - - if s.messages[filterId] != nil { - if changes := s.messages[filterId].retrieve(); changes != nil { - return changes - } - } - return returnWhisperMessages(nil) -} - -// UninstallFilter disables and removes an existing filter. -func (s *PublicWhisperAPI) UninstallFilter(filterId hexutil.Uint) bool { - s.messagesMu.Lock() - defer s.messagesMu.Unlock() - - if _, ok := s.messages[filterId]; ok { - delete(s.messages, filterId) - return true - } - return false -} - -// GetMessages retrieves all the known messages that match a specific filter. -func (s *PublicWhisperAPI) GetMessages(filterId hexutil.Uint) []WhisperMessage { - // Retrieve all the cached messages matching a specific, existing filter - s.messagesMu.RLock() - defer s.messagesMu.RUnlock() - - var messages []*Message - if s.messages[filterId] != nil { - messages = s.messages[filterId].messages() - } - - return returnWhisperMessages(messages) -} - -// returnWhisperMessages converts aNhisper message to a RPC whisper message. -func returnWhisperMessages(messages []*Message) []WhisperMessage { - msgs := make([]WhisperMessage, len(messages)) - for i, msg := range messages { - msgs[i] = NewWhisperMessage(msg) - } - return msgs -} - -type PostArgs struct { - From string `json:"from"` - To string `json:"to"` - Topics [][]byte `json:"topics"` - Payload string `json:"payload"` - Priority int64 `json:"priority"` - TTL int64 `json:"ttl"` -} - -// Post injects a message into the whisper network for distribution. -func (s *PublicWhisperAPI) Post(args PostArgs) (bool, error) { - if s.w == nil { - return false, whisperOffLineErr - } - - // construct whisper message with transmission options - message := NewMessage(common.FromHex(args.Payload)) - options := Options{ - To: crypto.ToECDSAPub(common.FromHex(args.To)), - TTL: time.Duration(args.TTL) * time.Second, - Topics: NewTopics(args.Topics...), - } - - // set sender identity - if len(args.From) > 0 { - if key := s.w.GetIdentity(crypto.ToECDSAPub(common.FromHex(args.From))); key != nil { - options.From = key - } else { - return false, fmt.Errorf("unknown identity to send from: %s", args.From) - } - } - - // Wrap and send the message - pow := time.Duration(args.Priority) * time.Millisecond - envelope, err := message.Wrap(pow, options) - if err != nil { - return false, err - } - - return true, s.w.Send(envelope) -} - -// WhisperMessage is the RPC representation of a whisper message. -type WhisperMessage struct { - ref *Message - - Payload string `json:"payload"` - To string `json:"to"` - From string `json:"from"` - Sent int64 `json:"sent"` - TTL int64 `json:"ttl"` - Hash string `json:"hash"` -} - -func (args *PostArgs) UnmarshalJSON(data []byte) (err error) { - var obj struct { - From string `json:"from"` - To string `json:"to"` - Topics []string `json:"topics"` - Payload string `json:"payload"` - Priority hexutil.Uint64 `json:"priority"` - TTL hexutil.Uint64 `json:"ttl"` - } - - if err := json.Unmarshal(data, &obj); err != nil { - return err - } - - args.From = obj.From - args.To = obj.To - args.Payload = obj.Payload - args.Priority = int64(obj.Priority) // TODO(gluk256): handle overflow - args.TTL = int64(obj.TTL) // ... here too ... - - // decode topic strings - args.Topics = make([][]byte, len(obj.Topics)) - for i, topic := range obj.Topics { - args.Topics[i] = common.FromHex(topic) - } - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface, invoked to convert a -// JSON message blob into a WhisperFilterArgs structure. -func (args *NewFilterArgs) UnmarshalJSON(b []byte) (err error) { - // Unmarshal the JSON message and sanity check - var obj struct { - To interface{} `json:"to"` - From interface{} `json:"from"` - Topics interface{} `json:"topics"` - } - if err := json.Unmarshal(b, &obj); err != nil { - return err - } - - // Retrieve the simple data contents of the filter arguments - if obj.To == nil { - args.To = "" - } else { - argstr, ok := obj.To.(string) - if !ok { - return fmt.Errorf("to is not a string") - } - args.To = argstr - } - if obj.From == nil { - args.From = "" - } else { - argstr, ok := obj.From.(string) - if !ok { - return fmt.Errorf("from is not a string") - } - args.From = argstr - } - // Construct the nested topic array - if obj.Topics != nil { - // Make sure we have an actual topic array - list, ok := obj.Topics.([]interface{}) - if !ok { - return fmt.Errorf("topics is not an array") - } - // Iterate over each topic and handle nil, string or array - topics := make([][]string, len(list)) - for idx, field := range list { - switch value := field.(type) { - case nil: - topics[idx] = []string{} - - case string: - topics[idx] = []string{value} - - case []interface{}: - topics[idx] = make([]string, len(value)) - for i, nested := range value { - switch value := nested.(type) { - case nil: - topics[idx][i] = "" - - case string: - topics[idx][i] = value - - default: - return fmt.Errorf("topic[%d][%d] is not a string", idx, i) - } - } - default: - return fmt.Errorf("topic[%d] not a string or array", idx) - } - } - - topicsDecoded := make([][][]byte, len(topics)) - for i, condition := range topics { - topicsDecoded[i] = make([][]byte, len(condition)) - for j, topic := range condition { - topicsDecoded[i][j] = common.FromHex(topic) - } - } - - args.Topics = topicsDecoded - } - return nil -} - -// whisperFilter is the message cache matching a specific filter, accumulating -// inbound messages until the are requested by the client. -type whisperFilter struct { - id hexutil.Uint // Filter identifier for old message retrieval - ref *Whisper // Whisper reference for old message retrieval - - cache []WhisperMessage // Cache of messages not yet polled - skip map[common.Hash]struct{} // List of retrieved messages to avoid duplication - update time.Time // Time of the last message query - - lock sync.RWMutex // Lock protecting the filter internals -} - -// messages retrieves all the cached messages from the entire pool matching the -// filter, resetting the filter's change buffer. -func (w *whisperFilter) messages() []*Message { - w.lock.Lock() - defer w.lock.Unlock() - - w.cache = nil - w.update = time.Now() - - w.skip = make(map[common.Hash]struct{}) - messages := w.ref.Messages(int(w.id)) - for _, message := range messages { - w.skip[message.Hash] = struct{}{} - } - return messages -} - -// insert injects a new batch of messages into the filter cache. -func (w *whisperFilter) insert(messages ...WhisperMessage) { - w.lock.Lock() - defer w.lock.Unlock() - - for _, message := range messages { - if _, ok := w.skip[message.ref.Hash]; !ok { - w.cache = append(w.cache, messages...) - } - } -} - -// retrieve fetches all the cached messages from the filter. -func (w *whisperFilter) retrieve() (messages []WhisperMessage) { - w.lock.Lock() - defer w.lock.Unlock() - - messages, w.cache = w.cache, nil - w.update = time.Now() - - return -} - -// newWhisperFilter creates a new serialized, poll based whisper topic filter. -func newWhisperFilter(id hexutil.Uint, ref *Whisper) *whisperFilter { - return &whisperFilter{ - id: id, - ref: ref, - update: time.Now(), - skip: make(map[common.Hash]struct{}), - } -} - -// NewWhisperMessage converts an internal message into an API version. -func NewWhisperMessage(message *Message) WhisperMessage { - return WhisperMessage{ - ref: message, - - Payload: common.ToHex(message.Payload), - From: common.ToHex(crypto.FromECDSAPub(message.Recover())), - To: common.ToHex(crypto.FromECDSAPub(message.To)), - Sent: message.Sent.Unix(), - TTL: int64(message.TTL / time.Second), - Hash: common.ToHex(message.Hash.Bytes()), - } -} diff --git a/whisper/whisperv2/doc.go b/whisper/whisperv2/doc.go deleted file mode 100644 index 7252f44b1..000000000 --- a/whisper/whisperv2/doc.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -/* -Package whisper implements the Whisper PoC-1. - -(https://github.com/ethereum/wiki/wiki/Whisper-PoC-1-Protocol-Spec) - -Whisper combines aspects of both DHTs and datagram messaging systems (e.g. UDP). -As such it may be likened and compared to both, not dissimilar to the -matter/energy duality (apologies to physicists for the blatant abuse of a -fundamental and beautiful natural principle). - -Whisper is a pure identity-based messaging system. Whisper provides a low-level -(non-application-specific) but easily-accessible API without being based upon -or prejudiced by the low-level hardware attributes and characteristics, -particularly the notion of singular endpoints. -*/ -package whisperv2 diff --git a/whisper/whisperv2/envelope.go b/whisper/whisperv2/envelope.go deleted file mode 100644 index 9f1c68204..000000000 --- a/whisper/whisperv2/envelope.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// Contains the Whisper protocol Envelope element. For formal details please see -// the specs at https://github.com/ethereum/wiki/wiki/Whisper-PoC-1-Protocol-Spec#envelopes. - -package whisperv2 - -import ( - "crypto/ecdsa" - "encoding/binary" - "fmt" - "math/big" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/rlp" -) - -// Envelope represents a clear-text data packet to transmit through the Whisper -// network. Its contents may or may not be encrypted and signed. -type Envelope struct { - Expiry uint32 // Whisper protocol specifies int32, really should be int64 - TTL uint32 // ^^^^^^ - Topics []Topic - Data []byte - Nonce uint32 - - hash common.Hash // Cached hash of the envelope to avoid rehashing every time -} - -// NewEnvelope wraps a Whisper message with expiration and destination data -// included into an envelope for network forwarding. -func NewEnvelope(ttl time.Duration, topics []Topic, msg *Message) *Envelope { - return &Envelope{ - Expiry: uint32(time.Now().Add(ttl).Unix()), - TTL: uint32(ttl.Seconds()), - Topics: topics, - Data: msg.bytes(), - Nonce: 0, - } -} - -// Seal closes the envelope by spending the requested amount of time as a proof -// of work on hashing the data. -func (self *Envelope) Seal(pow time.Duration) { - d := make([]byte, 64) - copy(d[:32], self.rlpWithoutNonce()) - - finish, bestBit := time.Now().Add(pow).UnixNano(), 0 - for nonce := uint32(0); time.Now().UnixNano() < finish; { - for i := 0; i < 1024; i++ { - binary.BigEndian.PutUint32(d[60:], nonce) - - d := new(big.Int).SetBytes(crypto.Keccak256(d)) - firstBit := math.FirstBitSet(d) - if firstBit > bestBit { - self.Nonce, bestBit = nonce, firstBit - } - nonce++ - } - } -} - -// rlpWithoutNonce returns the RLP encoded envelope contents, except the nonce. -func (self *Envelope) rlpWithoutNonce() []byte { - enc, _ := rlp.EncodeToBytes([]interface{}{self.Expiry, self.TTL, self.Topics, self.Data}) - return enc -} - -// Open extracts the message contained within a potentially encrypted envelope. -func (self *Envelope) Open(key *ecdsa.PrivateKey) (msg *Message, err error) { - // Split open the payload into a message construct - data := self.Data - - message := &Message{ - Flags: data[0], - Sent: time.Unix(int64(self.Expiry-self.TTL), 0), - TTL: time.Duration(self.TTL) * time.Second, - Hash: self.Hash(), - } - data = data[1:] - - if message.Flags&signatureFlag == signatureFlag { - if len(data) < signatureLength { - return nil, fmt.Errorf("unable to open envelope. First bit set but len(data) < len(signature)") - } - message.Signature, data = data[:signatureLength], data[signatureLength:] - } - message.Payload = data - - // Decrypt the message, if requested - if key == nil { - return message, nil - } - err = message.decrypt(key) - switch err { - case nil: - return message, nil - - case ecies.ErrInvalidPublicKey: // Payload isn't encrypted - return message, err - - default: - return nil, fmt.Errorf("unable to open envelope, decrypt failed: %v", err) - } -} - -// Hash returns the SHA3 hash of the envelope, calculating it if not yet done. -func (self *Envelope) Hash() common.Hash { - if (self.hash == common.Hash{}) { - enc, _ := rlp.EncodeToBytes(self) - self.hash = crypto.Keccak256Hash(enc) - } - return self.hash -} - -// DecodeRLP decodes an Envelope from an RLP data stream. -func (self *Envelope) DecodeRLP(s *rlp.Stream) error { - raw, err := s.Raw() - if err != nil { - return err - } - // The decoding of Envelope uses the struct fields but also needs - // to compute the hash of the whole RLP-encoded envelope. This - // type has the same structure as Envelope but is not an - // rlp.Decoder so we can reuse the Envelope struct definition. - type rlpenv Envelope - if err := rlp.DecodeBytes(raw, (*rlpenv)(self)); err != nil { - return err - } - self.hash = crypto.Keccak256Hash(raw) - return nil -} diff --git a/whisper/whisperv2/envelope_test.go b/whisper/whisperv2/envelope_test.go deleted file mode 100644 index 490ed9f6f..000000000 --- a/whisper/whisperv2/envelope_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "bytes" - "testing" - "time" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" -) - -func TestEnvelopeOpen(t *testing.T) { - payload := []byte("hello world") - message := NewMessage(payload) - - envelope, err := message.Wrap(DefaultPoW, Options{}) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - opened, err := envelope.Open(nil) - if err != nil { - t.Fatalf("failed to open envelope: %v", err) - } - if opened.Flags != message.Flags { - t.Fatalf("flags mismatch: have %d, want %d", opened.Flags, message.Flags) - } - if !bytes.Equal(opened.Signature, message.Signature) { - t.Fatalf("signature mismatch: have 0x%x, want 0x%x", opened.Signature, message.Signature) - } - if !bytes.Equal(opened.Payload, message.Payload) { - t.Fatalf("payload mismatch: have 0x%x, want 0x%x", opened.Payload, message.Payload) - } - if opened.Sent.Unix() != message.Sent.Unix() { - t.Fatalf("send time mismatch: have %v, want %v", opened.Sent, message.Sent) - } - if opened.TTL/time.Second != DefaultTTL/time.Second { - t.Fatalf("message TTL mismatch: have %v, want %v", opened.TTL, DefaultTTL) - } - - if opened.Hash != envelope.Hash() { - t.Fatalf("message hash mismatch: have 0x%x, want 0x%x", opened.Hash, envelope.Hash()) - } -} - -func TestEnvelopeAnonymousOpenUntargeted(t *testing.T) { - payload := []byte("hello envelope") - envelope, err := NewMessage(payload).Wrap(DefaultPoW, Options{}) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - opened, err := envelope.Open(nil) - if err != nil { - t.Fatalf("failed to open envelope: %v", err) - } - if opened.To != nil { - t.Fatalf("recipient mismatch: have 0x%x, want nil", opened.To) - } - if !bytes.Equal(opened.Payload, payload) { - t.Fatalf("payload mismatch: have 0x%x, want 0x%x", opened.Payload, payload) - } -} - -func TestEnvelopeAnonymousOpenTargeted(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to generate test identity: %v", err) - } - - payload := []byte("hello envelope") - envelope, err := NewMessage(payload).Wrap(DefaultPoW, Options{ - To: &key.PublicKey, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - opened, err := envelope.Open(nil) - if err != nil { - t.Fatalf("failed to open envelope: %v", err) - } - if opened.To != nil { - t.Fatalf("recipient mismatch: have 0x%x, want nil", opened.To) - } - if bytes.Equal(opened.Payload, payload) { - t.Fatalf("payload match, should have been encrypted: 0x%x", opened.Payload) - } -} - -func TestEnvelopeIdentifiedOpenUntargeted(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to generate test identity: %v", err) - } - - payload := []byte("hello envelope") - envelope, err := NewMessage(payload).Wrap(DefaultPoW, Options{}) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - opened, err := envelope.Open(key) - switch err { - case nil: - t.Fatalf("envelope opened with bad key: %v", opened) - - case ecies.ErrInvalidPublicKey: - // Ok, key mismatch but opened - - default: - t.Fatalf("failed to open envelope: %v", err) - } - - if opened.To != nil { - t.Fatalf("recipient mismatch: have 0x%x, want nil", opened.To) - } - if !bytes.Equal(opened.Payload, payload) { - t.Fatalf("payload mismatch: have 0x%x, want 0x%x", opened.Payload, payload) - } -} - -func TestEnvelopeIdentifiedOpenTargeted(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to generate test identity: %v", err) - } - - payload := []byte("hello envelope") - envelope, err := NewMessage(payload).Wrap(DefaultPoW, Options{ - To: &key.PublicKey, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - opened, err := envelope.Open(key) - if err != nil { - t.Fatalf("failed to open envelope: %v", err) - } - if opened.To != nil { - t.Fatalf("recipient mismatch: have 0x%x, want nil", opened.To) - } - if !bytes.Equal(opened.Payload, payload) { - t.Fatalf("payload mismatch: have 0x%x, want 0x%x", opened.Payload, payload) - } -} diff --git a/whisper/whisperv2/filter.go b/whisper/whisperv2/filter.go deleted file mode 100644 index 7404859b7..000000000 --- a/whisper/whisperv2/filter.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// Contains the message filter for fine grained subscriptions. - -package whisperv2 - -import ( - "crypto/ecdsa" - - "github.com/ethereum/go-ethereum/event/filter" -) - -// Filter is used to subscribe to specific types of whisper messages. -type Filter struct { - To *ecdsa.PublicKey // Recipient of the message - From *ecdsa.PublicKey // Sender of the message - Topics [][]Topic // Topics to filter messages with - Fn func(msg *Message) // Handler in case of a match -} - -// NewFilterTopics creates a 2D topic array used by whisper.Filter from binary -// data elements. -func NewFilterTopics(data ...[][]byte) [][]Topic { - filter := make([][]Topic, len(data)) - for i, condition := range data { - // Handle the special case of condition == [[]byte{}] - if len(condition) == 1 && len(condition[0]) == 0 { - filter[i] = []Topic{} - continue - } - // Otherwise flatten normally - filter[i] = NewTopics(condition...) - } - return filter -} - -// NewFilterTopicsFlat creates a 2D topic array used by whisper.Filter from flat -// binary data elements. -func NewFilterTopicsFlat(data ...[]byte) [][]Topic { - filter := make([][]Topic, len(data)) - for i, element := range data { - // Only add non-wildcard topics - filter[i] = make([]Topic, 0, 1) - if len(element) > 0 { - filter[i] = append(filter[i], NewTopic(element)) - } - } - return filter -} - -// NewFilterTopicsFromStrings creates a 2D topic array used by whisper.Filter -// from textual data elements. -func NewFilterTopicsFromStrings(data ...[]string) [][]Topic { - filter := make([][]Topic, len(data)) - for i, condition := range data { - // Handle the special case of condition == [""] - if len(condition) == 1 && condition[0] == "" { - filter[i] = []Topic{} - continue - } - // Otherwise flatten normally - filter[i] = NewTopicsFromStrings(condition...) - } - return filter -} - -// NewFilterTopicsFromStringsFlat creates a 2D topic array used by whisper.Filter from flat -// binary data elements. -func NewFilterTopicsFromStringsFlat(data ...string) [][]Topic { - filter := make([][]Topic, len(data)) - for i, element := range data { - // Only add non-wildcard topics - filter[i] = make([]Topic, 0, 1) - if element != "" { - filter[i] = append(filter[i], NewTopicFromString(element)) - } - } - return filter -} - -// filterer is the internal, fully initialized filter ready to match inbound -// messages to a variety of criteria. -type filterer struct { - to string // Recipient of the message - from string // Sender of the message - matcher *topicMatcher // Topics to filter messages with - fn func(data interface{}) // Handler in case of a match -} - -// Compare checks if the specified filter matches the current one. -func (self filterer) Compare(f filter.Filter) bool { - filter := f.(filterer) - - // Check the message sender and recipient - if len(self.to) > 0 && self.to != filter.to { - return false - } - if len(self.from) > 0 && self.from != filter.from { - return false - } - // Check the topic filtering - topics := make([]Topic, len(filter.matcher.conditions)) - for i, group := range filter.matcher.conditions { - // Message should contain a single topic entry, extract - for topics[i] = range group { - break - } - } - return self.matcher.Matches(topics) -} - -// Trigger is called when a filter successfully matches an inbound message. -func (self filterer) Trigger(data interface{}) { - self.fn(data) -} diff --git a/whisper/whisperv2/filter_test.go b/whisper/whisperv2/filter_test.go deleted file mode 100644 index ffdfd7b34..000000000 --- a/whisper/whisperv2/filter_test.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "bytes" - - "testing" -) - -var filterTopicsCreationTests = []struct { - topics [][]string - filter [][][4]byte -}{ - { // Simple topic filter - topics: [][]string{ - {"abc", "def", "ghi"}, - {"def"}, - {"ghi", "abc"}, - }, - filter: [][][4]byte{ - {{0x4e, 0x03, 0x65, 0x7a}, {0x34, 0x60, 0x7c, 0x9b}, {0x21, 0x41, 0x7d, 0xf9}}, - {{0x34, 0x60, 0x7c, 0x9b}}, - {{0x21, 0x41, 0x7d, 0xf9}, {0x4e, 0x03, 0x65, 0x7a}}, - }, - }, - { // Wild-carded topic filter - topics: [][]string{ - {"abc", "def", "ghi"}, - {}, - {""}, - {"def"}, - }, - filter: [][][4]byte{ - {{0x4e, 0x03, 0x65, 0x7a}, {0x34, 0x60, 0x7c, 0x9b}, {0x21, 0x41, 0x7d, 0xf9}}, - {}, - {}, - {{0x34, 0x60, 0x7c, 0x9b}}, - }, - }, -} - -var filterTopicsCreationFlatTests = []struct { - topics []string - filter [][][4]byte -}{ - { // Simple topic list - topics: []string{"abc", "def", "ghi"}, - filter: [][][4]byte{ - {{0x4e, 0x03, 0x65, 0x7a}}, - {{0x34, 0x60, 0x7c, 0x9b}}, - {{0x21, 0x41, 0x7d, 0xf9}}, - }, - }, - { // Wild-carded topic list - topics: []string{"abc", "", "ghi"}, - filter: [][][4]byte{ - {{0x4e, 0x03, 0x65, 0x7a}}, - {}, - {{0x21, 0x41, 0x7d, 0xf9}}, - }, - }, -} - -func TestFilterTopicsCreation(t *testing.T) { - // Check full filter creation - for i, tt := range filterTopicsCreationTests { - // Check the textual creation - filter := NewFilterTopicsFromStrings(tt.topics...) - if len(filter) != len(tt.topics) { - t.Errorf("test %d: condition count mismatch: have %v, want %v", i, len(filter), len(tt.topics)) - continue - } - for j, condition := range filter { - if len(condition) != len(tt.filter[j]) { - t.Errorf("test %d, condition %d: size mismatch: have %v, want %v", i, j, len(condition), len(tt.filter[j])) - continue - } - for k := 0; k < len(condition); k++ { - if !bytes.Equal(condition[k][:], tt.filter[j][k][:]) { - t.Errorf("test %d, condition %d, segment %d: filter mismatch: have 0x%x, want 0x%x", i, j, k, condition[k], tt.filter[j][k]) - } - } - } - // Check the binary creation - binary := make([][][]byte, len(tt.topics)) - for j, condition := range tt.topics { - binary[j] = make([][]byte, len(condition)) - for k, segment := range condition { - binary[j][k] = []byte(segment) - } - } - filter = NewFilterTopics(binary...) - if len(filter) != len(tt.topics) { - t.Errorf("test %d: condition count mismatch: have %v, want %v", i, len(filter), len(tt.topics)) - continue - } - for j, condition := range filter { - if len(condition) != len(tt.filter[j]) { - t.Errorf("test %d, condition %d: size mismatch: have %v, want %v", i, j, len(condition), len(tt.filter[j])) - continue - } - for k := 0; k < len(condition); k++ { - if !bytes.Equal(condition[k][:], tt.filter[j][k][:]) { - t.Errorf("test %d, condition %d, segment %d: filter mismatch: have 0x%x, want 0x%x", i, j, k, condition[k], tt.filter[j][k]) - } - } - } - } - // Check flat filter creation - for i, tt := range filterTopicsCreationFlatTests { - // Check the textual creation - filter := NewFilterTopicsFromStringsFlat(tt.topics...) - if len(filter) != len(tt.topics) { - t.Errorf("test %d: condition count mismatch: have %v, want %v", i, len(filter), len(tt.topics)) - continue - } - for j, condition := range filter { - if len(condition) != len(tt.filter[j]) { - t.Errorf("test %d, condition %d: size mismatch: have %v, want %v", i, j, len(condition), len(tt.filter[j])) - continue - } - for k := 0; k < len(condition); k++ { - if !bytes.Equal(condition[k][:], tt.filter[j][k][:]) { - t.Errorf("test %d, condition %d, segment %d: filter mismatch: have 0x%x, want 0x%x", i, j, k, condition[k], tt.filter[j][k]) - } - } - } - // Check the binary creation - binary := make([][]byte, len(tt.topics)) - for j, topic := range tt.topics { - binary[j] = []byte(topic) - } - filter = NewFilterTopicsFlat(binary...) - if len(filter) != len(tt.topics) { - t.Errorf("test %d: condition count mismatch: have %v, want %v", i, len(filter), len(tt.topics)) - continue - } - for j, condition := range filter { - if len(condition) != len(tt.filter[j]) { - t.Errorf("test %d, condition %d: size mismatch: have %v, want %v", i, j, len(condition), len(tt.filter[j])) - continue - } - for k := 0; k < len(condition); k++ { - if !bytes.Equal(condition[k][:], tt.filter[j][k][:]) { - t.Errorf("test %d, condition %d, segment %d: filter mismatch: have 0x%x, want 0x%x", i, j, k, condition[k], tt.filter[j][k]) - } - } - } - } -} - -var filterCompareTests = []struct { - matcher filterer - message filterer - match bool -}{ - { // Wild-card filter matching anything - matcher: filterer{to: "", from: "", matcher: newTopicMatcher()}, - message: filterer{to: "to", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: true, - }, - { // Filter matching the to field - matcher: filterer{to: "to", from: "", matcher: newTopicMatcher()}, - message: filterer{to: "to", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: true, - }, - { // Filter rejecting the to field - matcher: filterer{to: "to", from: "", matcher: newTopicMatcher()}, - message: filterer{to: "", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: false, - }, - { // Filter matching the from field - matcher: filterer{to: "", from: "from", matcher: newTopicMatcher()}, - message: filterer{to: "to", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: true, - }, - { // Filter rejecting the from field - matcher: filterer{to: "", from: "from", matcher: newTopicMatcher()}, - message: filterer{to: "to", from: "", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: false, - }, - { // Filter matching the topic field - matcher: filterer{to: "", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - message: filterer{to: "to", from: "from", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - match: true, - }, - { // Filter rejecting the topic field - matcher: filterer{to: "", from: "", matcher: newTopicMatcher(NewFilterTopicsFromStringsFlat("topic")...)}, - message: filterer{to: "to", from: "from", matcher: newTopicMatcher()}, - match: false, - }, -} - -func TestFilterCompare(t *testing.T) { - for i, tt := range filterCompareTests { - if match := tt.matcher.Compare(tt.message); match != tt.match { - t.Errorf("test %d: match mismatch: have %v, want %v", i, match, tt.match) - } - } -} diff --git a/whisper/whisperv2/main.go b/whisper/whisperv2/main.go deleted file mode 100644 index be4160489..000000000 --- a/whisper/whisperv2/main.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// +build none - -// Contains a simple whisper peer setup and self messaging to allow playing -// around with the protocol and API without a fancy client implementation. - -package main - -import ( - "fmt" - "log" - "os" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/nat" - "github.com/ethereum/go-ethereum/whisper" -) - -func main() { - logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.InfoLevel)) - - // Generate the peer identity - key, err := crypto.GenerateKey() - if err != nil { - fmt.Printf("Failed to generate peer key: %v.\n", err) - os.Exit(-1) - } - name := common.MakeName("whisper-go", "1.0") - shh := whisper.New() - - // Create an Ethereum peer to communicate through - server := p2p.Server{ - PrivateKey: key, - MaxPeers: 10, - Name: name, - Protocols: []p2p.Protocol{shh.Protocol()}, - ListenAddr: ":30300", - NAT: nat.Any(), - } - fmt.Println("Starting Ethereum peer...") - if err := server.Start(); err != nil { - fmt.Printf("Failed to start Ethereum peer: %v.\n", err) - os.Exit(1) - } - - // Send a message to self to check that something works - payload := fmt.Sprintf("Hello world, this is %v. In case you're wondering, the time is %v", name, time.Now()) - if err := selfSend(shh, []byte(payload)); err != nil { - fmt.Printf("Failed to self message: %v.\n", err) - os.Exit(-1) - } -} - -// SendSelf wraps a payload into a Whisper envelope and forwards it to itself. -func selfSend(shh *whisper.Whisper, payload []byte) error { - ok := make(chan struct{}) - - // Start watching for self messages, output any arrivals - id := shh.NewIdentity() - shh.Watch(whisper.Filter{ - To: &id.PublicKey, - Fn: func(msg *whisper.Message) { - fmt.Printf("Message received: %s, signed with 0x%x.\n", string(msg.Payload), msg.Signature) - close(ok) - }, - }) - // Wrap the payload and encrypt it - msg := whisper.NewMessage(payload) - envelope, err := msg.Wrap(whisper.DefaultPoW, whisper.Options{ - From: id, - To: &id.PublicKey, - TTL: whisper.DefaultTTL, - }) - if err != nil { - return fmt.Errorf("failed to seal message: %v", err) - } - // Dump the message into the system and wait for it to pop back out - if err := shh.Send(envelope); err != nil { - return fmt.Errorf("failed to send self-message: %v", err) - } - select { - case <-ok: - case <-time.After(time.Second): - return fmt.Errorf("failed to receive message in time") - } - return nil -} diff --git a/whisper/whisperv2/message.go b/whisper/whisperv2/message.go deleted file mode 100644 index 66648c3be..000000000 --- a/whisper/whisperv2/message.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// Contains the Whisper protocol Message element. For formal details please see -// the specs at https://github.com/ethereum/wiki/wiki/Whisper-PoC-1-Protocol-Spec#messages. - -package whisperv2 - -import ( - "crypto/ecdsa" - crand "crypto/rand" - "fmt" - "math/rand" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/log" -) - -// Message represents an end-user data packet to transmit through the Whisper -// protocol. These are wrapped into Envelopes that need not be understood by -// intermediate nodes, just forwarded. -type Message struct { - Flags byte // First bit is signature presence, rest reserved and should be random - Signature []byte - Payload []byte - - Sent time.Time // Time when the message was posted into the network - TTL time.Duration // Maximum time to live allowed for the message - - To *ecdsa.PublicKey // Message recipient (identity used to decode the message) - Hash common.Hash // Message envelope hash to act as a unique id -} - -// Options specifies the exact way a message should be wrapped into an Envelope. -type Options struct { - From *ecdsa.PrivateKey - To *ecdsa.PublicKey - TTL time.Duration - Topics []Topic -} - -// NewMessage creates and initializes a non-signed, non-encrypted Whisper message. -func NewMessage(payload []byte) *Message { - // Construct an initial flag set: no signature, rest random - flags := byte(rand.Intn(256)) - flags &= ^signatureFlag - - // Assemble and return the message - return &Message{ - Flags: flags, - Payload: payload, - Sent: time.Now(), - } -} - -// Wrap bundles the message into an Envelope to transmit over the network. -// -// pow (Proof Of Work) controls how much time to spend on hashing the message, -// inherently controlling its priority through the network (smaller hash, bigger -// priority). -// -// The user can control the amount of identity, privacy and encryption through -// the options parameter as follows: -// - options.From == nil && options.To == nil: anonymous broadcast -// - options.From != nil && options.To == nil: signed broadcast (known sender) -// - options.From == nil && options.To != nil: encrypted anonymous message -// - options.From != nil && options.To != nil: encrypted signed message -func (self *Message) Wrap(pow time.Duration, options Options) (*Envelope, error) { - // Use the default TTL if non was specified - if options.TTL == 0 { - options.TTL = DefaultTTL - } - self.TTL = options.TTL - - // Sign and encrypt the message if requested - if options.From != nil { - if err := self.sign(options.From); err != nil { - return nil, err - } - } - if options.To != nil { - if err := self.encrypt(options.To); err != nil { - return nil, err - } - } - // Wrap the processed message, seal it and return - envelope := NewEnvelope(options.TTL, options.Topics, self) - envelope.Seal(pow) - - return envelope, nil -} - -// sign calculates and sets the cryptographic signature for the message , also -// setting the sign flag. -func (self *Message) sign(key *ecdsa.PrivateKey) (err error) { - self.Flags |= signatureFlag - self.Signature, err = crypto.Sign(self.hash(), key) - return -} - -// Recover retrieves the public key of the message signer. -func (self *Message) Recover() *ecdsa.PublicKey { - defer func() { recover() }() // in case of invalid signature - - // Short circuit if no signature is present - if self.Signature == nil { - return nil - } - // Otherwise try and recover the signature - pub, err := crypto.SigToPub(self.hash(), self.Signature) - if err != nil { - log.Error(fmt.Sprintf("Could not get public key from signature: %v", err)) - return nil - } - return pub -} - -// encrypt encrypts a message payload with a public key. -func (self *Message) encrypt(key *ecdsa.PublicKey) (err error) { - self.Payload, err = ecies.Encrypt(crand.Reader, ecies.ImportECDSAPublic(key), self.Payload, nil, nil) - return -} - -// decrypt decrypts an encrypted payload with a private key. -func (self *Message) decrypt(key *ecdsa.PrivateKey) error { - cleartext, err := ecies.ImportECDSA(key).Decrypt(crand.Reader, self.Payload, nil, nil) - if err == nil { - self.Payload = cleartext - } - return err -} - -// hash calculates the SHA3 checksum of the message flags and payload. -func (self *Message) hash() []byte { - return crypto.Keccak256(append([]byte{self.Flags}, self.Payload...)) -} - -// bytes flattens the message contents (flags, signature and payload) into a -// single binary blob. -func (self *Message) bytes() []byte { - return append([]byte{self.Flags}, append(self.Signature, self.Payload...)...) -} diff --git a/whisper/whisperv2/message_test.go b/whisper/whisperv2/message_test.go deleted file mode 100644 index c760ac54c..000000000 --- a/whisper/whisperv2/message_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "bytes" - "crypto/elliptic" - "testing" - "time" - - "github.com/ethereum/go-ethereum/crypto" -) - -// Tests whether a message can be wrapped without any identity or encryption. -func TestMessageSimpleWrap(t *testing.T) { - payload := []byte("hello world") - - msg := NewMessage(payload) - if _, err := msg.Wrap(DefaultPoW, Options{}); err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if msg.Flags&signatureFlag != 0 { - t.Fatalf("signature flag mismatch: have %d, want %d", msg.Flags&signatureFlag, 0) - } - if len(msg.Signature) != 0 { - t.Fatalf("signature found for simple wrapping: 0x%x", msg.Signature) - } - if !bytes.Equal(msg.Payload, payload) { - t.Fatalf("payload mismatch after wrapping: have 0x%x, want 0x%x", msg.Payload, payload) - } - if msg.TTL/time.Second != DefaultTTL/time.Second { - t.Fatalf("message TTL mismatch: have %v, want %v", msg.TTL, DefaultTTL) - } -} - -// Tests whether a message can be signed, and wrapped in plain-text. -func TestMessageCleartextSignRecover(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to create crypto key: %v", err) - } - payload := []byte("hello world") - - msg := NewMessage(payload) - if _, err := msg.Wrap(DefaultPoW, Options{ - From: key, - }); err != nil { - t.Fatalf("failed to sign message: %v", err) - } - if msg.Flags&signatureFlag != signatureFlag { - t.Fatalf("signature flag mismatch: have %d, want %d", msg.Flags&signatureFlag, signatureFlag) - } - if !bytes.Equal(msg.Payload, payload) { - t.Fatalf("payload mismatch after signing: have 0x%x, want 0x%x", msg.Payload, payload) - } - - pubKey := msg.Recover() - if pubKey == nil { - t.Fatalf("failed to recover public key") - } - p1 := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y) - p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y) - if !bytes.Equal(p1, p2) { - t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1) - } -} - -// Tests whether a message can be encrypted and decrypted using an anonymous -// sender (i.e. no signature). -func TestMessageAnonymousEncryptDecrypt(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to create recipient crypto key: %v", err) - } - payload := []byte("hello world") - - msg := NewMessage(payload) - envelope, err := msg.Wrap(DefaultPoW, Options{ - To: &key.PublicKey, - }) - if err != nil { - t.Fatalf("failed to encrypt message: %v", err) - } - if msg.Flags&signatureFlag != 0 { - t.Fatalf("signature flag mismatch: have %d, want %d", msg.Flags&signatureFlag, 0) - } - if len(msg.Signature) != 0 { - t.Fatalf("signature found for anonymous message: 0x%x", msg.Signature) - } - - out, err := envelope.Open(key) - if err != nil { - t.Fatalf("failed to open encrypted message: %v", err) - } - if !bytes.Equal(out.Payload, payload) { - t.Errorf("payload mismatch: have 0x%x, want 0x%x", out.Payload, payload) - } -} - -// Tests whether a message can be properly signed and encrypted. -func TestMessageFullCrypto(t *testing.T) { - fromKey, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to create sender crypto key: %v", err) - } - toKey, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("failed to create recipient crypto key: %v", err) - } - - payload := []byte("hello world") - msg := NewMessage(payload) - envelope, err := msg.Wrap(DefaultPoW, Options{ - From: fromKey, - To: &toKey.PublicKey, - }) - if err != nil { - t.Fatalf("failed to encrypt message: %v", err) - } - if msg.Flags&signatureFlag != signatureFlag { - t.Fatalf("signature flag mismatch: have %d, want %d", msg.Flags&signatureFlag, signatureFlag) - } - if len(msg.Signature) == 0 { - t.Fatalf("no signature found for signed message") - } - - out, err := envelope.Open(toKey) - if err != nil { - t.Fatalf("failed to open encrypted message: %v", err) - } - if !bytes.Equal(out.Payload, payload) { - t.Errorf("payload mismatch: have 0x%x, want 0x%x", out.Payload, payload) - } - - pubKey := out.Recover() - if pubKey == nil { - t.Fatalf("failed to recover public key") - } - p1 := elliptic.Marshal(crypto.S256(), fromKey.PublicKey.X, fromKey.PublicKey.Y) - p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y) - if !bytes.Equal(p1, p2) { - t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1) - } -} diff --git a/whisper/whisperv2/peer.go b/whisper/whisperv2/peer.go deleted file mode 100644 index 71798408b..000000000 --- a/whisper/whisperv2/peer.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "fmt" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/rlp" - "gopkg.in/fatih/set.v0" -) - -// peer represents a whisper protocol peer connection. -type peer struct { - host *Whisper - peer *p2p.Peer - ws p2p.MsgReadWriter - - known *set.Set // Messages already known by the peer to avoid wasting bandwidth - - quit chan struct{} -} - -// newPeer creates a new whisper peer object, but does not run the handshake itself. -func newPeer(host *Whisper, remote *p2p.Peer, rw p2p.MsgReadWriter) *peer { - return &peer{ - host: host, - peer: remote, - ws: rw, - known: set.New(), - quit: make(chan struct{}), - } -} - -// start initiates the peer updater, periodically broadcasting the whisper packets -// into the network. -func (self *peer) start() { - go self.update() - log.Debug(fmt.Sprintf("%v: whisper started", self.peer)) -} - -// stop terminates the peer updater, stopping message forwarding to it. -func (self *peer) stop() { - close(self.quit) - log.Debug(fmt.Sprintf("%v: whisper stopped", self.peer)) -} - -// handshake sends the protocol initiation status message to the remote peer and -// verifies the remote status too. -func (self *peer) handshake() error { - // Send the handshake status message asynchronously - errc := make(chan error, 1) - go func() { - errc <- p2p.SendItems(self.ws, statusCode, protocolVersion) - }() - // Fetch the remote status packet and verify protocol match - packet, err := self.ws.ReadMsg() - if err != nil { - return err - } - if packet.Code != statusCode { - return fmt.Errorf("peer sent %x before status packet", packet.Code) - } - s := rlp.NewStream(packet.Payload, uint64(packet.Size)) - if _, err := s.List(); err != nil { - return fmt.Errorf("bad status message: %v", err) - } - peerVersion, err := s.Uint() - if err != nil { - return fmt.Errorf("bad status message: %v", err) - } - if peerVersion != protocolVersion { - return fmt.Errorf("protocol version mismatch %d != %d", peerVersion, protocolVersion) - } - // Wait until out own status is consumed too - if err := <-errc; err != nil { - return fmt.Errorf("failed to send status packet: %v", err) - } - return nil -} - -// update executes periodic operations on the peer, including message transmission -// and expiration. -func (self *peer) update() { - // Start the tickers for the updates - expire := time.NewTicker(expirationCycle) - transmit := time.NewTicker(transmissionCycle) - - // Loop and transmit until termination is requested - for { - select { - case <-expire.C: - self.expire() - - case <-transmit.C: - if err := self.broadcast(); err != nil { - log.Info(fmt.Sprintf("%v: broadcast failed: %v", self.peer, err)) - return - } - - case <-self.quit: - return - } - } -} - -// mark marks an envelope known to the peer so that it won't be sent back. -func (self *peer) mark(envelope *Envelope) { - self.known.Add(envelope.Hash()) -} - -// marked checks if an envelope is already known to the remote peer. -func (self *peer) marked(envelope *Envelope) bool { - return self.known.Has(envelope.Hash()) -} - -// expire iterates over all the known envelopes in the host and removes all -// expired (unknown) ones from the known list. -func (self *peer) expire() { - // Assemble the list of available envelopes - available := set.NewNonTS() - for _, envelope := range self.host.envelopes() { - available.Add(envelope.Hash()) - } - // Cross reference availability with known status - unmark := make(map[common.Hash]struct{}) - self.known.Each(func(v interface{}) bool { - if !available.Has(v.(common.Hash)) { - unmark[v.(common.Hash)] = struct{}{} - } - return true - }) - // Dump all known but unavailable - for hash := range unmark { - self.known.Remove(hash) - } -} - -// broadcast iterates over the collection of envelopes and transmits yet unknown -// ones over the network. -func (self *peer) broadcast() error { - // Fetch the envelopes and collect the unknown ones - envelopes := self.host.envelopes() - transmit := make([]*Envelope, 0, len(envelopes)) - for _, envelope := range envelopes { - if !self.marked(envelope) { - transmit = append(transmit, envelope) - self.mark(envelope) - } - } - // Transmit the unknown batch (potentially empty) - if err := p2p.Send(self.ws, messagesCode, transmit); err != nil { - return err - } - log.Trace(fmt.Sprint(self.peer, "broadcasted", len(transmit), "message(s)")) - return nil -} diff --git a/whisper/whisperv2/peer_test.go b/whisper/whisperv2/peer_test.go deleted file mode 100644 index 87ca5063d..000000000 --- a/whisper/whisperv2/peer_test.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "testing" - "time" - - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discover" -) - -type testPeer struct { - client *Whisper - stream *p2p.MsgPipeRW - termed chan struct{} -} - -func startTestPeer() *testPeer { - // Create a simulated P2P remote peer and data streams to it - remote := p2p.NewPeer(discover.NodeID{}, "", nil) - tester, tested := p2p.MsgPipe() - - // Create a whisper client and connect with it to the tester peer - client := New() - client.Start(nil) - - termed := make(chan struct{}) - go func() { - defer client.Stop() - defer close(termed) - defer tested.Close() - - client.handlePeer(remote, tested) - }() - - return &testPeer{ - client: client, - stream: tester, - termed: termed, - } -} - -func startTestPeerInited() (*testPeer, error) { - peer := startTestPeer() - - if err := p2p.ExpectMsg(peer.stream, statusCode, []uint64{protocolVersion}); err != nil { - peer.stream.Close() - return nil, err - } - if err := p2p.SendItems(peer.stream, statusCode, protocolVersion); err != nil { - peer.stream.Close() - return nil, err - } - return peer, nil -} - -func TestPeerStatusMessage(t *testing.T) { - tester := startTestPeer() - - // Wait for the handshake status message and check it - if err := p2p.ExpectMsg(tester.stream, statusCode, []uint64{protocolVersion}); err != nil { - t.Fatalf("status message mismatch: %v", err) - } - // Terminate the node - tester.stream.Close() - - select { - case <-tester.termed: - case <-time.After(time.Second): - t.Fatalf("local close timed out") - } -} - -func TestPeerHandshakeFail(t *testing.T) { - tester := startTestPeer() - - // Wait for and check the handshake - if err := p2p.ExpectMsg(tester.stream, statusCode, []uint64{protocolVersion}); err != nil { - t.Fatalf("status message mismatch: %v", err) - } - // Send an invalid handshake status and verify disconnect - if err := p2p.SendItems(tester.stream, messagesCode); err != nil { - t.Fatalf("failed to send malformed status: %v", err) - } - select { - case <-tester.termed: - case <-time.After(time.Second): - t.Fatalf("remote close timed out") - } -} - -func TestPeerHandshakeSuccess(t *testing.T) { - tester := startTestPeer() - - // Wait for and check the handshake - if err := p2p.ExpectMsg(tester.stream, statusCode, []uint64{protocolVersion}); err != nil { - t.Fatalf("status message mismatch: %v", err) - } - // Send a valid handshake status and make sure connection stays live - if err := p2p.SendItems(tester.stream, statusCode, protocolVersion); err != nil { - t.Fatalf("failed to send status: %v", err) - } - select { - case <-tester.termed: - t.Fatalf("valid handshake disconnected") - - case <-time.After(100 * time.Millisecond): - } - // Clean up the test - tester.stream.Close() - - select { - case <-tester.termed: - case <-time.After(time.Second): - t.Fatalf("local close timed out") - } -} - -func TestPeerSend(t *testing.T) { - // Start a tester and execute the handshake - tester, err := startTestPeerInited() - if err != nil { - t.Fatalf("failed to start initialized peer: %v", err) - } - defer tester.stream.Close() - - // Construct a message and inject into the tester - message := NewMessage([]byte("peer broadcast test message")) - envelope, err := message.Wrap(DefaultPoW, Options{ - TTL: DefaultTTL, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := tester.client.Send(envelope); err != nil { - t.Fatalf("failed to send message: %v", err) - } - // Check that the message is eventually forwarded - payload := []interface{}{envelope} - if err := p2p.ExpectMsg(tester.stream, messagesCode, payload); err != nil { - t.Fatalf("message mismatch: %v", err) - } - // Make sure that even with a re-insert, an empty batch is received - if err := tester.client.Send(envelope); err != nil { - t.Fatalf("failed to send message: %v", err) - } - if err := p2p.ExpectMsg(tester.stream, messagesCode, []interface{}{}); err != nil { - t.Fatalf("message mismatch: %v", err) - } -} - -func TestPeerDeliver(t *testing.T) { - // Start a tester and execute the handshake - tester, err := startTestPeerInited() - if err != nil { - t.Fatalf("failed to start initialized peer: %v", err) - } - defer tester.stream.Close() - - // Watch for all inbound messages - arrived := make(chan struct{}, 1) - tester.client.Watch(Filter{ - Fn: func(message *Message) { - arrived <- struct{}{} - }, - }) - // Construct a message and deliver it to the tester peer - message := NewMessage([]byte("peer broadcast test message")) - envelope, err := message.Wrap(DefaultPoW, Options{ - TTL: DefaultTTL, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := p2p.Send(tester.stream, messagesCode, []*Envelope{envelope}); err != nil { - t.Fatalf("failed to transfer message: %v", err) - } - // Check that the message is delivered upstream - select { - case <-arrived: - case <-time.After(time.Second): - t.Fatalf("message delivery timeout") - } - // Check that a resend is not delivered - if err := p2p.Send(tester.stream, messagesCode, []*Envelope{envelope}); err != nil { - t.Fatalf("failed to transfer message: %v", err) - } - select { - case <-time.After(2 * transmissionCycle): - case <-arrived: - t.Fatalf("repeating message arrived") - } -} - -func TestPeerMessageExpiration(t *testing.T) { - // Start a tester and execute the handshake - tester, err := startTestPeerInited() - if err != nil { - t.Fatalf("failed to start initialized peer: %v", err) - } - defer tester.stream.Close() - - // Fetch the peer instance for later inspection - tester.client.peerMu.RLock() - if peers := len(tester.client.peers); peers != 1 { - t.Fatalf("peer pool size mismatch: have %v, want %v", peers, 1) - } - var peer *peer - for peer = range tester.client.peers { - break - } - tester.client.peerMu.RUnlock() - - // Construct a message and pass it through the tester - message := NewMessage([]byte("peer test message")) - envelope, err := message.Wrap(DefaultPoW, Options{ - TTL: time.Second, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := tester.client.Send(envelope); err != nil { - t.Fatalf("failed to send message: %v", err) - } - payload := []interface{}{envelope} - if err := p2p.ExpectMsg(tester.stream, messagesCode, payload); err != nil { - // A premature empty message may have been broadcast, check the next too - if err := p2p.ExpectMsg(tester.stream, messagesCode, payload); err != nil { - t.Fatalf("message mismatch: %v", err) - } - } - // Check that the message is inside the cache - if !peer.known.Has(envelope.Hash()) { - t.Fatalf("message not found in cache") - } - // Discard messages until expiration and check cache again - exp := time.Now().Add(time.Second + 2*expirationCycle + 100*time.Millisecond) - for time.Now().Before(exp) { - if err := p2p.ExpectMsg(tester.stream, messagesCode, []interface{}{}); err != nil { - t.Fatalf("message mismatch: %v", err) - } - } - if peer.known.Has(envelope.Hash()) { - t.Fatalf("message not expired from cache") - } -} diff --git a/whisper/whisperv2/topic.go b/whisper/whisperv2/topic.go deleted file mode 100644 index 3e2b47bd3..000000000 --- a/whisper/whisperv2/topic.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -// Contains the Whisper protocol Topic element. For formal details please see -// the specs at https://github.com/ethereum/wiki/wiki/Whisper-PoC-1-Protocol-Spec#topics. - -package whisperv2 - -import "github.com/ethereum/go-ethereum/crypto" - -// Topic represents a cryptographically secure, probabilistic partial -// classifications of a message, determined as the first (left) 4 bytes of the -// SHA3 hash of some arbitrary data given by the original author of the message. -type Topic [4]byte - -// NewTopic creates a topic from the 4 byte prefix of the SHA3 hash of the data. -// -// Note, empty topics are considered the wildcard, and cannot be used in messages. -func NewTopic(data []byte) Topic { - prefix := [4]byte{} - copy(prefix[:], crypto.Keccak256(data)[:4]) - return Topic(prefix) -} - -// NewTopics creates a list of topics from a list of binary data elements, by -// iteratively calling NewTopic on each of them. -func NewTopics(data ...[]byte) []Topic { - topics := make([]Topic, len(data)) - for i, element := range data { - topics[i] = NewTopic(element) - } - return topics -} - -// NewTopicFromString creates a topic using the binary data contents of the -// specified string. -func NewTopicFromString(data string) Topic { - return NewTopic([]byte(data)) -} - -// NewTopicsFromStrings creates a list of topics from a list of textual data -// elements, by iteratively calling NewTopicFromString on each of them. -func NewTopicsFromStrings(data ...string) []Topic { - topics := make([]Topic, len(data)) - for i, element := range data { - topics[i] = NewTopicFromString(element) - } - return topics -} - -// String converts a topic byte array to a string representation. -func (self *Topic) String() string { - return string(self[:]) -} - -// topicMatcher is a filter expression to verify if a list of topics contained -// in an arriving message matches some topic conditions. The topic matcher is -// built up of a list of conditions, each of which must be satisfied by the -// corresponding topic in the message. Each condition may require: a) an exact -// topic match; b) a match from a set of topics; or c) a wild-card matching all. -// -// If a message contains more topics than required by the matcher, those beyond -// the condition count are ignored and assumed to match. -// -// Consider the following sample topic matcher: -// sample := { -// {TopicA1, TopicA2, TopicA3}, -// {TopicB}, -// nil, -// {TopicD1, TopicD2} -// } -// In order for a message to pass this filter, it should enumerate at least 4 -// topics, the first any of [TopicA1, TopicA2, TopicA3], the second mandatory -// "TopicB", the third is ignored by the filter and the fourth either "TopicD1" -// or "TopicD2". If the message contains further topics, the filter will match -// them too. -type topicMatcher struct { - conditions []map[Topic]struct{} -} - -// newTopicMatcher create a topic matcher from a list of topic conditions. -func newTopicMatcher(topics ...[]Topic) *topicMatcher { - matcher := make([]map[Topic]struct{}, len(topics)) - for i, condition := range topics { - matcher[i] = make(map[Topic]struct{}) - for _, topic := range condition { - matcher[i][topic] = struct{}{} - } - } - return &topicMatcher{conditions: matcher} -} - -// newTopicMatcherFromBinary create a topic matcher from a list of binary conditions. -func newTopicMatcherFromBinary(data ...[][]byte) *topicMatcher { - topics := make([][]Topic, len(data)) - for i, condition := range data { - topics[i] = NewTopics(condition...) - } - return newTopicMatcher(topics...) -} - -// newTopicMatcherFromStrings creates a topic matcher from a list of textual -// conditions. -func newTopicMatcherFromStrings(data ...[]string) *topicMatcher { - topics := make([][]Topic, len(data)) - for i, condition := range data { - topics[i] = NewTopicsFromStrings(condition...) - } - return newTopicMatcher(topics...) -} - -// Matches checks if a list of topics matches this particular condition set. -func (self *topicMatcher) Matches(topics []Topic) bool { - // Mismatch if there aren't enough topics - if len(self.conditions) > len(topics) { - return false - } - // Check each topic condition for existence (skip wild-cards) - for i := 0; i < len(topics) && i < len(self.conditions); i++ { - if len(self.conditions[i]) > 0 { - if _, ok := self.conditions[i][topics[i]]; !ok { - return false - } - } - } - return true -} diff --git a/whisper/whisperv2/topic_test.go b/whisper/whisperv2/topic_test.go deleted file mode 100644 index bb6568996..000000000 --- a/whisper/whisperv2/topic_test.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "bytes" - "testing" -) - -var topicCreationTests = []struct { - data []byte - hash [4]byte -}{ - {hash: [4]byte{0x8f, 0x9a, 0x2b, 0x7d}, data: []byte("test name")}, - {hash: [4]byte{0xf2, 0x6e, 0x77, 0x79}, data: []byte("some other test")}, -} - -func TestTopicCreation(t *testing.T) { - // Create the topics individually - for i, tt := range topicCreationTests { - topic := NewTopic(tt.data) - if !bytes.Equal(topic[:], tt.hash[:]) { - t.Errorf("binary test %d: hash mismatch: have %v, want %v.", i, topic, tt.hash) - } - } - for i, tt := range topicCreationTests { - topic := NewTopicFromString(string(tt.data)) - if !bytes.Equal(topic[:], tt.hash[:]) { - t.Errorf("textual test %d: hash mismatch: have %v, want %v.", i, topic, tt.hash) - } - } - // Create the topics in batches - binaryData := make([][]byte, len(topicCreationTests)) - for i, tt := range topicCreationTests { - binaryData[i] = tt.data - } - textualData := make([]string, len(topicCreationTests)) - for i, tt := range topicCreationTests { - textualData[i] = string(tt.data) - } - - topics := NewTopics(binaryData...) - for i, tt := range topicCreationTests { - if !bytes.Equal(topics[i][:], tt.hash[:]) { - t.Errorf("binary batch test %d: hash mismatch: have %v, want %v.", i, topics[i], tt.hash) - } - } - topics = NewTopicsFromStrings(textualData...) - for i, tt := range topicCreationTests { - if !bytes.Equal(topics[i][:], tt.hash[:]) { - t.Errorf("textual batch test %d: hash mismatch: have %v, want %v.", i, topics[i], tt.hash) - } - } -} - -var topicMatcherCreationTest = struct { - binary [][][]byte - textual [][]string - matcher []map[[4]byte]struct{} -}{ - binary: [][][]byte{ - {}, - { - []byte("Topic A"), - }, - { - []byte("Topic B1"), - []byte("Topic B2"), - []byte("Topic B3"), - }, - }, - textual: [][]string{ - {}, - {"Topic A"}, - {"Topic B1", "Topic B2", "Topic B3"}, - }, - matcher: []map[[4]byte]struct{}{ - {}, - { - {0x25, 0xfc, 0x95, 0x66}: {}, - }, - { - {0x93, 0x6d, 0xec, 0x09}: {}, - {0x25, 0x23, 0x34, 0xd3}: {}, - {0x6b, 0xc2, 0x73, 0xd1}: {}, - }, - }, -} - -func TestTopicMatcherCreation(t *testing.T) { - test := topicMatcherCreationTest - - matcher := newTopicMatcherFromBinary(test.binary...) - for i, cond := range matcher.conditions { - for topic := range cond { - if _, ok := test.matcher[i][topic]; !ok { - t.Errorf("condition %d; extra topic found: 0x%x", i, topic[:]) - } - } - } - for i, cond := range test.matcher { - for topic := range cond { - if _, ok := matcher.conditions[i][topic]; !ok { - t.Errorf("condition %d; topic not found: 0x%x", i, topic[:]) - } - } - } - - matcher = newTopicMatcherFromStrings(test.textual...) - for i, cond := range matcher.conditions { - for topic := range cond { - if _, ok := test.matcher[i][topic]; !ok { - t.Errorf("condition %d; extra topic found: 0x%x", i, topic[:]) - } - } - } - for i, cond := range test.matcher { - for topic := range cond { - if _, ok := matcher.conditions[i][topic]; !ok { - t.Errorf("condition %d; topic not found: 0x%x", i, topic[:]) - } - } - } -} - -var topicMatcherTests = []struct { - filter [][]string - topics []string - match bool -}{ - // Empty topic matcher should match everything - { - filter: [][]string{}, - topics: []string{}, - match: true, - }, - { - filter: [][]string{}, - topics: []string{"a", "b", "c"}, - match: true, - }, - // Fixed topic matcher should match strictly, but only prefix - { - filter: [][]string{{"a"}, {"b"}}, - topics: []string{"a"}, - match: false, - }, - { - filter: [][]string{{"a"}, {"b"}}, - topics: []string{"a", "b"}, - match: true, - }, - { - filter: [][]string{{"a"}, {"b"}}, - topics: []string{"a", "b", "c"}, - match: true, - }, - // Multi-matcher should match any from a sub-group - { - filter: [][]string{{"a1", "a2"}}, - topics: []string{"a"}, - match: false, - }, - { - filter: [][]string{{"a1", "a2"}}, - topics: []string{"a1"}, - match: true, - }, - { - filter: [][]string{{"a1", "a2"}}, - topics: []string{"a2"}, - match: true, - }, - // Wild-card condition should match anything - { - filter: [][]string{{}, {"b"}}, - topics: []string{"a"}, - match: false, - }, - { - filter: [][]string{{}, {"b"}}, - topics: []string{"a", "b"}, - match: true, - }, - { - filter: [][]string{{}, {"b"}}, - topics: []string{"b", "b"}, - match: true, - }, -} - -func TestTopicMatcher(t *testing.T) { - for i, tt := range topicMatcherTests { - topics := NewTopicsFromStrings(tt.topics...) - - matcher := newTopicMatcherFromStrings(tt.filter...) - if match := matcher.Matches(topics); match != tt.match { - t.Errorf("test %d: match mismatch: have %v, want %v", i, match, tt.match) - } - } -} diff --git a/whisper/whisperv2/whisper.go b/whisper/whisperv2/whisper.go deleted file mode 100644 index e111a3414..000000000 --- a/whisper/whisperv2/whisper.go +++ /dev/null @@ -1,378 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "crypto/ecdsa" - "fmt" - "sync" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/event/filter" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/rpc" - - "gopkg.in/fatih/set.v0" -) - -const ( - statusCode = 0x00 - messagesCode = 0x01 - - protocolVersion uint64 = 0x02 - protocolName = "shh" - - signatureFlag = byte(1 << 7) - signatureLength = 65 - - expirationCycle = 800 * time.Millisecond - transmissionCycle = 300 * time.Millisecond -) - -const ( - DefaultTTL = 50 * time.Second - DefaultPoW = 50 * time.Millisecond -) - -type MessageEvent struct { - To *ecdsa.PrivateKey - From *ecdsa.PublicKey - Message *Message -} - -// Whisper represents a dark communication interface through the Ethereum -// network, using its very own P2P communication layer. -type Whisper struct { - protocol p2p.Protocol - filters *filter.Filters - - keys map[string]*ecdsa.PrivateKey - - messages map[common.Hash]*Envelope // Pool of messages currently tracked by this node - expirations map[uint32]*set.SetNonTS // Message expiration pool (TODO: something lighter) - poolMu sync.RWMutex // Mutex to sync the message and expiration pools - - peers map[*peer]struct{} // Set of currently active peers - peerMu sync.RWMutex // Mutex to sync the active peer set - - quit chan struct{} -} - -// New creates a Whisper client ready to communicate through the Ethereum P2P -// network. -func New() *Whisper { - whisper := &Whisper{ - filters: filter.New(), - keys: make(map[string]*ecdsa.PrivateKey), - messages: make(map[common.Hash]*Envelope), - expirations: make(map[uint32]*set.SetNonTS), - peers: make(map[*peer]struct{}), - quit: make(chan struct{}), - } - whisper.filters.Start() - - // p2p whisper sub protocol handler - whisper.protocol = p2p.Protocol{ - Name: protocolName, - Version: uint(protocolVersion), - Length: 2, - Run: whisper.handlePeer, - } - - return whisper -} - -// APIs returns the RPC descriptors the Whisper implementation offers -func (s *Whisper) APIs() []rpc.API { - return []rpc.API{ - { - Namespace: "shh", - Version: "1.0", - Service: NewPublicWhisperAPI(s), - Public: true, - }, - } -} - -// Protocols returns the whisper sub-protocols ran by this particular client. -func (self *Whisper) Protocols() []p2p.Protocol { - return []p2p.Protocol{self.protocol} -} - -// Version returns the whisper sub-protocols version number. -func (self *Whisper) Version() uint { - return self.protocol.Version -} - -// NewIdentity generates a new cryptographic identity for the client, and injects -// it into the known identities for message decryption. -func (self *Whisper) NewIdentity() *ecdsa.PrivateKey { - key, err := crypto.GenerateKey() - if err != nil { - panic(err) - } - self.keys[string(crypto.FromECDSAPub(&key.PublicKey))] = key - - return key -} - -// HasIdentity checks if the the whisper node is configured with the private key -// of the specified public pair. -func (self *Whisper) HasIdentity(key *ecdsa.PublicKey) bool { - return self.keys[string(crypto.FromECDSAPub(key))] != nil -} - -// GetIdentity retrieves the private key of the specified public identity. -func (self *Whisper) GetIdentity(key *ecdsa.PublicKey) *ecdsa.PrivateKey { - return self.keys[string(crypto.FromECDSAPub(key))] -} - -// Watch installs a new message handler to run in case a matching packet arrives -// from the whisper network. -func (self *Whisper) Watch(options Filter) int { - filter := filterer{ - to: string(crypto.FromECDSAPub(options.To)), - from: string(crypto.FromECDSAPub(options.From)), - matcher: newTopicMatcher(options.Topics...), - fn: func(data interface{}) { - options.Fn(data.(*Message)) - }, - } - return self.filters.Install(filter) -} - -// Unwatch removes an installed message handler. -func (self *Whisper) Unwatch(id int) { - self.filters.Uninstall(id) -} - -// Send injects a message into the whisper send queue, to be distributed in the -// network in the coming cycles. -func (self *Whisper) Send(envelope *Envelope) error { - return self.add(envelope) -} - -// Start implements node.Service, starting the background data propagation thread -// of the Whisper protocol. -func (self *Whisper) Start(*p2p.Server) error { - log.Info("Whisper started") - go self.update() - return nil -} - -// Stop implements node.Service, stopping the background data propagation thread -// of the Whisper protocol. -func (self *Whisper) Stop() error { - close(self.quit) - log.Info("Whisper stopped") - return nil -} - -// Messages retrieves all the currently pooled messages matching a filter id. -func (self *Whisper) Messages(id int) []*Message { - messages := make([]*Message, 0) - if filter := self.filters.Get(id); filter != nil { - for _, envelope := range self.messages { - if message := self.open(envelope); message != nil { - if self.filters.Match(filter, createFilter(message, envelope.Topics)) { - messages = append(messages, message) - } - } - } - } - return messages -} - -// handlePeer is called by the underlying P2P layer when the whisper sub-protocol -// connection is negotiated. -func (self *Whisper) handlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error { - // Create the new peer and start tracking it - whisperPeer := newPeer(self, peer, rw) - - self.peerMu.Lock() - self.peers[whisperPeer] = struct{}{} - self.peerMu.Unlock() - - defer func() { - self.peerMu.Lock() - delete(self.peers, whisperPeer) - self.peerMu.Unlock() - }() - - // Run the peer handshake and state updates - if err := whisperPeer.handshake(); err != nil { - return err - } - whisperPeer.start() - defer whisperPeer.stop() - - // Read and process inbound messages directly to merge into client-global state - for { - // Fetch the next packet and decode the contained envelopes - packet, err := rw.ReadMsg() - if err != nil { - return err - } - var envelopes []*Envelope - if err := packet.Decode(&envelopes); err != nil { - log.Info(fmt.Sprintf("%v: failed to decode envelope: %v", peer, err)) - continue - } - // Inject all envelopes into the internal pool - for _, envelope := range envelopes { - if err := self.add(envelope); err != nil { - // TODO Punish peer here. Invalid envelope. - log.Debug(fmt.Sprintf("%v: failed to pool envelope: %v", peer, err)) - } - whisperPeer.mark(envelope) - } - } -} - -// add inserts a new envelope into the message pool to be distributed within the -// whisper network. It also inserts the envelope into the expiration pool at the -// appropriate time-stamp. -func (self *Whisper) add(envelope *Envelope) error { - self.poolMu.Lock() - defer self.poolMu.Unlock() - - // short circuit when a received envelope has already expired - if envelope.Expiry < uint32(time.Now().Unix()) { - return nil - } - - // Insert the message into the tracked pool - hash := envelope.Hash() - if _, ok := self.messages[hash]; ok { - log.Trace(fmt.Sprintf("whisper envelope already cached: %x\n", hash)) - return nil - } - self.messages[hash] = envelope - - // Insert the message into the expiration pool for later removal - if self.expirations[envelope.Expiry] == nil { - self.expirations[envelope.Expiry] = set.NewNonTS() - } - if !self.expirations[envelope.Expiry].Has(hash) { - self.expirations[envelope.Expiry].Add(hash) - - // Notify the local node of a message arrival - go self.postEvent(envelope) - } - log.Trace(fmt.Sprintf("cached whisper envelope %x\n", hash)) - return nil -} - -// postEvent opens an envelope with the configured identities and delivers the -// message upstream from application processing. -func (self *Whisper) postEvent(envelope *Envelope) { - if message := self.open(envelope); message != nil { - self.filters.Notify(createFilter(message, envelope.Topics), message) - } -} - -// open tries to decrypt a whisper envelope with all the configured identities, -// returning the decrypted message and the key used to achieve it. If not keys -// are configured, open will return the payload as if non encrypted. -func (self *Whisper) open(envelope *Envelope) *Message { - // Short circuit if no identity is set, and assume clear-text - if len(self.keys) == 0 { - if message, err := envelope.Open(nil); err == nil { - return message - } - } - // Iterate over the keys and try to decrypt the message - for _, key := range self.keys { - message, err := envelope.Open(key) - if err == nil { - message.To = &key.PublicKey - return message - } else if err == ecies.ErrInvalidPublicKey { - return message - } - } - // Failed to decrypt, don't return anything - return nil -} - -// createFilter creates a message filter to check against installed handlers. -func createFilter(message *Message, topics []Topic) filter.Filter { - matcher := make([][]Topic, len(topics)) - for i, topic := range topics { - matcher[i] = []Topic{topic} - } - return filterer{ - to: string(crypto.FromECDSAPub(message.To)), - from: string(crypto.FromECDSAPub(message.Recover())), - matcher: newTopicMatcher(matcher...), - } -} - -// update loops until the lifetime of the whisper node, updating its internal -// state by expiring stale messages from the pool. -func (self *Whisper) update() { - // Start a ticker to check for expirations - expire := time.NewTicker(expirationCycle) - - // Repeat updates until termination is requested - for { - select { - case <-expire.C: - self.expire() - - case <-self.quit: - return - } - } -} - -// expire iterates over all the expiration timestamps, removing all stale -// messages from the pools. -func (self *Whisper) expire() { - self.poolMu.Lock() - defer self.poolMu.Unlock() - - now := uint32(time.Now().Unix()) - for then, hashSet := range self.expirations { - // Short circuit if a future time - if then > now { - continue - } - // Dump all expired messages and remove timestamp - hashSet.Each(func(v interface{}) bool { - delete(self.messages, v.(common.Hash)) - return true - }) - self.expirations[then].Clear() - } -} - -// envelopes retrieves all the messages currently pooled by the node. -func (self *Whisper) envelopes() []*Envelope { - self.poolMu.RLock() - defer self.poolMu.RUnlock() - - envelopes := make([]*Envelope, 0, len(self.messages)) - for _, envelope := range self.messages { - envelopes = append(envelopes, envelope) - } - return envelopes -} diff --git a/whisper/whisperv2/whisper_test.go b/whisper/whisperv2/whisper_test.go deleted file mode 100644 index 1e0d3f85d..000000000 --- a/whisper/whisperv2/whisper_test.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package whisperv2 - -import ( - "testing" - "time" - - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discover" -) - -func startTestCluster(n int) []*Whisper { - // Create the batch of simulated peers - nodes := make([]*p2p.Peer, n) - for i := 0; i < n; i++ { - nodes[i] = p2p.NewPeer(discover.NodeID{}, "", nil) - } - whispers := make([]*Whisper, n) - for i := 0; i < n; i++ { - whispers[i] = New() - whispers[i].Start(nil) - } - // Wire all the peers to the root one - for i := 1; i < n; i++ { - src, dst := p2p.MsgPipe() - - go whispers[0].handlePeer(nodes[i], src) - go whispers[i].handlePeer(nodes[0], dst) - } - return whispers -} - -func TestSelfMessage(t *testing.T) { - // Start the single node cluster - client := startTestCluster(1)[0] - - // Start watching for self messages, signal any arrivals - self := client.NewIdentity() - done := make(chan struct{}) - - client.Watch(Filter{ - To: &self.PublicKey, - Fn: func(msg *Message) { - close(done) - }, - }) - // Send a dummy message to oneself - msg := NewMessage([]byte("self whisper")) - envelope, err := msg.Wrap(DefaultPoW, Options{ - From: self, - To: &self.PublicKey, - TTL: DefaultTTL, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - // Dump the message into the system and wait for it to pop back out - if err := client.Send(envelope); err != nil { - t.Fatalf("failed to send self-message: %v", err) - } - select { - case <-done: - case <-time.After(time.Second): - t.Fatalf("self-message receive timeout") - } -} - -func TestDirectMessage(t *testing.T) { - // Start the sender-recipient cluster - cluster := startTestCluster(2) - - sender := cluster[0] - senderId := sender.NewIdentity() - - recipient := cluster[1] - recipientId := recipient.NewIdentity() - - // Watch for arriving messages on the recipient - done := make(chan struct{}) - recipient.Watch(Filter{ - To: &recipientId.PublicKey, - Fn: func(msg *Message) { - close(done) - }, - }) - // Send a dummy message from the sender - msg := NewMessage([]byte("direct whisper")) - envelope, err := msg.Wrap(DefaultPoW, Options{ - From: senderId, - To: &recipientId.PublicKey, - TTL: DefaultTTL, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := sender.Send(envelope); err != nil { - t.Fatalf("failed to send direct message: %v", err) - } - // Wait for an arrival or a timeout - select { - case <-done: - case <-time.After(time.Second): - t.Fatalf("direct message receive timeout") - } -} - -func TestAnonymousBroadcast(t *testing.T) { - testBroadcast(true, t) -} - -func TestIdentifiedBroadcast(t *testing.T) { - testBroadcast(false, t) -} - -func testBroadcast(anonymous bool, t *testing.T) { - // Start the single sender multi recipient cluster - cluster := startTestCluster(3) - - sender := cluster[1] - targets := cluster[1:] - for _, target := range targets { - if !anonymous { - target.NewIdentity() - } - } - // Watch for arriving messages on the recipients - dones := make([]chan struct{}, len(targets)) - for i := 0; i < len(targets); i++ { - done := make(chan struct{}) // need for the closure - dones[i] = done - - targets[i].Watch(Filter{ - Topics: NewFilterTopicsFromStringsFlat("broadcast topic"), - Fn: func(msg *Message) { - close(done) - }, - }) - } - // Send a dummy message from the sender - msg := NewMessage([]byte("broadcast whisper")) - envelope, err := msg.Wrap(DefaultPoW, Options{ - Topics: NewTopicsFromStrings("broadcast topic"), - TTL: DefaultTTL, - }) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := sender.Send(envelope); err != nil { - t.Fatalf("failed to send broadcast message: %v", err) - } - // Wait for an arrival on each recipient, or timeouts - timeout := time.After(time.Second) - for _, done := range dones { - select { - case <-done: - case <-timeout: - t.Fatalf("broadcast message receive timeout") - } - } -} - -func TestMessageExpiration(t *testing.T) { - // Start the single node cluster and inject a dummy message - node := startTestCluster(1)[0] - - message := NewMessage([]byte("expiring message")) - envelope, err := message.Wrap(DefaultPoW, Options{TTL: time.Second}) - if err != nil { - t.Fatalf("failed to wrap message: %v", err) - } - if err := node.Send(envelope); err != nil { - t.Fatalf("failed to inject message: %v", err) - } - // Check that the message is inside the cache - node.poolMu.RLock() - _, found := node.messages[envelope.Hash()] - node.poolMu.RUnlock() - - if !found { - t.Fatalf("message not found in cache") - } - // Wait for expiration and check cache again - time.Sleep(time.Second) // wait for expiration - time.Sleep(2 * expirationCycle) // wait for cleanup cycle - - node.poolMu.RLock() - _, found = node.messages[envelope.Hash()] - node.poolMu.RUnlock() - if found { - t.Fatalf("message not expired from cache") - } - - // Check that adding an expired envelope doesn't do anything. - node.add(envelope) - node.poolMu.RLock() - _, found = node.messages[envelope.Hash()] - node.poolMu.RUnlock() - if found { - t.Fatalf("message was added to cache") - } -} diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go index b4494d0d6..ee566625c 100644 --- a/whisper/whisperv5/api.go +++ b/whisper/whisperv5/api.go @@ -60,32 +60,9 @@ func NewPublicWhisperAPI(w *Whisper) *PublicWhisperAPI { w: w, lastUsed: make(map[string]time.Time), } - - go api.run() return api } -// run the api event loop. -// this loop deletes filter that have not been used within filterTimeout -func (api *PublicWhisperAPI) run() { - timeout := time.NewTicker(2 * time.Minute) - for { - <-timeout.C - - api.mu.Lock() - for id, lastUsed := range api.lastUsed { - if time.Since(lastUsed).Seconds() >= filterTimeout { - delete(api.lastUsed, id) - if err := api.w.Unsubscribe(id); err != nil { - log.Error("could not unsubscribe whisper filter", "error", err) - } - log.Debug("delete whisper filter (timeout)", "id", id) - } - } - api.mu.Unlock() - } -} - // Version returns the Whisper sub-protocol version. func (api *PublicWhisperAPI) Version(ctx context.Context) string { return ProtocolVersionStr diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go index a2c75a41c..96e2b17e7 100644 --- a/whisper/whisperv6/api.go +++ b/whisper/whisperv6/api.go @@ -61,32 +61,9 @@ func NewPublicWhisperAPI(w *Whisper) *PublicWhisperAPI { w: w, lastUsed: make(map[string]time.Time), } - - go api.run() return api } -// run the api event loop. -// this loop deletes filter that have not been used within filterTimeout -func (api *PublicWhisperAPI) run() { - timeout := time.NewTicker(2 * time.Minute) - for { - <-timeout.C - - api.mu.Lock() - for id, lastUsed := range api.lastUsed { - if time.Since(lastUsed).Seconds() >= filterTimeout { - delete(api.lastUsed, id) - if err := api.w.Unsubscribe(id); err != nil { - log.Error("could not unsubscribe whisper filter", "error", err) - } - log.Debug("delete whisper filter (timeout)", "id", id) - } - } - api.mu.Unlock() - } -} - // Version returns the Whisper sub-protocol version. func (api *PublicWhisperAPI) Version(ctx context.Context) string { return ProtocolVersionStr @@ -219,6 +196,19 @@ func (api *PublicWhisperAPI) DeleteSymKey(ctx context.Context, id string) bool { return api.w.DeleteSymKey(id) } +// MakeLightClient turns the node into light client, which does not forward +// any incoming messages, and sends only messages originated in this node. +func (api *PublicWhisperAPI) MakeLightClient(ctx context.Context) bool { + api.w.lightClient = true + return api.w.lightClient +} + +// CancelLightClient cancels light client mode. +func (api *PublicWhisperAPI) CancelLightClient(ctx context.Context) bool { + api.w.lightClient = false + return !api.w.lightClient +} + //go:generate gencodec -type NewMessage -field-override newMessageOverride -out gen_newmessage_json.go // NewMessage represents a new whisper message that is posted through the RPC. diff --git a/whisper/whisperv6/doc.go b/whisper/whisperv6/doc.go index d5d7fed60..066a9766d 100644 --- a/whisper/whisperv6/doc.go +++ b/whisper/whisperv6/doc.go @@ -60,7 +60,7 @@ const ( aesKeyLength = 32 // in bytes aesNonceLength = 12 // in bytes; for more info please see cipher.gcmStandardNonceSize & aesgcm.NonceSize() keyIDSize = 32 // in bytes - bloomFilterSize = 64 // in bytes + BloomFilterSize = 64 // in bytes flagsLength = 1 EnvelopeHeaderLength = 20 diff --git a/whisper/whisperv6/envelope.go b/whisper/whisperv6/envelope.go index c7bea2bb9..2f947f1a4 100644 --- a/whisper/whisperv6/envelope.go +++ b/whisper/whisperv6/envelope.go @@ -208,6 +208,10 @@ func (e *Envelope) OpenSymmetric(key []byte) (msg *ReceivedMessage, err error) { // Open tries to decrypt an envelope, and populates the message fields in case of success. func (e *Envelope) Open(watcher *Filter) (msg *ReceivedMessage) { + if watcher == nil { + return nil + } + // The API interface forbids filters doing both symmetric and asymmetric encryption. if watcher.expectsAsymmetricEncryption() && watcher.expectsSymmetricEncryption() { return nil @@ -249,7 +253,7 @@ func (e *Envelope) Bloom() []byte { // TopicToBloom converts the topic (4 bytes) to the bloom filter (64 bytes) func TopicToBloom(topic TopicType) []byte { - b := make([]byte, bloomFilterSize) + b := make([]byte, BloomFilterSize) var index [3]int for j := 0; j < 3; j++ { index[j] = int(topic[j]) diff --git a/whisper/whisperv6/filter.go b/whisper/whisperv6/filter.go index eb0c65fa3..2f170ddeb 100644 --- a/whisper/whisperv6/filter.go +++ b/whisper/whisperv6/filter.go @@ -35,6 +35,7 @@ type Filter struct { PoW float64 // Proof of work as described in the Whisper spec AllowP2P bool // Indicates whether this filter is interested in direct peer-to-peer messages SymKeyHash common.Hash // The Keccak256Hash of the symmetric key, needed for optimization + id string // unique identifier Messages map[common.Hash]*ReceivedMessage mutex sync.RWMutex @@ -43,15 +44,21 @@ type Filter struct { // Filters represents a collection of filters type Filters struct { watchers map[string]*Filter - whisper *Whisper - mutex sync.RWMutex + + topicMatcher map[TopicType]map[*Filter]struct{} // map a topic to the filters that are interested in being notified when a message matches that topic + allTopicsMatcher map[*Filter]struct{} // list all the filters that will be notified of a new message, no matter what its topic is + + whisper *Whisper + mutex sync.RWMutex } // NewFilters returns a newly created filter collection func NewFilters(w *Whisper) *Filters { return &Filters{ - watchers: make(map[string]*Filter), - whisper: w, + watchers: make(map[string]*Filter), + topicMatcher: make(map[TopicType]map[*Filter]struct{}), + allTopicsMatcher: make(map[*Filter]struct{}), + whisper: w, } } @@ -81,7 +88,9 @@ func (fs *Filters) Install(watcher *Filter) (string, error) { watcher.SymKeyHash = crypto.Keccak256Hash(watcher.KeySym) } + watcher.id = id fs.watchers[id] = watcher + fs.addTopicMatcher(watcher) return id, err } @@ -91,12 +100,51 @@ func (fs *Filters) Uninstall(id string) bool { fs.mutex.Lock() defer fs.mutex.Unlock() if fs.watchers[id] != nil { + fs.removeFromTopicMatchers(fs.watchers[id]) delete(fs.watchers, id) return true } return false } +// addTopicMatcher adds a filter to the topic matchers. +// If the filter's Topics array is empty, it will be tried on every topic. +// Otherwise, it will be tried on the topics specified. +func (fs *Filters) addTopicMatcher(watcher *Filter) { + if len(watcher.Topics) == 0 { + fs.allTopicsMatcher[watcher] = struct{}{} + } else { + for _, t := range watcher.Topics { + topic := BytesToTopic(t) + if fs.topicMatcher[topic] == nil { + fs.topicMatcher[topic] = make(map[*Filter]struct{}) + } + fs.topicMatcher[topic][watcher] = struct{}{} + } + } +} + +// removeFromTopicMatchers removes a filter from the topic matchers +func (fs *Filters) removeFromTopicMatchers(watcher *Filter) { + delete(fs.allTopicsMatcher, watcher) + for _, topic := range watcher.Topics { + delete(fs.topicMatcher[BytesToTopic(topic)], watcher) + } +} + +// getWatchersByTopic returns a slice containing the filters that +// match a specific topic +func (fs *Filters) getWatchersByTopic(topic TopicType) []*Filter { + res := make([]*Filter, 0, len(fs.allTopicsMatcher)) + for watcher := range fs.allTopicsMatcher { + res = append(res, watcher) + } + for watcher := range fs.topicMatcher[topic] { + res = append(res, watcher) + } + return res +} + // Get returns a filter from the collection with a specific ID func (fs *Filters) Get(id string) *Filter { fs.mutex.RLock() @@ -112,11 +160,10 @@ func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) { fs.mutex.RLock() defer fs.mutex.RUnlock() - i := -1 // only used for logging info - for _, watcher := range fs.watchers { - i++ + candidates := fs.getWatchersByTopic(env.Topic) + for _, watcher := range candidates { if p2pMessage && !watcher.AllowP2P { - log.Trace(fmt.Sprintf("msg [%x], filter [%d]: p2p messages are not allowed", env.Hash(), i)) + log.Trace(fmt.Sprintf("msg [%x], filter [%s]: p2p messages are not allowed", env.Hash(), watcher.id)) continue } @@ -128,10 +175,10 @@ func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) { if match { msg = env.Open(watcher) if msg == nil { - log.Trace("processing message: failed to open", "message", env.Hash().Hex(), "filter", i) + log.Trace("processing message: failed to open", "message", env.Hash().Hex(), "filter", watcher.id) } } else { - log.Trace("processing message: does not match", "message", env.Hash().Hex(), "filter", i) + log.Trace("processing message: does not match", "message", env.Hash().Hex(), "filter", watcher.id) } } @@ -144,20 +191,6 @@ func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) { } } -func (f *Filter) processEnvelope(env *Envelope) *ReceivedMessage { - if f.MatchEnvelope(env) { - msg := env.Open(f) - if msg != nil { - return msg - } - - log.Trace("processing envelope: failed to open", "hash", env.Hash().Hex()) - } else { - log.Trace("processing envelope: does not match", "hash", env.Hash().Hex()) - } - return nil -} - func (f *Filter) expectsAsymmetricEncryption() bool { return f.KeyAsym != nil } @@ -194,16 +227,17 @@ func (f *Filter) Retrieve() (all []*ReceivedMessage) { // MatchMessage checks if the filter matches an already decrypted // message (i.e. a Message that has already been handled by -// MatchEnvelope when checked by a previous filter) +// MatchEnvelope when checked by a previous filter). +// Topics are not checked here, since this is done by topic matchers. func (f *Filter) MatchMessage(msg *ReceivedMessage) bool { if f.PoW > 0 && msg.PoW < f.PoW { return false } if f.expectsAsymmetricEncryption() && msg.isAsymmetricEncryption() { - return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst) && f.MatchTopic(msg.Topic) + return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst) } else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() { - return f.SymKeyHash == msg.SymKeyHash && f.MatchTopic(msg.Topic) + return f.SymKeyHash == msg.SymKeyHash } return false } @@ -211,27 +245,9 @@ func (f *Filter) MatchMessage(msg *ReceivedMessage) bool { // MatchEnvelope checks if it's worth decrypting the message. If // it returns `true`, client code is expected to attempt decrypting // the message and subsequently call MatchMessage. +// Topics are not checked here, since this is done by topic matchers. func (f *Filter) MatchEnvelope(envelope *Envelope) bool { - if f.PoW > 0 && envelope.pow < f.PoW { - return false - } - - return f.MatchTopic(envelope.Topic) -} - -// MatchTopic checks that the filter captures a given topic. -func (f *Filter) MatchTopic(topic TopicType) bool { - if len(f.Topics) == 0 { - // any topic matches - return true - } - - for _, bt := range f.Topics { - if matchSingleTopic(topic, bt) { - return true - } - } - return false + return f.PoW <= 0 || envelope.pow >= f.PoW } func matchSingleTopic(topic TopicType, bt []byte) bool { diff --git a/whisper/whisperv6/filter_test.go b/whisper/whisperv6/filter_test.go index e7230ef38..0bb7986c3 100644 --- a/whisper/whisperv6/filter_test.go +++ b/whisper/whisperv6/filter_test.go @@ -303,9 +303,8 @@ func TestMatchEnvelope(t *testing.T) { t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) } - params.Topic[0] = 0xFF // ensure mismatch + params.Topic[0] = 0xFF // topic mismatch - // mismatch with pseudo-random data msg, err := NewSentMessage(params) if err != nil { t.Fatalf("failed to create new message with seed %d: %s.", seed, err) @@ -314,14 +313,6 @@ func TestMatchEnvelope(t *testing.T) { if err != nil { t.Fatalf("failed Wrap with seed %d: %s.", seed, err) } - match := fsym.MatchEnvelope(env) - if match { - t.Fatalf("failed MatchEnvelope symmetric with seed %d.", seed) - } - match = fasym.MatchEnvelope(env) - if match { - t.Fatalf("failed MatchEnvelope asymmetric with seed %d.", seed) - } // encrypt symmetrically i := mrand.Int() % 4 @@ -337,7 +328,7 @@ func TestMatchEnvelope(t *testing.T) { } // symmetric + matching topic: match - match = fsym.MatchEnvelope(env) + match := fsym.MatchEnvelope(env) if !match { t.Fatalf("failed MatchEnvelope() symmetric with seed %d.", seed) } @@ -396,7 +387,7 @@ func TestMatchEnvelope(t *testing.T) { // asymmetric + matching topic: match fasym.Topics[i] = fasym.Topics[i+1] match = fasym.MatchEnvelope(env) - if match { + if !match { t.Fatalf("failed MatchEnvelope(asymmetric + matching topic) with seed %d.", seed) } @@ -431,7 +422,8 @@ func TestMatchEnvelope(t *testing.T) { // filter with topic + envelope without topic: mismatch fasym.Topics = fsym.Topics match = fasym.MatchEnvelope(env) - if match { + if !match { + // topic mismatch should have no affect, as topics are handled by topic matchers t.Fatalf("failed MatchEnvelope(filter without topic + envelope without topic) with seed %d.", seed) } } @@ -487,7 +479,8 @@ func TestMatchMessageSym(t *testing.T) { // topic mismatch f.Topics[index][0]++ - if f.MatchMessage(msg) { + if !f.MatchMessage(msg) { + // topic mismatch should have no affect, as topics are handled by topic matchers t.Fatalf("failed MatchEnvelope(topic mismatch) with seed %d.", seed) } f.Topics[index][0]-- @@ -580,7 +573,8 @@ func TestMatchMessageAsym(t *testing.T) { // topic mismatch f.Topics[index][0]++ - if f.MatchMessage(msg) { + if !f.MatchMessage(msg) { + // topic mismatch should have no affect, as topics are handled by topic matchers t.Fatalf("failed MatchEnvelope(topic mismatch) with seed %d.", seed) } f.Topics[index][0]-- @@ -829,8 +823,9 @@ func TestVariableTopics(t *testing.T) { f.Topics[i][lastTopicByte]++ match = f.MatchEnvelope(env) - if match { - t.Fatalf("MatchEnvelope symmetric with seed %d, step %d: false positive.", seed, i) + if !match { + // topic mismatch should have no affect, as topics are handled by topic matchers + t.Fatalf("MatchEnvelope symmetric with seed %d, step %d.", seed, i) } } } diff --git a/whisper/whisperv6/peer.go b/whisper/whisperv6/peer.go index 4ef0f3c43..2bf1c905b 100644 --- a/whisper/whisperv6/peer.go +++ b/whisper/whisperv6/peer.go @@ -19,6 +19,7 @@ package whisperv6 import ( "fmt" "math" + "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -36,6 +37,7 @@ type Peer struct { trusted bool powRequirement float64 + bloomMu sync.Mutex bloomFilter []byte fullNode bool @@ -54,7 +56,7 @@ func newPeer(host *Whisper, remote *p2p.Peer, rw p2p.MsgReadWriter) *Peer { powRequirement: 0.0, known: set.New(), quit: make(chan struct{}), - bloomFilter: makeFullNodeBloom(), + bloomFilter: MakeFullNodeBloom(), fullNode: true, } } @@ -118,7 +120,7 @@ func (peer *Peer) handshake() error { err = s.Decode(&bloom) if err == nil { sz := len(bloom) - if sz != bloomFilterSize && sz != 0 { + if sz != BloomFilterSize && sz != 0 { return fmt.Errorf("peer [%x] sent bad status message: wrong bloom filter size %d", peer.ID(), sz) } peer.setBloomFilter(bloom) @@ -225,20 +227,24 @@ func (peer *Peer) notifyAboutBloomFilterChange(bloom []byte) error { } func (peer *Peer) bloomMatch(env *Envelope) bool { - return peer.fullNode || bloomFilterMatch(peer.bloomFilter, env.Bloom()) + peer.bloomMu.Lock() + defer peer.bloomMu.Unlock() + return peer.fullNode || BloomFilterMatch(peer.bloomFilter, env.Bloom()) } func (peer *Peer) setBloomFilter(bloom []byte) { + peer.bloomMu.Lock() + defer peer.bloomMu.Unlock() peer.bloomFilter = bloom peer.fullNode = isFullNode(bloom) if peer.fullNode && peer.bloomFilter == nil { - peer.bloomFilter = makeFullNodeBloom() + peer.bloomFilter = MakeFullNodeBloom() } } -func makeFullNodeBloom() []byte { - bloom := make([]byte, bloomFilterSize) - for i := 0; i < bloomFilterSize; i++ { +func MakeFullNodeBloom() []byte { + bloom := make([]byte, BloomFilterSize) + for i := 0; i < BloomFilterSize; i++ { bloom[i] = 0xFF } return bloom diff --git a/whisper/whisperv6/peer_test.go b/whisper/whisperv6/peer_test.go index 9ce5eed8b..ec985ae65 100644 --- a/whisper/whisperv6/peer_test.go +++ b/whisper/whisperv6/peer_test.go @@ -23,6 +23,7 @@ import ( mrand "math/rand" "net" "sync" + "sync/atomic" "testing" "time" @@ -71,7 +72,7 @@ var keys = []string{ } type TestData struct { - started int + started int64 counter [NumNodes]int mutex sync.RWMutex } @@ -151,7 +152,7 @@ func resetParams(t *testing.T) { } func initBloom(t *testing.T) { - masterBloomFilter = make([]byte, bloomFilterSize) + masterBloomFilter = make([]byte, BloomFilterSize) _, err := mrand.Read(masterBloomFilter) if err != nil { t.Fatalf("rand failed: %s.", err) @@ -163,7 +164,7 @@ func initBloom(t *testing.T) { masterBloomFilter[i] = 0xFF } - if !bloomFilterMatch(masterBloomFilter, msgBloom) { + if !BloomFilterMatch(masterBloomFilter, msgBloom) { t.Fatalf("bloom mismatch on initBloom.") } } @@ -177,7 +178,7 @@ func initialize(t *testing.T) { for i := 0; i < NumNodes; i++ { var node TestNode - b := make([]byte, bloomFilterSize) + b := make([]byte, BloomFilterSize) copy(b, masterBloomFilter) node.shh = New(&DefaultConfig) node.shh.SetMinimumPoW(masterPow) @@ -240,9 +241,7 @@ func startServer(t *testing.T, s *p2p.Server) { t.Fatalf("failed to start the fisrt server.") } - result.mutex.Lock() - defer result.mutex.Unlock() - result.started++ + atomic.AddInt64(&result.started, 1) } func stopServers() { @@ -472,7 +471,10 @@ func checkPowExchange(t *testing.T) { func checkBloomFilterExchangeOnce(t *testing.T, mustPass bool) bool { for i, node := range nodes { for peer := range node.shh.peers { - if !bytes.Equal(peer.bloomFilter, masterBloomFilter) { + peer.bloomMu.Lock() + equals := bytes.Equal(peer.bloomFilter, masterBloomFilter) + peer.bloomMu.Unlock() + if !equals { if mustPass { t.Fatalf("node %d: failed to exchange bloom filter requirement in round %d. \n%x expected \n%x got", i, round, masterBloomFilter, peer.bloomFilter) @@ -500,11 +502,13 @@ func checkBloomFilterExchange(t *testing.T) { func waitForServersToStart(t *testing.T) { const iterations = 200 + var started int64 for j := 0; j < iterations; j++ { time.Sleep(50 * time.Millisecond) - if result.started == NumNodes { + started = atomic.LoadInt64(&result.started) + if started == NumNodes { return } } - t.Fatalf("Failed to start all the servers, running: %d", result.started) + t.Fatalf("Failed to start all the servers, running: %d", started) } diff --git a/whisper/whisperv6/whisper.go b/whisper/whisperv6/whisper.go index 600f9cb28..880cced09 100644 --- a/whisper/whisperv6/whisper.go +++ b/whisper/whisperv6/whisper.go @@ -82,6 +82,8 @@ type Whisper struct { syncAllowance int // maximum time in seconds allowed to process the whisper-related messages + lightClient bool // indicates is this node is pure light client (does not forward any messages) + statsMu sync.Mutex // guard stats stats Statistics // Statistics of whisper node @@ -230,11 +232,11 @@ func (whisper *Whisper) SetMaxMessageSize(size uint32) error { // SetBloomFilter sets the new bloom filter func (whisper *Whisper) SetBloomFilter(bloom []byte) error { - if len(bloom) != bloomFilterSize { + if len(bloom) != BloomFilterSize { return fmt.Errorf("invalid bloom filter size: %d", len(bloom)) } - b := make([]byte, bloomFilterSize) + b := make([]byte, BloomFilterSize) copy(b, bloom) whisper.settings.Store(bloomFilterIdx, b) @@ -556,14 +558,14 @@ func (whisper *Whisper) Subscribe(f *Filter) (string, error) { // updateBloomFilter recalculates the new value of bloom filter, // and informs the peers if necessary. func (whisper *Whisper) updateBloomFilter(f *Filter) { - aggregate := make([]byte, bloomFilterSize) + aggregate := make([]byte, BloomFilterSize) for _, t := range f.Topics { top := BytesToTopic(t) b := TopicToBloom(top) aggregate = addBloom(aggregate, b) } - if !bloomFilterMatch(whisper.BloomFilter(), aggregate) { + if !BloomFilterMatch(whisper.BloomFilter(), aggregate) { // existing bloom filter must be updated aggregate = addBloom(whisper.BloomFilter(), aggregate) whisper.SetBloomFilter(aggregate) @@ -587,11 +589,8 @@ func (whisper *Whisper) Unsubscribe(id string) error { // Send injects a message into the whisper send queue, to be distributed in the // network in the coming cycles. func (whisper *Whisper) Send(envelope *Envelope) error { - ok, err := whisper.add(envelope) - if err != nil { - return err - } - if !ok { + ok, err := whisper.add(envelope, false) + if err == nil && !ok { return fmt.Errorf("failed to add envelope") } return err @@ -673,7 +672,7 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error { trouble := false for _, env := range envelopes { - cached, err := whisper.add(env) + cached, err := whisper.add(env, whisper.lightClient) if err != nil { trouble = true log.Error("bad envelope received, peer will be disconnected", "peer", p.peer.ID(), "err", err) @@ -702,7 +701,7 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error { case bloomFilterExCode: var bloom []byte err := packet.Decode(&bloom) - if err == nil && len(bloom) != bloomFilterSize { + if err == nil && len(bloom) != BloomFilterSize { err = fmt.Errorf("wrong bloom filter size %d", len(bloom)) } @@ -746,7 +745,8 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error { // add inserts a new envelope into the message pool to be distributed within the // whisper network. It also inserts the envelope into the expiration pool at the // appropriate time-stamp. In case of error, connection should be dropped. -func (whisper *Whisper) add(envelope *Envelope) (bool, error) { +// param isP2P indicates whether the message is peer-to-peer (should not be forwarded). +func (whisper *Whisper) add(envelope *Envelope, isP2P bool) (bool, error) { now := uint32(time.Now().Unix()) sent := envelope.Expiry - envelope.TTL @@ -779,11 +779,11 @@ func (whisper *Whisper) add(envelope *Envelope) (bool, error) { } } - if !bloomFilterMatch(whisper.BloomFilter(), envelope.Bloom()) { + if !BloomFilterMatch(whisper.BloomFilter(), envelope.Bloom()) { // maybe the value was recently changed, and the peers did not adjust yet. // in this case the previous value is retrieved by BloomFilterTolerance() // for a short period of peer synchronization. - if !bloomFilterMatch(whisper.BloomFilterTolerance(), envelope.Bloom()) { + if !BloomFilterMatch(whisper.BloomFilterTolerance(), envelope.Bloom()) { return false, fmt.Errorf("envelope does not match bloom filter, hash=[%v], bloom: \n%x \n%x \n%x", envelope.Hash().Hex(), whisper.BloomFilter(), envelope.Bloom(), envelope.Topic) } @@ -811,7 +811,7 @@ func (whisper *Whisper) add(envelope *Envelope) (bool, error) { whisper.statsMu.Lock() whisper.stats.memoryUsed += envelope.size() whisper.statsMu.Unlock() - whisper.postEvent(envelope, false) // notify the local node about the new message + whisper.postEvent(envelope, isP2P) // notify the local node about the new message if whisper.mailServer != nil { whisper.mailServer.Archive(envelope) } @@ -928,24 +928,6 @@ func (whisper *Whisper) Envelopes() []*Envelope { return all } -// Messages iterates through all currently floating envelopes -// and retrieves all the messages, that this filter could decrypt. -func (whisper *Whisper) Messages(id string) []*ReceivedMessage { - result := make([]*ReceivedMessage, 0) - whisper.poolMu.RLock() - defer whisper.poolMu.RUnlock() - - if filter := whisper.filters.Get(id); filter != nil { - for _, env := range whisper.envelopes { - msg := filter.processEnvelope(env) - if msg != nil { - result = append(result, msg) - } - } - } - return result -} - // isEnvelopeCached checks if envelope with specific hash has already been received and cached. func (whisper *Whisper) isEnvelopeCached(hash common.Hash) bool { whisper.poolMu.Lock() @@ -1043,12 +1025,12 @@ func isFullNode(bloom []byte) bool { return true } -func bloomFilterMatch(filter, sample []byte) bool { +func BloomFilterMatch(filter, sample []byte) bool { if filter == nil { return true } - for i := 0; i < bloomFilterSize; i++ { + for i := 0; i < BloomFilterSize; i++ { f := filter[i] s := sample[i] if (f | s) != f { @@ -1060,8 +1042,8 @@ func bloomFilterMatch(filter, sample []byte) bool { } func addBloom(a, b []byte) []byte { - c := make([]byte, bloomFilterSize) - for i := 0; i < bloomFilterSize; i++ { + c := make([]byte, BloomFilterSize) + for i := 0; i < BloomFilterSize; i++ { c[i] = a[i] | b[i] } return c diff --git a/whisper/whisperv6/whisper_test.go b/whisper/whisperv6/whisper_test.go index 99e5f0bbb..7fe256309 100644 --- a/whisper/whisperv6/whisper_test.go +++ b/whisper/whisperv6/whisper_test.go @@ -75,10 +75,6 @@ func TestWhisperBasic(t *testing.T) { if len(mail) != 0 { t.Fatalf("failed w.Envelopes().") } - m := w.Messages("non-existent") - if len(m) != 0 { - t.Fatalf("failed w.Messages.") - } derived := pbkdf2.Key([]byte(peerID), nil, 65356, aesKeyLength, sha256.New) if !validateDataIntegrity(derived, aesKeyLength) { @@ -593,7 +589,7 @@ func TestCustomization(t *testing.T) { } // check w.messages() - id, err := w.Subscribe(f) + _, err = w.Subscribe(f) if err != nil { t.Fatalf("failed subscribe with seed %d: %s.", seed, err) } @@ -602,11 +598,6 @@ func TestCustomization(t *testing.T) { if len(mail) > 0 { t.Fatalf("received premature mail") } - - mail = w.Messages(id) - if len(mail) != 2 { - t.Fatalf("failed to get whisper messages") - } } func TestSymmetricSendCycle(t *testing.T) { @@ -835,11 +826,11 @@ func TestSymmetricSendKeyMismatch(t *testing.T) { func TestBloom(t *testing.T) { topic := TopicType{0, 0, 255, 6} b := TopicToBloom(topic) - x := make([]byte, bloomFilterSize) + x := make([]byte, BloomFilterSize) x[0] = byte(1) x[32] = byte(1) - x[bloomFilterSize-1] = byte(128) - if !bloomFilterMatch(x, b) || !bloomFilterMatch(b, x) { + x[BloomFilterSize-1] = byte(128) + if !BloomFilterMatch(x, b) || !BloomFilterMatch(b, x) { t.Fatalf("bloom filter does not match the mask") } @@ -851,11 +842,11 @@ func TestBloom(t *testing.T) { if err != nil { t.Fatalf("math rand error") } - if !bloomFilterMatch(b, b) { + if !BloomFilterMatch(b, b) { t.Fatalf("bloom filter does not match self") } x = addBloom(x, b) - if !bloomFilterMatch(x, b) { + if !BloomFilterMatch(x, b) { t.Fatalf("bloom filter does not match combined bloom") } if !isFullNode(nil) { @@ -865,16 +856,16 @@ func TestBloom(t *testing.T) { if isFullNode(x) { t.Fatalf("isFullNode false positive") } - for i := 0; i < bloomFilterSize; i++ { + for i := 0; i < BloomFilterSize; i++ { b[i] = byte(255) } if !isFullNode(b) { t.Fatalf("isFullNode false negative") } - if bloomFilterMatch(x, b) { + if BloomFilterMatch(x, b) { t.Fatalf("bloomFilterMatch false positive") } - if !bloomFilterMatch(b, x) { + if !BloomFilterMatch(b, x) { t.Fatalf("bloomFilterMatch false negative") } @@ -888,7 +879,7 @@ func TestBloom(t *testing.T) { t.Fatalf("failed to set bloom filter: %s", err) } f = w.BloomFilter() - if !bloomFilterMatch(f, x) || !bloomFilterMatch(x, f) { + if !BloomFilterMatch(f, x) || !BloomFilterMatch(x, f) { t.Fatalf("retireved wrong bloom filter") } } |