diff options
Diffstat (limited to 'core/vm/sqlvm/ast')
-rw-r--r-- | core/vm/sqlvm/ast/types.go | 33 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types_test.go | 25 |
2 files changed, 50 insertions, 8 deletions
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index 65fdc501d..3909e7575 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -169,6 +169,17 @@ func DataTypeDecode(t DataType) (TypeNode, error) { return nil, se.ErrorCodeDataTypeDecode } +func decimalToBig(d decimal.Decimal) (b *big.Int) { + if exponent := int64(d.Exponent()); exponent >= 0 { + exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil) + b = new(big.Int).Mul(d.Coefficient(), exp) + } else { + exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil) + b = new(big.Int).Div(d.Coefficient(), exp) + } + return +} + // Don't handle overflow here. func decimalEncode(size int, d decimal.Decimal) []byte { ret := make([]byte, size) @@ -177,14 +188,7 @@ func decimalEncode(size int, d decimal.Decimal) []byte { return ret } - var b *big.Int - if exponent := int64(d.Exponent()); exponent >= 0 { - exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil) - b = new(big.Int).Mul(d.Coefficient(), exp) - } else { - exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil) - b = new(big.Int).Div(d.Coefficient(), exp) - } + b := decimalToBig(d) if s > 0 { bs := b.Bytes() @@ -298,3 +302,16 @@ func GetMinMax(dt DataType) (min, max decimal.Decimal, err error) { decPairMap[dt] = decPair{Max: max, Min: min} return } + +// DecimalToUint64 convert decimal to uint64. +// Negative case will return error, and decimal part will be trancated. +func DecimalToUint64(d decimal.Decimal) (uint64, error) { + s := d.Sign() + if s == 0 { + return 0, nil + } + if s < 0 { + return 0, se.ErrorCodeNegDecimalToUint64 + } + return decimalToBig(d).Uint64(), nil +} diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go index ada5d487f..8373aa344 100644 --- a/core/vm/sqlvm/ast/types_test.go +++ b/core/vm/sqlvm/ast/types_test.go @@ -212,6 +212,31 @@ func (s *TypesTestSuite) TestSize() { } } +func (s *TypesTestSuite) TestDecimalToUint64() { + pos := decimal.New(15, 1) + zero := decimal.Zero + neg := decimal.New(-150, -1) + posSmall := decimal.New(15, -2) + negSmall := decimal.New(-15, -2) + + testcases := []struct { + d decimal.Decimal + u uint64 + err error + }{ + {pos, 150, nil}, + {zero, 0, nil}, + {neg, 0, errors.ErrorCodeNegDecimalToUint64}, + {posSmall, 0, nil}, + {negSmall, 0, errors.ErrorCodeNegDecimalToUint64}, + } + for i, t := range testcases { + u, err := DecimalToUint64(t.d) + s.Require().Equalf(t.err, err, "err not match. testcase: %v", i) + s.Require().Equalf(t.u, u, "result not match. testcase: %v", i) + } +} + func TestTypes(t *testing.T) { suite.Run(t, new(TypesTestSuite)) } |