From 37e5816bcdaaca2380ce5a56d9a0834340733b31 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Mon, 28 Nov 2016 00:58:22 +0100
Subject: common: use package hexutil for fixed size type encoding

---
 common/types.go      | 58 ++++++++--------------------------------------------
 common/types_test.go | 32 +++++++++++++++++++----------
 2 files changed, 30 insertions(+), 60 deletions(-)

(limited to 'common')

diff --git a/common/types.go b/common/types.go
index 70b7e7aae..8a456e965 100644
--- a/common/types.go
+++ b/common/types.go
@@ -17,14 +17,12 @@
 package common
 
 import (
-	"encoding/hex"
-	"encoding/json"
-	"errors"
 	"fmt"
 	"math/big"
 	"math/rand"
 	"reflect"
-	"strings"
+
+	"github.com/ethereum/go-ethereum/common/hexutil"
 )
 
 const (
@@ -32,8 +30,6 @@ const (
 	AddressLength = 20
 )
 
-var hashJsonLengthErr = errors.New("common: unmarshalJSON failed: hash must be exactly 32 bytes")
-
 type (
 	// Hash represents the 32 byte Keccak256 hash of arbitrary data.
 	Hash [HashLength]byte
@@ -57,30 +53,16 @@ func HexToHash(s string) Hash    { return BytesToHash(FromHex(s)) }
 func (h Hash) Str() string   { return string(h[:]) }
 func (h Hash) Bytes() []byte { return h[:] }
 func (h Hash) Big() *big.Int { return Bytes2Big(h[:]) }
-func (h Hash) Hex() string   { return "0x" + Bytes2Hex(h[:]) }
+func (h Hash) Hex() string   { return hexutil.Encode(h[:]) }
 
 // UnmarshalJSON parses a hash in its hex from to a hash.
 func (h *Hash) UnmarshalJSON(input []byte) error {
-	length := len(input)
-	if length >= 2 && input[0] == '"' && input[length-1] == '"' {
-		input = input[1 : length-1]
-	}
-	// strip "0x" for length check
-	if len(input) > 1 && strings.ToLower(string(input[:2])) == "0x" {
-		input = input[2:]
-	}
-
-	// validate the length of the input hash
-	if len(input) != HashLength*2 {
-		return hashJsonLengthErr
-	}
-	h.SetBytes(FromHex(string(input)))
-	return nil
+	return hexutil.UnmarshalJSON("Hash", input, h[:])
 }
 
 // Serialize given hash to JSON
 func (h Hash) MarshalJSON() ([]byte, error) {
-	return json.Marshal(h.Hex())
+	return hexutil.Bytes(h[:]).MarshalJSON()
 }
 
 // Sets the hash to the value of b. If b is larger than len(h) it will panic
@@ -142,7 +124,7 @@ func (a Address) Str() string   { return string(a[:]) }
 func (a Address) Bytes() []byte { return a[:] }
 func (a Address) Big() *big.Int { return Bytes2Big(a[:]) }
 func (a Address) Hash() Hash    { return BytesToHash(a[:]) }
-func (a Address) Hex() string   { return "0x" + Bytes2Hex(a[:]) }
+func (a Address) Hex() string   { return hexutil.Encode(a[:]) }
 
 // Sets the address to the value of b. If b is larger than len(a) it will panic
 func (a *Address) SetBytes(b []byte) {
@@ -164,34 +146,12 @@ func (a *Address) Set(other Address) {
 
 // Serialize given address to JSON
 func (a Address) MarshalJSON() ([]byte, error) {
-	return json.Marshal(a.Hex())
+	return hexutil.Bytes(a[:]).MarshalJSON()
 }
 
 // Parse address from raw json data
-func (a *Address) UnmarshalJSON(data []byte) error {
-	if len(data) > 2 && data[0] == '"' && data[len(data)-1] == '"' {
-		data = data[1 : len(data)-1]
-	}
-
-	if len(data) > 2 && data[0] == '0' && data[1] == 'x' {
-		data = data[2:]
-	}
-
-	if len(data) != 2*AddressLength {
-		return fmt.Errorf("Invalid address length, expected %d got %d bytes", 2*AddressLength, len(data))
-	}
-
-	n, err := hex.Decode(a[:], data)
-	if err != nil {
-		return err
-	}
-
-	if n != AddressLength {
-		return fmt.Errorf("Invalid address")
-	}
-
-	a.Set(HexToAddress(string(data)))
-	return nil
+func (a *Address) UnmarshalJSON(input []byte) error {
+	return hexutil.UnmarshalJSON("Address", input, a[:])
 }
 
 // PP Pretty Prints a byte slice in the following format:
diff --git a/common/types_test.go b/common/types_test.go
index de67cfcb5..e84780f43 100644
--- a/common/types_test.go
+++ b/common/types_test.go
@@ -18,7 +18,10 @@ package common
 
 import (
 	"math/big"
+	"strings"
 	"testing"
+
+	"github.com/ethereum/go-ethereum/common/hexutil"
 )
 
 func TestBytesConversion(t *testing.T) {
@@ -38,19 +41,26 @@ func TestHashJsonValidation(t *testing.T) {
 	var tests = []struct {
 		Prefix string
 		Size   int
-		Error  error
+		Error  string
 	}{
-		{"", 2, hashJsonLengthErr},
-		{"", 62, hashJsonLengthErr},
-		{"", 66, hashJsonLengthErr},
-		{"", 65, hashJsonLengthErr},
-		{"0X", 64, nil},
-		{"0x", 64, nil},
-		{"0x", 62, hashJsonLengthErr},
+		{"", 62, hexutil.ErrMissingPrefix.Error()},
+		{"0x", 66, "hex string has length 66, want 64 for Hash"},
+		{"0x", 63, hexutil.ErrOddLength.Error()},
+		{"0x", 0, "hex string has length 0, want 64 for Hash"},
+		{"0x", 64, ""},
+		{"0X", 64, ""},
 	}
-	for i, test := range tests {
-		if err := h.UnmarshalJSON(append([]byte(test.Prefix), make([]byte, test.Size)...)); err != test.Error {
-			t.Errorf("test #%d: error mismatch: have %v, want %v", i, err, test.Error)
+	for _, test := range tests {
+		input := `"` + test.Prefix + strings.Repeat("0", test.Size) + `"`
+		err := h.UnmarshalJSON([]byte(input))
+		if err == nil {
+			if test.Error != "" {
+				t.Errorf("%s: error mismatch: have nil, want %q", input, test.Error)
+			}
+		} else {
+			if err.Error() != test.Error {
+				t.Errorf("%s: error mismatch: have %q, want %q", input, err, test.Error)
+			}
 		}
 	}
 }
-- 
cgit v1.2.3