aboutsummaryrefslogtreecommitdiffstats
path: root/common/math
diff options
context:
space:
mode:
Diffstat (limited to 'common/math')
-rw-r--r--common/math/big.go19
-rw-r--r--common/math/big_test.go13
-rw-r--r--common/math/integer.go23
-rw-r--r--common/math/integer_test.go11
4 files changed, 54 insertions, 12 deletions
diff --git a/common/math/big.go b/common/math/big.go
index 704ca40a9..0b67a1b50 100644
--- a/common/math/big.go
+++ b/common/math/big.go
@@ -18,6 +18,7 @@
package math
import (
+ "fmt"
"math/big"
)
@@ -35,6 +36,24 @@ const (
wordBytes = wordBits / 8
)
+// HexOrDecimal256 marshals big.Int as hex or decimal.
+type HexOrDecimal256 big.Int
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (i *HexOrDecimal256) UnmarshalText(input []byte) error {
+ bigint, ok := ParseBig256(string(input))
+ if !ok {
+ return fmt.Errorf("invalid hex or decimal integer %q", input)
+ }
+ *i = HexOrDecimal256(*bigint)
+ return nil
+}
+
+// MarshalText implements encoding.TextMarshaler.
+func (i *HexOrDecimal256) MarshalText() ([]byte, error) {
+ return []byte(fmt.Sprintf("%#x", (*big.Int)(i))), nil
+}
+
// ParseBig256 parses s as a 256 bit integer in decimal or hexadecimal syntax.
// Leading zeros are accepted. The empty string parses as zero.
func ParseBig256(s string) (*big.Int, bool) {
diff --git a/common/math/big_test.go b/common/math/big_test.go
index 6eb13f4f1..deff25465 100644
--- a/common/math/big_test.go
+++ b/common/math/big_test.go
@@ -23,7 +23,7 @@ import (
"testing"
)
-func TestParseBig256(t *testing.T) {
+func TestHexOrDecimal256(t *testing.T) {
tests := []struct {
input string
num *big.Int
@@ -47,13 +47,14 @@ func TestParseBig256(t *testing.T) {
{"115792089237316195423570985008687907853269984665640564039457584007913129639936", nil, false},
}
for _, test := range tests {
- num, ok := ParseBig256(test.input)
- if ok != test.ok {
- t.Errorf("ParseBig(%q) -> ok = %t, want %t", test.input, ok, test.ok)
+ var num HexOrDecimal256
+ err := num.UnmarshalText([]byte(test.input))
+ if (err == nil) != test.ok {
+ t.Errorf("ParseBig(%q) -> (err == nil) == %t, want %t", test.input, err == nil, test.ok)
continue
}
- if num != nil && test.num != nil && num.Cmp(test.num) != 0 {
- t.Errorf("ParseBig(%q) -> %d, want %d", test.input, num, test.num)
+ if test.num != nil && (*big.Int)(&num).Cmp(test.num) != 0 {
+ t.Errorf("ParseBig(%q) -> %d, want %d", test.input, (*big.Int)(&num), test.num)
}
}
}
diff --git a/common/math/integer.go b/common/math/integer.go
index a3eeee27e..7eff4d3b0 100644
--- a/common/math/integer.go
+++ b/common/math/integer.go
@@ -16,7 +16,10 @@
package math
-import "strconv"
+import (
+ "fmt"
+ "strconv"
+)
const (
// Integer limit values.
@@ -34,6 +37,24 @@ const (
MaxUint64 = 1<<64 - 1
)
+// HexOrDecimal64 marshals uint64 as hex or decimal.
+type HexOrDecimal64 uint64
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (i *HexOrDecimal64) UnmarshalText(input []byte) error {
+ int, ok := ParseUint64(string(input))
+ if !ok {
+ return fmt.Errorf("invalid hex or decimal integer %q", input)
+ }
+ *i = HexOrDecimal64(int)
+ return nil
+}
+
+// MarshalText implements encoding.TextMarshaler.
+func (i HexOrDecimal64) MarshalText() ([]byte, error) {
+ return []byte(fmt.Sprintf("%#x", uint64(i))), nil
+}
+
// ParseUint64 parses s as an integer in decimal or hexadecimal syntax.
// Leading zeros are accepted. The empty string parses as zero.
func ParseUint64(s string) (uint64, bool) {
diff --git a/common/math/integer_test.go b/common/math/integer_test.go
index 05bba221f..b31c7c26c 100644
--- a/common/math/integer_test.go
+++ b/common/math/integer_test.go
@@ -65,7 +65,7 @@ func TestOverflow(t *testing.T) {
}
}
-func TestParseUint64(t *testing.T) {
+func TestHexOrDecimal64(t *testing.T) {
tests := []struct {
input string
num uint64
@@ -88,12 +88,13 @@ func TestParseUint64(t *testing.T) {
{"18446744073709551617", 0, false},
}
for _, test := range tests {
- num, ok := ParseUint64(test.input)
- if ok != test.ok {
- t.Errorf("ParseUint64(%q) -> ok = %t, want %t", test.input, ok, test.ok)
+ var num HexOrDecimal64
+ err := num.UnmarshalText([]byte(test.input))
+ if (err == nil) != test.ok {
+ t.Errorf("ParseUint64(%q) -> (err == nil) = %t, want %t", test.input, err == nil, test.ok)
continue
}
- if ok && num != test.num {
+ if err == nil && uint64(num) != test.num {
t.Errorf("ParseUint64(%q) -> %d, want %d", test.input, num, test.num)
}
}