From 95fab1aca3f01c33491b6a4840b0e2d814aa6e77 Mon Sep 17 00:00:00 2001
From: Jhih-Ming Huang <jm.huang@cobinhood.com>
Date: Fri, 15 Mar 2019 17:31:52 +0800
Subject: core: vm: sqlvm: ast: add size func and move error code to errors.go

Move error code to errors.go, and implement Size method for column field.
---
 core/vm/sqlvm/ast/types.go      | 48 +++++++++++++++++++++++++++--------------
 core/vm/sqlvm/ast/types_test.go | 29 ++++++++++++++++++++++---
 core/vm/sqlvm/errors/errors.go  | 11 ++++++++++
 3 files changed, 69 insertions(+), 19 deletions(-)

(limited to 'core/vm')

diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index c006086e9..65fdc501d 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -1,7 +1,7 @@
 package ast
 
 import (
-	"errors"
+	"fmt"
 	"math/big"
 
 	"github.com/shopspring/decimal"
@@ -24,15 +24,6 @@ var (
 	decPairMap = make(map[DataType]decPair)
 )
 
-// Error defines.
-var (
-	ErrDataTypeEncode = errors.New("data type encode failed")
-	ErrDataTypeDecode = errors.New("data type decode failed")
-	ErrDecimalEncode  = errors.New("decimal encode failed")
-	ErrDecimalDecode  = errors.New("decimal decode failed")
-	ErrGetMinMax      = errors.New("get (min, max) failed")
-)
-
 // DataTypeMajor defines type for high byte of DataType.
 type DataTypeMajor uint8
 
@@ -77,6 +68,29 @@ func ComposeDataType(major DataTypeMajor, minor DataTypeMinor) DataType {
 	return (DataType(major) << 8) | DataType(minor)
 }
 
+// Size return the bytes of the data type occupied.
+func (dt DataType) Size() uint8 {
+	major, minor := DecomposeDataType(dt)
+	if major.IsFixedRange() {
+		return uint8(major - DataTypeMajorFixed + 1)
+	}
+	if major.IsUfixedRange() {
+		return uint8(major - DataTypeMajorUfixed + 1)
+	}
+	switch major {
+	case DataTypeMajorBool:
+		return 1
+	case DataTypeMajorDynamicBytes:
+		return common.HashLength
+	case DataTypeMajorAddress:
+		return common.AddressLength
+	case DataTypeMajorInt, DataTypeMajorUint, DataTypeMajorFixedBytes:
+		return uint8(minor + 1)
+	default:
+		panic(fmt.Sprintf("unknown data type %v", dt))
+	}
+}
+
 // IsFixedRange checks if major is in range of DataTypeMajorFixed.
 func (d DataTypeMajor) IsFixedRange() bool {
 	return d >= DataTypeMajorFixed && d-DataTypeMajorFixed <= 0x1f
@@ -90,7 +104,7 @@ func (d DataTypeMajor) IsUfixedRange() bool {
 // DataTypeEncode encodes data type node into DataType.
 func DataTypeEncode(n TypeNode) (DataType, error) {
 	if n == nil {
-		return DataTypeUnknown, ErrDataTypeEncode
+		return DataTypeUnknown, se.ErrorCodeDataTypeEncode
 	}
 	t, code := n.GetType()
 	if code == se.ErrorCodeNil {
@@ -152,7 +166,7 @@ func DataTypeDecode(t DataType) (TypeNode, error) {
 			}, nil
 		}
 	}
-	return nil, ErrDataTypeDecode
+	return nil, se.ErrorCodeDataTypeDecode
 }
 
 // Don't handle overflow here.
@@ -212,7 +226,8 @@ func DecimalEncode(dt DataType, d decimal.Decimal) ([]byte, error) {
 	major, minor := DecomposeDataType(dt)
 	switch major {
 	case DataTypeMajorInt,
-		DataTypeMajorUint:
+		DataTypeMajorUint,
+		DataTypeMajorFixedBytes:
 		return decimalEncode(int(minor)+1, d), nil
 	case DataTypeMajorAddress:
 		return decimalEncode(common.AddressLength, d), nil
@@ -228,7 +243,7 @@ func DecimalEncode(dt DataType, d decimal.Decimal) ([]byte, error) {
 			d.Shift(int32(minor))), nil
 	}
 
-	return nil, ErrDecimalEncode
+	return nil, se.ErrorCodeDecimalEncode
 }
 
 // DecimalDecode decodes decimal from bytes.
@@ -238,6 +253,7 @@ func DecimalDecode(dt DataType, b []byte) (decimal.Decimal, error) {
 	case DataTypeMajorInt:
 		return decimalDecode(true, b), nil
 	case DataTypeMajorUint,
+		DataTypeMajorFixedBytes,
 		DataTypeMajorAddress:
 		return decimalDecode(false, b), nil
 	}
@@ -248,7 +264,7 @@ func DecimalDecode(dt DataType, b []byte) (decimal.Decimal, error) {
 		return decimalDecode(false, b).Shift(-int32(minor)), nil
 	}
 
-	return decimal.Zero, ErrDecimalDecode
+	return decimal.Zero, se.ErrorCodeDecimalDecode
 }
 
 // GetMinMax returns min, max pair according to given data type.
@@ -275,7 +291,7 @@ func GetMinMax(dt DataType) (min, max decimal.Decimal, err error) {
 		bigUMax := new(big.Int).Lsh(bigIntOne, (uint(minor)+1)*8)
 		max = decimal.NewFromBigInt(bigUMax, 0).Sub(dec.One)
 	default:
-		err = ErrGetMinMax
+		err = se.ErrorCodeGetMinMax
 		return
 	}
 
diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go
index fe125ba2a..ada5d487f 100644
--- a/core/vm/sqlvm/ast/types_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -3,10 +3,12 @@ package ast
 import (
 	"testing"
 
-	"github.com/dexon-foundation/dexon/common"
-	dec "github.com/dexon-foundation/dexon/core/vm/sqlvm/common/decimal"
 	"github.com/shopspring/decimal"
 	"github.com/stretchr/testify/suite"
+
+	"github.com/dexon-foundation/dexon/common"
+	dec "github.com/dexon-foundation/dexon/core/vm/sqlvm/common/decimal"
+	"github.com/dexon-foundation/dexon/core/vm/sqlvm/errors"
 )
 
 type TypesTestSuite struct{ suite.Suite }
@@ -170,7 +172,7 @@ func (s *TypesTestSuite) TestGetMinMax() {
 		{"UInt16", ComposeDataType(DataTypeMajorUint, 1), decimal.Zero, decimal.New(65535, 0), nil},
 		{"Bytes1", ComposeDataType(DataTypeMajorFixedBytes, 0), decimal.Zero, decimal.New(255, 0), nil},
 		{"Bytes2", ComposeDataType(DataTypeMajorFixedBytes, 1), decimal.Zero, decimal.New(65535, 0), nil},
-		{"Dynamic Bytes", ComposeDataType(DataTypeMajorDynamicBytes, 0), decimal.Zero, decimal.Zero, ErrGetMinMax},
+		{"Dynamic Bytes", ComposeDataType(DataTypeMajorDynamicBytes, 0), decimal.Zero, decimal.Zero, errors.ErrorCodeGetMinMax},
 	}
 
 	var (
@@ -189,6 +191,27 @@ func (s *TypesTestSuite) TestGetMinMax() {
 	}
 }
 
+func (s *TypesTestSuite) TestSize() {
+	testcases := []struct {
+		Name string
+		Dt   DataType
+		Size uint8
+	}{
+		{"Bool", ComposeDataType(DataTypeMajorBool, 0), 1},
+		{"Address", ComposeDataType(DataTypeMajorAddress, 0), 20},
+		{"Int8", ComposeDataType(DataTypeMajorInt, 0), 1},
+		{"Int16", ComposeDataType(DataTypeMajorInt, 1), 2},
+		{"UInt8", ComposeDataType(DataTypeMajorUint, 0), 1},
+		{"UInt16", ComposeDataType(DataTypeMajorUint, 1), 2},
+		{"Bytes1", ComposeDataType(DataTypeMajorFixedBytes, 0), 1},
+		{"Bytes2", ComposeDataType(DataTypeMajorFixedBytes, 1), 2},
+		{"Dynamic Bytes", ComposeDataType(DataTypeMajorDynamicBytes, 0), 32},
+	}
+	for _, t := range testcases {
+		s.Require().Equalf(t.Size, t.Dt.Size(), "Testcase %v", t.Name)
+	}
+}
+
 func TestTypes(t *testing.T) {
 	suite.Run(t, new(TypesTestSuite))
 }
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index e276245e5..ae3c3b8a5 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -108,6 +108,12 @@ const (
 	ErrorCodeInvalidUfixedSize
 	ErrorCodeInvalidFixedFractionalDigits
 	ErrorCodeInvalidUfixedFractionalDigits
+	ErrorCodeDataTypeEncode
+	ErrorCodeDataTypeDecode
+	ErrorCodeDecimalEncode
+	ErrorCodeDecimalDecode
+	ErrorCodeGetMinMax
+
 	// Runtime Error
 	ErrorCodeInvalidDataType
 	ErrorCodeOverflow
@@ -135,6 +141,11 @@ var errorCodeMap = [...]string{
 	ErrorCodeInvalidUfixedSize:             "invalid ufixed size",
 	ErrorCodeInvalidFixedFractionalDigits:  "invalid fixed fractional digits",
 	ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits",
+	ErrorCodeDataTypeEncode:                "data type encode failed",
+	ErrorCodeDataTypeDecode:                "data type decode failed",
+	ErrorCodeDecimalEncode:                 "decimal encode failed",
+	ErrorCodeDecimalDecode:                 "decimal decode failed",
+	ErrorCodeGetMinMax:                     "get (min, max) failed",
 	// Runtime Error
 	ErrorCodeInvalidDataType: "invalid data type",
 	ErrorCodeOverflow:        "overflow",
-- 
cgit v1.2.3