aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/vm/sqlvm/ast/types.go (renamed from core/vm/sqlvm/ast/type.go)116
-rw-r--r--core/vm/sqlvm/ast/types_test.go (renamed from core/vm/sqlvm/ast/type_test.go)91
2 files changed, 196 insertions, 11 deletions
diff --git a/core/vm/sqlvm/ast/type.go b/core/vm/sqlvm/ast/types.go
index 06f0c0207..80cda796c 100644
--- a/core/vm/sqlvm/ast/type.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -2,13 +2,25 @@ package ast
import (
"errors"
+ "math/big"
"reflect"
+
+ "github.com/shopspring/decimal"
+
+ "github.com/dexon-foundation/dexon/common"
+)
+
+var (
+ bigIntOne = big.NewInt(1)
+ bigIntTen = big.NewInt(10)
)
// Error defines.
var (
ErrDataTypeEncode = errors.New("data type encode failed")
ErrDataTypeDecode = errors.New("data type decode failed")
+ ErrDecimalEncode = errors.New("decimal encode failed")
+ ErrDecimalDecode = errors.New("decimal decode failed")
)
// DataTypeMajor defines type for high byte of DataType.
@@ -45,6 +57,16 @@ func composeDataType(major DataTypeMajor, minor DataTypeMinor) DataType {
return (DataType(major) << 8) | DataType(minor)
}
+// IsFixedRange checks if major is in range of DataTypeMajorFixed.
+func (d DataTypeMajor) IsFixedRange() bool {
+ return d >= DataTypeMajorFixed && d-DataTypeMajorFixed <= 0x1f
+}
+
+// IsUfixedRange checks if major is in range of DataTypeMajorUfixed.
+func (d DataTypeMajor) IsUfixedRange() bool {
+ return d >= DataTypeMajorUfixed && d-DataTypeMajorUfixed <= 0x1f
+}
+
// DataTypeEncode encodes data type node into DataType.
func DataTypeEncode(n interface{}) (DataType, error) {
if n == nil {
@@ -137,7 +159,7 @@ func DataTypeDecode(t DataType) (interface{}, error) {
}
}
switch {
- case major >= DataTypeMajorFixed && major-DataTypeMajorFixed <= 0x1f:
+ case major.IsFixedRange():
if minor <= 80 {
size := (uint32(major-DataTypeMajorFixed) + 1) * 8
return FixedTypeNode{
@@ -146,7 +168,7 @@ func DataTypeDecode(t DataType) (interface{}, error) {
FractionalDigits: uint32(minor),
}, nil
}
- case major >= DataTypeMajorUfixed && major-DataTypeMajorUfixed <= 0x1f:
+ case major.IsUfixedRange():
if minor <= 80 {
size := (uint32(major-DataTypeMajorUfixed) + 1) * 8
return FixedTypeNode{
@@ -158,3 +180,93 @@ func DataTypeDecode(t DataType) (interface{}, error) {
}
return nil, ErrDataTypeDecode
}
+
+// Don't handle overflow here.
+func decimalEncode(size int, d decimal.Decimal) []byte {
+ ret := make([]byte, size)
+ s := d.Sign()
+ if s == 0 {
+ return ret
+ }
+
+ exp := new(big.Int).Exp(bigIntTen, big.NewInt(int64(d.Exponent())), nil)
+ b := new(big.Int).Mul(d.Coefficient(), exp)
+
+ if s > 0 {
+ bs := b.Bytes()
+ copy(ret[size-len(bs):], bs)
+ return ret
+ }
+
+ b.Add(b, bigIntOne)
+ bs := b.Bytes()
+ copy(ret[size-len(bs):], bs)
+ for idx := range ret {
+ ret[idx] = ^ret[idx]
+ }
+ return ret
+}
+
+// Don't handle overflow here.
+func decimalDecode(signed bool, bs []byte) decimal.Decimal {
+ neg := false
+ if signed && (bs[0]&0x80 != 0) {
+ neg = true
+ for idx := range bs {
+ bs[idx] = ^bs[idx]
+ }
+ }
+
+ b := new(big.Int).SetBytes(bs)
+
+ if neg {
+ b.Add(b, bigIntOne)
+ b.Neg(b)
+ }
+
+ return decimal.NewFromBigInt(b, 0)
+}
+
+// DecimalEncode encodes decimal to bytes depend on data type.
+func DecimalEncode(dt DataType, d decimal.Decimal) ([]byte, error) {
+ major, minor := decomposeDataType(dt)
+ switch major {
+ case DataTypeMajorInt,
+ DataTypeMajorUint:
+ return decimalEncode(int(minor)+1, d), nil
+ case DataTypeMajorAddress:
+ return decimalEncode(common.AddressLength, d), nil
+ }
+ switch {
+ case major.IsFixedRange():
+ return decimalEncode(
+ int(major-DataTypeMajorFixed)+1,
+ d.Shift(int32(minor))), nil
+ case major.IsUfixedRange():
+ return decimalEncode(
+ int(major-DataTypeMajorUfixed)+1,
+ d.Shift(int32(minor))), nil
+ }
+
+ return nil, ErrDecimalEncode
+}
+
+// DecimalDecode decodes decimal from bytes.
+func DecimalDecode(dt DataType, b []byte) (decimal.Decimal, error) {
+ major, minor := decomposeDataType(dt)
+ switch major {
+ case DataTypeMajorInt:
+ return decimalDecode(true, b), nil
+ case DataTypeMajorUint,
+ DataTypeMajorAddress:
+ return decimalDecode(false, b), nil
+ }
+ switch {
+ case major.IsFixedRange():
+ return decimalDecode(true, b).Shift(-int32(minor)), nil
+ case major.IsUfixedRange():
+ return decimalDecode(false, b).Shift(-int32(minor)), nil
+ }
+
+ return decimal.Zero, ErrDecimalDecode
+}
diff --git a/core/vm/sqlvm/ast/type_test.go b/core/vm/sqlvm/ast/types_test.go
index 41c5d3a20..31ed224fb 100644
--- a/core/vm/sqlvm/ast/type_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -3,12 +3,23 @@ package ast
import (
"testing"
+ "github.com/shopspring/decimal"
"github.com/stretchr/testify/suite"
)
-type TypeTestSuite struct{ suite.Suite }
+type TypesTestSuite struct{ suite.Suite }
-func (s *TypeTestSuite) requireEncodeAndDecodeNoError(
+func (s *TypesTestSuite) requireEncodeAndDecodeDecimalNoError(
+ d DataType, t decimal.Decimal, bs int) {
+ encode, err := DecimalEncode(d, t)
+ s.Require().NoError(err)
+ s.Require().Len(encode, bs)
+ decode, err := DecimalDecode(d, encode)
+ s.Require().NoError(err)
+ s.Require().Equal(t.String(), decode.String())
+}
+
+func (s *TypesTestSuite) requireEncodeAndDecodeNoError(
d DataType, t interface{}) {
encode, err := DataTypeEncode(t)
s.Require().NoError(err)
@@ -18,17 +29,17 @@ func (s *TypeTestSuite) requireEncodeAndDecodeNoError(
s.Require().Equal(t, decode)
}
-func (s *TypeTestSuite) requireEncodeError(input interface{}) {
+func (s *TypesTestSuite) requireEncodeError(input interface{}) {
_, err := DataTypeEncode(input)
s.Require().Error(err)
}
-func (s *TypeTestSuite) requireDecodeError(input DataType) {
+func (s *TypesTestSuite) requireDecodeError(input DataType) {
_, err := DataTypeDecode(input)
s.Require().Error(err)
}
-func (s *TypeTestSuite) TestEncodeAndDecode() {
+func (s *TypesTestSuite) TestEncodeAndDecode() {
s.requireEncodeAndDecodeNoError(
composeDataType(DataTypeMajorBool, 0),
BoolTypeNode{})
@@ -55,7 +66,7 @@ func (s *TypeTestSuite) TestEncodeAndDecode() {
FixedTypeNode{Unsigned: true, Size: 16, FractionalDigits: 2})
}
-func (s *TypeTestSuite) TestEncodeError() {
+func (s *TypesTestSuite) TestEncodeError() {
s.requireEncodeError(struct{}{})
s.requireEncodeError(IntTypeNode{Size: 1})
s.requireEncodeError(IntTypeNode{Size: 257})
@@ -66,7 +77,7 @@ func (s *TypeTestSuite) TestEncodeError() {
s.requireEncodeError(FixedTypeNode{Size: 8, FractionalDigits: 81})
}
-func (s *TypeTestSuite) TestDecodeError() {
+func (s *TypesTestSuite) TestDecodeError() {
s.requireDecodeError(DataTypeUnknown)
s.requireDecodeError(composeDataType(DataTypeMajorBool, 1))
s.requireDecodeError(composeDataType(DataTypeMajorAddress, 1))
@@ -79,6 +90,68 @@ func (s *TypeTestSuite) TestDecodeError() {
s.requireDecodeError(composeDataType(DataTypeMajorUfixed+0x20, 80))
}
-func TestType(t *testing.T) {
- suite.Run(t, new(TypeTestSuite))
+func (s *TypesTestSuite) TestEncodeAndDecodeDecimal() {
+ pos := decimal.New(15, 0)
+ zero := decimal.Zero
+ neg := decimal.New(-15, 0)
+
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorInt, 2),
+ pos,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorInt, 2),
+ zero,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorInt, 2),
+ neg,
+ 3)
+
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorUint, 2),
+ pos,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorUint, 2),
+ zero,
+ 3)
+
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorAddress, 0),
+ pos,
+ 20)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorAddress, 0),
+ zero,
+ 20)
+
+ pos = decimal.New(15, -2)
+ neg = decimal.New(-15, -2)
+
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorFixed+2, 2),
+ pos,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorFixed+2, 2),
+ zero,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorFixed+2, 2),
+ neg,
+ 3)
+
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorUfixed+2, 2),
+ pos,
+ 3)
+ s.requireEncodeAndDecodeDecimalNoError(
+ composeDataType(DataTypeMajorUfixed+2, 2),
+ zero,
+ 3)
+}
+
+func TestTypes(t *testing.T) {
+ suite.Run(t, new(TypesTestSuite))
}