aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/vm/sqlvm/ast/types.go33
-rw-r--r--core/vm/sqlvm/ast/types_test.go25
-rw-r--r--core/vm/sqlvm/errors/errors.go14
-rw-r--r--core/vm/sqlvm/runtime/instructions.go36
4 files changed, 82 insertions, 26 deletions
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index 65fdc501d..3909e7575 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -169,6 +169,17 @@ func DataTypeDecode(t DataType) (TypeNode, error) {
return nil, se.ErrorCodeDataTypeDecode
}
+func decimalToBig(d decimal.Decimal) (b *big.Int) {
+ if exponent := int64(d.Exponent()); exponent >= 0 {
+ exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil)
+ b = new(big.Int).Mul(d.Coefficient(), exp)
+ } else {
+ exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil)
+ b = new(big.Int).Div(d.Coefficient(), exp)
+ }
+ return
+}
+
// Don't handle overflow here.
func decimalEncode(size int, d decimal.Decimal) []byte {
ret := make([]byte, size)
@@ -177,14 +188,7 @@ func decimalEncode(size int, d decimal.Decimal) []byte {
return ret
}
- var b *big.Int
- if exponent := int64(d.Exponent()); exponent >= 0 {
- exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil)
- b = new(big.Int).Mul(d.Coefficient(), exp)
- } else {
- exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil)
- b = new(big.Int).Div(d.Coefficient(), exp)
- }
+ b := decimalToBig(d)
if s > 0 {
bs := b.Bytes()
@@ -298,3 +302,16 @@ func GetMinMax(dt DataType) (min, max decimal.Decimal, err error) {
decPairMap[dt] = decPair{Max: max, Min: min}
return
}
+
+// DecimalToUint64 convert decimal to uint64.
+// Negative case will return error, and decimal part will be trancated.
+func DecimalToUint64(d decimal.Decimal) (uint64, error) {
+ s := d.Sign()
+ if s == 0 {
+ return 0, nil
+ }
+ if s < 0 {
+ return 0, se.ErrorCodeNegDecimalToUint64
+ }
+ return decimalToBig(d).Uint64(), nil
+}
diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go
index ada5d487f..8373aa344 100644
--- a/core/vm/sqlvm/ast/types_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -212,6 +212,31 @@ func (s *TypesTestSuite) TestSize() {
}
}
+func (s *TypesTestSuite) TestDecimalToUint64() {
+ pos := decimal.New(15, 1)
+ zero := decimal.Zero
+ neg := decimal.New(-150, -1)
+ posSmall := decimal.New(15, -2)
+ negSmall := decimal.New(-15, -2)
+
+ testcases := []struct {
+ d decimal.Decimal
+ u uint64
+ err error
+ }{
+ {pos, 150, nil},
+ {zero, 0, nil},
+ {neg, 0, errors.ErrorCodeNegDecimalToUint64},
+ {posSmall, 0, nil},
+ {negSmall, 0, errors.ErrorCodeNegDecimalToUint64},
+ }
+ for i, t := range testcases {
+ u, err := DecimalToUint64(t.d)
+ s.Require().Equalf(t.err, err, "err not match. testcase: %v", i)
+ s.Require().Equalf(t.u, u, "result not match. testcase: %v", i)
+ }
+}
+
func TestTypes(t *testing.T) {
suite.Run(t, new(TypesTestSuite))
}
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index ae3c3b8a5..c327d1b43 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -121,6 +121,7 @@ const (
ErrorCodeIndexOutOfRange
ErrorCodeInvalidCastType
ErrorCodeDividedByZero
+ ErrorCodeNegDecimalToUint64
)
var errorCodeMap = [...]string{
@@ -147,12 +148,13 @@ var errorCodeMap = [...]string{
ErrorCodeDecimalDecode: "decimal decode failed",
ErrorCodeGetMinMax: "get (min, max) failed",
// Runtime Error
- ErrorCodeInvalidDataType: "invalid data type",
- ErrorCodeOverflow: "overflow",
- ErrorCodeUnderflow: "underflow",
- ErrorCodeIndexOutOfRange: "index out of range",
- ErrorCodeInvalidCastType: "invalid cast type",
- ErrorCodeDividedByZero: "divide by zero",
+ ErrorCodeInvalidDataType: "invalid data type",
+ ErrorCodeOverflow: "overflow",
+ ErrorCodeUnderflow: "underflow",
+ ErrorCodeIndexOutOfRange: "index out of range",
+ ErrorCodeInvalidCastType: "invalid cast type",
+ ErrorCodeDividedByZero: "divide by zero",
+ ErrorCodeNegDecimalToUint64: "negative deciaml to uint64",
}
func (c ErrorCode) Error() string {
diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go
index 5a61d5f80..17691d0d4 100644
--- a/core/vm/sqlvm/runtime/instructions.go
+++ b/core/vm/sqlvm/runtime/instructions.go
@@ -58,20 +58,27 @@ type Operand struct {
RegisterIndex uint
}
-func (o *Operand) toUint64() []uint64 {
- result := make([]uint64, len(o.Data))
+func (o *Operand) toUint64() (result []uint64, err error) {
+ result = make([]uint64, len(o.Data))
for i, tuple := range o.Data {
- result[i] = uint64(tuple[0].Value.IntPart())
+ result[i], err = ast.DecimalToUint64(tuple[0].Value)
+ if err != nil {
+ return
+ }
}
- return result
+ return
}
-func (o *Operand) toUint8() []uint8 {
+func (o *Operand) toUint8() ([]uint8, error) {
result := make([]uint8, len(o.Data))
for i, tuple := range o.Data {
- result[i] = uint8(tuple[0].Value.IntPart())
+ u, err := ast.DecimalToUint64(tuple[0].Value)
+ if err != nil {
+ return nil, err
+ }
+ result[i] = uint8(u)
}
- return result
+ return result, nil
}
func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output int) error {
@@ -81,8 +88,14 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output
}
table := ctx.Storage.Schema[tableIdx]
- ids := input[1].toUint64()
- fields := input[2].toUint8()
+ ids, err := input[1].toUint64()
+ if err != nil {
+ return err
+ }
+ fields, err := input[2].toUint8()
+ if err != nil {
+ return err
+ }
op := Operand{
IsImmediate: false,
Data: make([]Tuple, len(ids)),
@@ -91,11 +104,10 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output
for i := range op.Data {
op.Data[i] = make([]*Raw, len(fields))
}
- meta, err := table.GetFieldType(fields)
+ op.Meta, err = table.GetFieldType(fields)
if err != nil {
return err
}
- op.Meta = meta
for i, id := range ids {
slotDataCache := make(map[dexCommon.Hash]dexCommon.Hash)
head := ctx.Storage.GetPrimaryKeyHash(table.Name, id)
@@ -103,7 +115,7 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output
col := table.Columns[int(fields[j])]
byteOffset := col.ByteOffset
slotOffset := col.SlotOffset
- dt := meta[j]
+ dt := op.Meta[j]
size := dt.Size()
slot := ctx.Storage.ShiftHashUint64(head, uint64(slotOffset))
slotData := getSlotData(ctx, slot, slotDataCache)