diff options
-rw-r--r-- | core/vm/sqlvm/ast/types.go | 26 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types_test.go | 8 | ||||
-rw-r--r-- | core/vm/sqlvm/common/utilities.go | 6 | ||||
-rw-r--r-- | core/vm/sqlvm/errors/errors.go | 4 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions.go | 52 | ||||
-rw-r--r-- | core/vm/sqlvm/schema/schema.go | 24 |
6 files changed, 63 insertions, 57 deletions
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index 5f7f1b886..1d06b5695 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -356,49 +356,49 @@ func decimalDecode(signed bool, bs []byte) decimal.Decimal { } // DecimalEncode encodes decimal to bytes depend on data type. -func DecimalEncode(dt DataType, d decimal.Decimal) ([]byte, error) { +func DecimalEncode(dt DataType, d decimal.Decimal) ([]byte, bool) { major, minor := DecomposeDataType(dt) switch major { case DataTypeMajorInt, DataTypeMajorUint: - return decimalEncode(int(minor)+1, d), nil + return decimalEncode(int(minor)+1, d), true } switch { case major.IsFixedRange(): return decimalEncode( int(major-DataTypeMajorFixed)+1, - d.Shift(int32(minor))), nil + d.Shift(int32(minor))), true case major.IsUfixedRange(): return decimalEncode( int(major-DataTypeMajorUfixed)+1, - d.Shift(int32(minor))), nil + d.Shift(int32(minor))), true } - return nil, se.ErrorCodeDecimalEncode + return nil, false } // DecimalDecode decodes decimal from bytes. -func DecimalDecode(dt DataType, b []byte) (decimal.Decimal, error) { +func DecimalDecode(dt DataType, b []byte) (decimal.Decimal, bool) { major, minor := DecomposeDataType(dt) switch major { case DataTypeMajorInt: - return decimalDecode(true, b), nil + return decimalDecode(true, b), true case DataTypeMajorUint: - return decimalDecode(false, b), nil + return decimalDecode(false, b), true case DataTypeMajorBool: if b[0] == 0 { - return dec.False, nil + return dec.False, true } - return dec.True, nil + return dec.True, true } switch { case major.IsFixedRange(): - return decimalDecode(true, b).Shift(-int32(minor)), nil + return decimalDecode(true, b).Shift(-int32(minor)), true case major.IsUfixedRange(): - return decimalDecode(false, b).Shift(-int32(minor)), nil + return decimalDecode(false, b).Shift(-int32(minor)), true } - return decimal.Zero, se.ErrorCodeDecimalDecode + return decimal.Zero, false } // DecimalToUint64 convert decimal to uint64. diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go index 02a51895c..d5f8e159d 100644 --- a/core/vm/sqlvm/ast/types_test.go +++ b/core/vm/sqlvm/ast/types_test.go @@ -16,11 +16,11 @@ type TypesTestSuite struct{ suite.Suite } func (s *TypesTestSuite) requireEncodeAndDecodeDecimalNoError( d DataType, t decimal.Decimal, bs int) { - encode, err := DecimalEncode(d, t) - s.Require().NoError(err) + encode, ok := DecimalEncode(d, t) + s.Require().True(ok) s.Require().Len(encode, bs) - decode, err := DecimalDecode(d, encode) - s.Require().NoError(err) + decode, ok := DecimalDecode(d, encode) + s.Require().True(ok) s.Require().Equal(t.String(), decode.String()) } diff --git a/core/vm/sqlvm/common/utilities.go b/core/vm/sqlvm/common/utilities.go index acff30c99..f1e74c664 100644 --- a/core/vm/sqlvm/common/utilities.go +++ b/core/vm/sqlvm/common/utilities.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "math/big" "github.com/dexon-foundation/decimal" @@ -14,7 +15,10 @@ func uint64ToBytes(id uint64) []byte { bigIntID := new(big.Int).SetUint64(id) decimalID := decimal.NewFromBigInt(bigIntID, 0) dt := ast.ComposeDataType(ast.DataTypeMajorUint, 7) - byteID, _ := ast.DecimalEncode(dt, decimalID) + byteID, ok := ast.DecimalEncode(dt, decimalID) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", dt)) + } return byteID } diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go index c97b3e191..444d87615 100644 --- a/core/vm/sqlvm/errors/errors.go +++ b/core/vm/sqlvm/errors/errors.go @@ -108,8 +108,6 @@ const ( ErrorCodeInvalidUfixedSize ErrorCodeInvalidFixedFractionalDigits ErrorCodeInvalidUfixedFractionalDigits - ErrorCodeDecimalEncode - ErrorCodeDecimalDecode // Runtime Error ErrorCodeInvalidOperandNum @@ -143,8 +141,6 @@ var errorCodeMap = [...]string{ ErrorCodeInvalidUfixedSize: "invalid ufixed size", ErrorCodeInvalidFixedFractionalDigits: "invalid fixed fractional digits", ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits", - ErrorCodeDecimalEncode: "decimal encode failed", - ErrorCodeDecimalDecode: "decimal decode failed", // Runtime Error ErrorCodeInvalidOperandNum: "invalid operand number", ErrorCodeInvalidDataType: "invalid data type", diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go index 6fa72d61a..35d3135b2 100644 --- a/core/vm/sqlvm/runtime/instructions.go +++ b/core/vm/sqlvm/runtime/instructions.go @@ -188,16 +188,16 @@ func decode(ctx *common.Context, dt ast.DataType, slot dexCommon.Hash, bytes []b case ast.DataTypeMajorFixedBytes, ast.DataTypeMajorAddress: rVal.Bytes = bytes case ast.DataTypeMajorBool, ast.DataTypeMajorInt, ast.DataTypeMajorUint: - d, err := ast.DecimalDecode(dt, bytes) - if err != nil { - return nil, err + d, ok := ast.DecimalDecode(dt, bytes) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", dt)) } rVal.Value = d } if major.IsFixedRange() || major.IsUfixedRange() { - d, err := ast.DecimalDecode(dt, bytes) - if err != nil { - return nil, err + d, ok := ast.DecimalDecode(dt, bytes) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", dt)) } rVal.Value = d } @@ -1607,16 +1607,16 @@ func (r *Raw) castValue( ctx *common.Context, origin, target ast.DataType, l int, signed, rPadding bool) (err error) { - oBytes, err := ast.DecimalEncode(origin, r.Value) - if err != nil { - return + oBytes, ok := ast.DecimalEncode(origin, r.Value) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", origin)) } bytes2 := r.shiftBytes(oBytes, l, signed, rPadding) - r.Value, err = ast.DecimalDecode(target, bytes2) - if err != nil { - return + r.Value, ok = ast.DecimalDecode(target, bytes2) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", target)) } err = flowCheck(ctx, r.Value, target) @@ -1644,9 +1644,10 @@ func (r *Raw) castInt(ctx *common.Context, origin, target ast.DataType) (err err return } - r.Bytes, err = ast.DecimalEncode(mockDt, r.Value) - if err != nil { - return + var ok bool + r.Bytes, ok = ast.DecimalEncode(mockDt, r.Value) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", origin)) } r.Value = decimal.Zero case ast.DataTypeMajorFixedBytes: @@ -1654,9 +1655,10 @@ func (r *Raw) castInt(ctx *common.Context, origin, target ast.DataType) (err err err = se.ErrorCodeInvalidCastType return } - r.Bytes, err = ast.DecimalEncode(origin, r.Value) - if err != nil { - return + var ok bool + r.Bytes, ok = ast.DecimalEncode(origin, r.Value) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", origin)) } r.Value = decimal.Zero case ast.DataTypeMajorBool: @@ -1677,9 +1679,10 @@ func (r *Raw) castFixedBytes(ctx *common.Context, origin, target ast.DataType) ( err = se.ErrorCodeInvalidCastType return } - r.Value, err = ast.DecimalDecode(target, r.Bytes) - if err != nil { - return + var ok bool + r.Value, ok = ast.DecimalDecode(target, r.Bytes) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", target)) } r.Bytes = nil case ast.DataTypeMajorFixedBytes: @@ -1701,12 +1704,13 @@ func (r *Raw) castAddress(ctx *common.Context, origin, target ast.DataType) (err switch tMajor { case ast.DataTypeMajorAddress: case ast.DataTypeMajorInt, ast.DataTypeMajorUint: - r.Value, err = ast.DecimalDecode( + var ok bool + r.Value, ok = ast.DecimalDecode( target, r.shiftBytes(r.Bytes, int(tMinor)+1, false, false), ) - if err != nil { - return + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", target)) } err = flowCheck(ctx, r.Value, target) if err != nil { diff --git a/core/vm/sqlvm/schema/schema.go b/core/vm/sqlvm/schema/schema.go index 1ebb96fa3..1e87d88cf 100644 --- a/core/vm/sqlvm/schema/schema.go +++ b/core/vm/sqlvm/schema/schema.go @@ -14,8 +14,10 @@ import ( // Error defines for encode and decode. var ( - ErrEncodeUnexpectedType = errors.New("encode unexpected type") - ErrDecodeUnexpectedType = errors.New("decode unexpected type") + ErrEncodeUnexpectedDataType = errors.New("encode unexpected data type") + ErrEncodeUnexpectedDefaultType = errors.New("encode unexpected default type") + ErrDecodeUnexpectedDataType = errors.New("decode unexpected data type") + ErrDecodeUnexpectedDefaultType = errors.New("decode unexpected default type") ) // ColumnAttr defines bit flags for describing column attribute. @@ -208,13 +210,13 @@ func (c Column) EncodeRLP(w io.Writer) error { case []byte: c.Rest = d case decimal.Decimal: - var err error - c.Rest, err = ast.DecimalEncode(c.Type, d) - if err != nil { - return err + var ok bool + c.Rest, ok = ast.DecimalEncode(c.Type, d) + if !ok { + return ErrEncodeUnexpectedDataType } default: - return ErrEncodeUnexpectedType + return ErrEncodeUnexpectedDefaultType } } else { c.Rest = nil @@ -249,14 +251,14 @@ func (c *Column) DecodeRLP(s *rlp.Stream) error { case ast.DataTypeMajorFixedBytes, ast.DataTypeMajorDynamicBytes: c.Default = rest default: - d, err := ast.DecimalDecode(c.Type, rest) - if err != nil { - return err + d, ok := ast.DecimalDecode(c.Type, rest) + if !ok { + return ErrDecodeUnexpectedDataType } c.Default = d } default: - return ErrDecodeUnexpectedType + return ErrDecodeUnexpectedDefaultType } return nil |