From c02fb054e521b1d29b3ae1e1917778a5f92dc05c Mon Sep 17 00:00:00 2001 From: Jhih-Ming Huang Date: Tue, 19 Mar 2019 16:13:37 +0800 Subject: core: vm: sqlvm: ast: implement decimal to uint64 deciaml.IntPart() returns int64, so we have to implement a function to convert deciaml to uint64 for reading primary id from Raw. --- core/vm/sqlvm/ast/types.go | 33 ++++++++++++++++++++++++-------- core/vm/sqlvm/ast/types_test.go | 25 ++++++++++++++++++++++++ core/vm/sqlvm/errors/errors.go | 14 ++++++++------ core/vm/sqlvm/runtime/instructions.go | 36 +++++++++++++++++++++++------------ 4 files changed, 82 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index 65fdc501d..3909e7575 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -169,6 +169,17 @@ func DataTypeDecode(t DataType) (TypeNode, error) { return nil, se.ErrorCodeDataTypeDecode } +func decimalToBig(d decimal.Decimal) (b *big.Int) { + if exponent := int64(d.Exponent()); exponent >= 0 { + exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil) + b = new(big.Int).Mul(d.Coefficient(), exp) + } else { + exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil) + b = new(big.Int).Div(d.Coefficient(), exp) + } + return +} + // Don't handle overflow here. func decimalEncode(size int, d decimal.Decimal) []byte { ret := make([]byte, size) @@ -177,14 +188,7 @@ func decimalEncode(size int, d decimal.Decimal) []byte { return ret } - var b *big.Int - if exponent := int64(d.Exponent()); exponent >= 0 { - exp := new(big.Int).Exp(bigIntTen, big.NewInt(exponent), nil) - b = new(big.Int).Mul(d.Coefficient(), exp) - } else { - exp := new(big.Int).Exp(bigIntTen, big.NewInt(-exponent), nil) - b = new(big.Int).Div(d.Coefficient(), exp) - } + b := decimalToBig(d) if s > 0 { bs := b.Bytes() @@ -298,3 +302,16 @@ func GetMinMax(dt DataType) (min, max decimal.Decimal, err error) { decPairMap[dt] = decPair{Max: max, Min: min} return } + +// DecimalToUint64 convert decimal to uint64. +// Negative case will return error, and decimal part will be trancated. +func DecimalToUint64(d decimal.Decimal) (uint64, error) { + s := d.Sign() + if s == 0 { + return 0, nil + } + if s < 0 { + return 0, se.ErrorCodeNegDecimalToUint64 + } + return decimalToBig(d).Uint64(), nil +} diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go index ada5d487f..8373aa344 100644 --- a/core/vm/sqlvm/ast/types_test.go +++ b/core/vm/sqlvm/ast/types_test.go @@ -212,6 +212,31 @@ func (s *TypesTestSuite) TestSize() { } } +func (s *TypesTestSuite) TestDecimalToUint64() { + pos := decimal.New(15, 1) + zero := decimal.Zero + neg := decimal.New(-150, -1) + posSmall := decimal.New(15, -2) + negSmall := decimal.New(-15, -2) + + testcases := []struct { + d decimal.Decimal + u uint64 + err error + }{ + {pos, 150, nil}, + {zero, 0, nil}, + {neg, 0, errors.ErrorCodeNegDecimalToUint64}, + {posSmall, 0, nil}, + {negSmall, 0, errors.ErrorCodeNegDecimalToUint64}, + } + for i, t := range testcases { + u, err := DecimalToUint64(t.d) + s.Require().Equalf(t.err, err, "err not match. testcase: %v", i) + s.Require().Equalf(t.u, u, "result not match. testcase: %v", i) + } +} + func TestTypes(t *testing.T) { suite.Run(t, new(TypesTestSuite)) } diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go index ae3c3b8a5..c327d1b43 100644 --- a/core/vm/sqlvm/errors/errors.go +++ b/core/vm/sqlvm/errors/errors.go @@ -121,6 +121,7 @@ const ( ErrorCodeIndexOutOfRange ErrorCodeInvalidCastType ErrorCodeDividedByZero + ErrorCodeNegDecimalToUint64 ) var errorCodeMap = [...]string{ @@ -147,12 +148,13 @@ var errorCodeMap = [...]string{ ErrorCodeDecimalDecode: "decimal decode failed", ErrorCodeGetMinMax: "get (min, max) failed", // Runtime Error - ErrorCodeInvalidDataType: "invalid data type", - ErrorCodeOverflow: "overflow", - ErrorCodeUnderflow: "underflow", - ErrorCodeIndexOutOfRange: "index out of range", - ErrorCodeInvalidCastType: "invalid cast type", - ErrorCodeDividedByZero: "divide by zero", + ErrorCodeInvalidDataType: "invalid data type", + ErrorCodeOverflow: "overflow", + ErrorCodeUnderflow: "underflow", + ErrorCodeIndexOutOfRange: "index out of range", + ErrorCodeInvalidCastType: "invalid cast type", + ErrorCodeDividedByZero: "divide by zero", + ErrorCodeNegDecimalToUint64: "negative deciaml to uint64", } func (c ErrorCode) Error() string { diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go index 5a61d5f80..17691d0d4 100644 --- a/core/vm/sqlvm/runtime/instructions.go +++ b/core/vm/sqlvm/runtime/instructions.go @@ -58,20 +58,27 @@ type Operand struct { RegisterIndex uint } -func (o *Operand) toUint64() []uint64 { - result := make([]uint64, len(o.Data)) +func (o *Operand) toUint64() (result []uint64, err error) { + result = make([]uint64, len(o.Data)) for i, tuple := range o.Data { - result[i] = uint64(tuple[0].Value.IntPart()) + result[i], err = ast.DecimalToUint64(tuple[0].Value) + if err != nil { + return + } } - return result + return } -func (o *Operand) toUint8() []uint8 { +func (o *Operand) toUint8() ([]uint8, error) { result := make([]uint8, len(o.Data)) for i, tuple := range o.Data { - result[i] = uint8(tuple[0].Value.IntPart()) + u, err := ast.DecimalToUint64(tuple[0].Value) + if err != nil { + return nil, err + } + result[i] = uint8(u) } - return result + return result, nil } func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output int) error { @@ -81,8 +88,14 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output } table := ctx.Storage.Schema[tableIdx] - ids := input[1].toUint64() - fields := input[2].toUint8() + ids, err := input[1].toUint64() + if err != nil { + return err + } + fields, err := input[2].toUint8() + if err != nil { + return err + } op := Operand{ IsImmediate: false, Data: make([]Tuple, len(ids)), @@ -91,11 +104,10 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output for i := range op.Data { op.Data[i] = make([]*Raw, len(fields)) } - meta, err := table.GetFieldType(fields) + op.Meta, err = table.GetFieldType(fields) if err != nil { return err } - op.Meta = meta for i, id := range ids { slotDataCache := make(map[dexCommon.Hash]dexCommon.Hash) head := ctx.Storage.GetPrimaryKeyHash(table.Name, id) @@ -103,7 +115,7 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output col := table.Columns[int(fields[j])] byteOffset := col.ByteOffset slotOffset := col.SlotOffset - dt := meta[j] + dt := op.Meta[j] size := dt.Size() slot := ctx.Storage.ShiftHashUint64(head, uint64(slotOffset)) slotData := getSlotData(ctx, slot, slotDataCache) -- cgit v1.2.3