aboutsummaryrefslogtreecommitdiffstats
path: root/common
diff options
context:
space:
mode:
authorVincent Serpoul <vincent@serpoul.com>2018-07-24 21:15:07 +0800
committerFelix Lange <fjl@users.noreply.github.com>2018-07-24 21:15:07 +0800
commit2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334 (patch)
treebd1fb46b09efed7478daeea0ad34c3bb643de448 /common
parentd96ba77113e1a87e0402fa4eb6a5776786f8e005 (diff)
downloaddexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar.gz
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar.bz2
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar.lz
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar.xz
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.tar.zst
dexon-2909f6d7a2ceb5b1cdeb4cc3966531018a0b8334.zip
common: add database/sql support for Hash and Address (#15541)
Diffstat (limited to 'common')
-rw-r--r--common/types.go41
-rw-r--r--common/types_test.go180
2 files changed, 219 insertions, 2 deletions
diff --git a/common/types.go b/common/types.go
index 4d374ad24..71fe5c95c 100644
--- a/common/types.go
+++ b/common/types.go
@@ -17,6 +17,7 @@
package common
import (
+ "database/sql/driver"
"encoding/hex"
"encoding/json"
"fmt"
@@ -31,7 +32,9 @@ import (
// Lengths of hashes and addresses in bytes.
const (
- HashLength = 32
+ // HashLength is the expected length of the hash
+ HashLength = 32
+ // AddressLength is the expected length of the adddress
AddressLength = 20
)
@@ -120,6 +123,24 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(h)
}
+// Scan implements Scanner for database/sql.
+func (h *Hash) Scan(src interface{}) error {
+ srcB, ok := src.([]byte)
+ if !ok {
+ return fmt.Errorf("can't scan %T into Hash", src)
+ }
+ if len(srcB) != HashLength {
+ return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), HashLength)
+ }
+ copy(h[:], srcB)
+ return nil
+}
+
+// Value implements valuer for database/sql.
+func (h Hash) Value() (driver.Value, error) {
+ return h[:], nil
+}
+
// UnprefixedHash allows marshaling a Hash without 0x prefix.
type UnprefixedHash Hash
@@ -229,6 +250,24 @@ func (a *Address) UnmarshalJSON(input []byte) error {
return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
}
+// Scan implements Scanner for database/sql.
+func (a *Address) Scan(src interface{}) error {
+ srcB, ok := src.([]byte)
+ if !ok {
+ return fmt.Errorf("can't scan %T into Address", src)
+ }
+ if len(srcB) != AddressLength {
+ return fmt.Errorf("can't scan []byte of len %d into Address, want %d", len(srcB), AddressLength)
+ }
+ copy(a[:], srcB)
+ return nil
+}
+
+// Value implements valuer for database/sql.
+func (a Address) Value() (driver.Value, error) {
+ return a[:], nil
+}
+
// UnprefixedAddress allows marshaling an Address without 0x prefix.
type UnprefixedAddress Address
diff --git a/common/types_test.go b/common/types_test.go
index 9e0c5be3a..7095ccd01 100644
--- a/common/types_test.go
+++ b/common/types_test.go
@@ -17,9 +17,10 @@
package common
import (
+ "database/sql/driver"
"encoding/json"
-
"math/big"
+ "reflect"
"strings"
"testing"
)
@@ -193,3 +194,180 @@ func TestMixedcaseAccount_Address(t *testing.T) {
}
}
+
+func TestHash_Scan(t *testing.T) {
+ type args struct {
+ src interface{}
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ name: "working scan",
+ args: args{src: []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ 0x10, 0x00,
+ }},
+ wantErr: false,
+ },
+ {
+ name: "non working scan",
+ args: args{src: int64(1234567890)},
+ wantErr: true,
+ },
+ {
+ name: "invalid length scan",
+ args: args{src: []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ }},
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ h := &Hash{}
+ if err := h.Scan(tt.args.src); (err != nil) != tt.wantErr {
+ t.Errorf("Hash.Scan() error = %v, wantErr %v", err, tt.wantErr)
+ }
+
+ if !tt.wantErr {
+ for i := range h {
+ if h[i] != tt.args.src.([]byte)[i] {
+ t.Errorf(
+ "Hash.Scan() didn't scan the %d src correctly (have %X, want %X)",
+ i, h[i], tt.args.src.([]byte)[i],
+ )
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestHash_Value(t *testing.T) {
+ b := []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ 0x10, 0x00,
+ }
+ var usedH Hash
+ usedH.SetBytes(b)
+ tests := []struct {
+ name string
+ h Hash
+ want driver.Value
+ wantErr bool
+ }{
+ {
+ name: "Working value",
+ h: usedH,
+ want: b,
+ wantErr: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.h.Value()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Hash.Value() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Hash.Value() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAddress_Scan(t *testing.T) {
+ type args struct {
+ src interface{}
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ name: "working scan",
+ args: args{src: []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ }},
+ wantErr: false,
+ },
+ {
+ name: "non working scan",
+ args: args{src: int64(1234567890)},
+ wantErr: true,
+ },
+ {
+ name: "invalid length scan",
+ args: args{src: []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a,
+ }},
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Address{}
+ if err := a.Scan(tt.args.src); (err != nil) != tt.wantErr {
+ t.Errorf("Address.Scan() error = %v, wantErr %v", err, tt.wantErr)
+ }
+
+ if !tt.wantErr {
+ for i := range a {
+ if a[i] != tt.args.src.([]byte)[i] {
+ t.Errorf(
+ "Address.Scan() didn't scan the %d src correctly (have %X, want %X)",
+ i, a[i], tt.args.src.([]byte)[i],
+ )
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestAddress_Value(t *testing.T) {
+ b := []byte{
+ 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+ 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+ }
+ var usedA Address
+ usedA.SetBytes(b)
+ tests := []struct {
+ name string
+ a Address
+ want driver.Value
+ wantErr bool
+ }{
+ {
+ name: "Working value",
+ a: usedA,
+ want: b,
+ wantErr: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.a.Value()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Address.Value() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Address.Value() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}