aboutsummaryrefslogtreecommitdiffstats
path: root/core/vm/sqlvm/ast
diff options
context:
space:
mode:
Diffstat (limited to 'core/vm/sqlvm/ast')
-rw-r--r--core/vm/sqlvm/ast/types.go33
-rw-r--r--core/vm/sqlvm/ast/types_test.go25
2 files changed, 50 insertions, 8 deletions
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index 65fdc501d..3909e7575 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -169,6 +169,17 @@ func DataTypeDecode(t DataType) (TypeNode, error) {
return nil, se.ErrorCodeDataTypeDecode
}
+func decimalToBig(d decimal.Decimal) (b *big.Int) {
+ if exponent := int64(d.Exponent()); exponent >= 0 {
+ exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil)
+ b = new(big.Int).Mul(d.Coefficient(), exp)
+ } else {
+ exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil)
+ b = new(big.Int).Div(d.Coefficient(), exp)
+ }
+ return
+}
+
// Don't handle overflow here.
func decimalEncode(size int, d decimal.Decimal) []byte {
ret := make([]byte, size)
@@ -177,14 +188,7 @@ func decimalEncode(size int, d decimal.Decimal) []byte {
return ret
}
- var b *big.Int
- if exponent := int64(d.Exponent()); exponent >= 0 {
- exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil)
- b = new(big.Int).Mul(d.Coefficient(), exp)
- } else {
- exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil)
- b = new(big.Int).Div(d.Coefficient(), exp)
- }
+ b := decimalToBig(d)
if s > 0 {
bs := b.Bytes()
@@ -298,3 +302,16 @@ func GetMinMax(dt DataType) (min, max decimal.Decimal, err error) {
decPairMap[dt] = decPair{Max: max, Min: min}
return
}
+
+// DecimalToUint64 convert decimal to uint64.
+// Negative case will return error, and decimal part will be trancated.
+func DecimalToUint64(d decimal.Decimal) (uint64, error) {
+ s := d.Sign()
+ if s == 0 {
+ return 0, nil
+ }
+ if s < 0 {
+ return 0, se.ErrorCodeNegDecimalToUint64
+ }
+ return decimalToBig(d).Uint64(), nil
+}
diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go
index ada5d487f..8373aa344 100644
--- a/core/vm/sqlvm/ast/types_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -212,6 +212,31 @@ func (s *TypesTestSuite) TestSize() {
}
}
+func (s *TypesTestSuite) TestDecimalToUint64() {
+ pos := decimal.New(15, 1)
+ zero := decimal.Zero
+ neg := decimal.New(-150, -1)
+ posSmall := decimal.New(15, -2)
+ negSmall := decimal.New(-15, -2)
+
+ testcases := []struct {
+ d decimal.Decimal
+ u uint64
+ err error
+ }{
+ {pos, 150, nil},
+ {zero, 0, nil},
+ {neg, 0, errors.ErrorCodeNegDecimalToUint64},
+ {posSmall, 0, nil},
+ {negSmall, 0, errors.ErrorCodeNegDecimalToUint64},
+ }
+ for i, t := range testcases {
+ u, err := DecimalToUint64(t.d)
+ s.Require().Equalf(t.err, err, "err not match. testcase: %v", i)
+ s.Require().Equalf(t.u, u, "result not match. testcase: %v", i)
+ }
+}
+
func TestTypes(t *testing.T) {
suite.Run(t, new(TypesTestSuite))
}