diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/vm/sqlvm/ast/ast.go | 105 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types.go | 12 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types_test.go | 11 | ||||
-rw-r--r-- | core/vm/sqlvm/errors/errors.go | 2 |
4 files changed, 85 insertions, 45 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index e5de75d45..672770b84 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -1,6 +1,8 @@ package ast import ( + "fmt" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" "github.com/shopspring/decimal" ) @@ -309,7 +311,7 @@ func (n *NullValueNode) Value() interface{} { return n } // TypeNode is an interface which should be satisfied nodes representing types. type TypeNode interface { Node - GetType() (DataType, errors.ErrorCode) + GetType() (DataType, errors.ErrorCode, string) } // IntTypeNode represents solidity int{X} and uint{X} types. @@ -327,12 +329,30 @@ func (n *IntTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode) { - if n.Size%8 != 0 || n.Size == 0 || n.Size > 256 { +func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode, string) { + isMultiple := n.Size%8 == 0 + isNotZero := n.Size != 0 + isInRange := n.Size <= 256 + if !isMultiple || !isNotZero || !isInRange { + name := "int" + code := errors.ErrorCodeInvalidIntSize if n.Unsigned { - return DataTypeUnknown, errors.ErrorCodeInvalidUintSize + name = "uint" + code = errors.ErrorCodeInvalidUintSize + } + if !isMultiple { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size %d is not a multiple of 8", name, n.Size) + } + if !isNotZero { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size cannot be zero", name) } - return DataTypeUnknown, errors.ErrorCodeInvalidIntSize + if !isInRange { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size %d cannot be larger than 256", name, n.Size) + } + panic("unreachable") } var major DataTypeMajor var minor DataTypeMinor @@ -342,7 +362,7 @@ func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode) { major = DataTypeMajorInt } minor = DataTypeMinor(n.Size/8 - 1) - return ComposeDataType(major, minor), errors.ErrorCodeNil + return ComposeDataType(major, minor), errors.ErrorCodeNil, "" } // FixedTypeNode represents solidity fixed{M}x{N} and ufixed{M}x{N} types. @@ -361,18 +381,41 @@ func (n *FixedTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode) { - if n.Size%8 != 0 || n.Size == 0 || n.Size > 256 { +func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode, string) { + sizeIsMultiple := n.Size%8 == 0 + sizeIsNotZero := n.Size != 0 + sizeIsInRange := n.Size <= 256 + fractionalDigitsInRange := n.FractionalDigits <= 80 + if !sizeIsMultiple || !sizeIsNotZero || !sizeIsInRange || + !fractionalDigitsInRange { + name := "fixed" + code := errors.ErrorCodeInvalidFixedSize if n.Unsigned { - return DataTypeUnknown, errors.ErrorCodeInvalidUfixedSize + name = "ufixed" + code = errors.ErrorCodeInvalidFixedSize } - return DataTypeUnknown, errors.ErrorCodeInvalidFixedSize - } - if n.FractionalDigits > 80 { + if !sizeIsMultiple { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size %d is not a multiple of 8", name, n.Size) + } + if !sizeIsNotZero { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size cannot be zero", name) + } + if !sizeIsInRange { + return DataTypeUnknown, code, fmt.Sprintf( + "%s size %d cannot be larger than 256", name, n.Size) + } + code = errors.ErrorCodeInvalidFixedFractionalDigits if n.Unsigned { - return DataTypeUnknown, errors.ErrorCodeInvalidUfixedFractionalDigits + code = errors.ErrorCodeInvalidUfixedFractionalDigits + } + if !fractionalDigitsInRange { + return DataTypeUnknown, code, fmt.Sprintf( + "%s fractional digits %d cannot be larger than 80", + name, n.FractionalDigits) } - return DataTypeUnknown, errors.ErrorCodeInvalidFixedFractionalDigits + panic("unreachable") } var major DataTypeMajor var minor DataTypeMinor @@ -383,7 +426,7 @@ func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode) { } major += DataTypeMajor(n.Size/8 - 1) minor = DataTypeMinor(n.FractionalDigits) - return ComposeDataType(major, minor), errors.ErrorCodeNil + return ComposeDataType(major, minor), errors.ErrorCodeNil, "" } // DynamicBytesTypeNode represents solidity bytes type. @@ -399,9 +442,9 @@ func (n *DynamicBytesTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *DynamicBytesTypeNode) GetType() (DataType, errors.ErrorCode) { +func (n *DynamicBytesTypeNode) GetType() (DataType, errors.ErrorCode, string) { return ComposeDataType(DataTypeMajorDynamicBytes, DataTypeMinorDontCare), - errors.ErrorCodeNil + errors.ErrorCodeNil, "" } // FixedBytesTypeNode represents solidity bytes{X} type. @@ -418,13 +461,23 @@ func (n *FixedBytesTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *FixedBytesTypeNode) GetType() (DataType, errors.ErrorCode) { - if n.Size == 0 || n.Size > 32 { - return DataTypeUnknown, errors.ErrorCodeInvalidBytesSize +func (n *FixedBytesTypeNode) GetType() (DataType, errors.ErrorCode, string) { + isNotZero := n.Size != 0 + isInRange := n.Size <= 32 + if !isNotZero || !isInRange { + code := errors.ErrorCodeInvalidBytesSize + if !isNotZero { + return DataTypeUnknown, code, "bytes size cannot be zero" + } + if !isInRange { + return DataTypeUnknown, code, fmt.Sprintf( + "bytes size %d cannot be larger than 32", n.Size) + } + panic("unreachable") } major := DataTypeMajorFixedBytes minor := DataTypeMinor(n.Size - 1) - return ComposeDataType(major, minor), errors.ErrorCodeNil + return ComposeDataType(major, minor), errors.ErrorCodeNil, "" } // AddressTypeNode represents solidity address type. @@ -440,9 +493,9 @@ func (n *AddressTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *AddressTypeNode) GetType() (DataType, errors.ErrorCode) { +func (n *AddressTypeNode) GetType() (DataType, errors.ErrorCode, string) { return ComposeDataType(DataTypeMajorAddress, DataTypeMinorDontCare), - errors.ErrorCodeNil + errors.ErrorCodeNil, "" } // BoolTypeNode represents solidity bool type. @@ -458,9 +511,9 @@ func (n *BoolTypeNode) GetChildren() []Node { } // GetType returns the type represented by the node. -func (n *BoolTypeNode) GetType() (DataType, errors.ErrorCode) { +func (n *BoolTypeNode) GetType() (DataType, errors.ErrorCode, string) { return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare), - errors.ErrorCodeNil + errors.ErrorCodeNil, "" } // --------------------------------------------------------------------------- @@ -825,7 +878,7 @@ func (n *CastOperatorNode) IsConstant() bool { // GetType returns the type of CAST expression, which is always the target type. func (n *CastOperatorNode) GetType() DataType { - if dt, code := n.TargetType.GetType(); code == errors.ErrorCodeNil { + if dt, code, _ := n.TargetType.GetType(); code == errors.ErrorCodeNil { return dt } return DataTypeUnknown diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index c37e6c012..5151334a9 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -101,18 +101,6 @@ func (dt DataType) Size() uint8 { } } -// DataTypeEncode encodes data type node into DataType. -func DataTypeEncode(n TypeNode) (DataType, error) { - if n == nil { - return DataTypeUnknown, se.ErrorCodeDataTypeEncode - } - t, code := n.GetType() - if code == se.ErrorCodeNil { - return t, nil - } - return t, code -} - // DataTypeDecode decodes DataType into data type node. func DataTypeDecode(t DataType) (TypeNode, error) { major, minor := DecomposeDataType(t) diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go index 8373aa344..15e9cdec0 100644 --- a/core/vm/sqlvm/ast/types_test.go +++ b/core/vm/sqlvm/ast/types_test.go @@ -25,8 +25,9 @@ func (s *TypesTestSuite) requireEncodeAndDecodeDecimalNoError( func (s *TypesTestSuite) requireEncodeAndDecodeNoError( d DataType, t TypeNode) { - encode, err := DataTypeEncode(t) - s.Require().NoError(err) + encode, code, message := t.GetType() + s.Require().Zero(code) + s.Require().Empty(message) s.Require().Equal(d, encode) decode, err := DataTypeDecode(d) s.Require().NoError(err) @@ -34,8 +35,9 @@ func (s *TypesTestSuite) requireEncodeAndDecodeNoError( } func (s *TypesTestSuite) requireEncodeError(input TypeNode) { - _, err := DataTypeEncode(input) - s.Require().Error(err) + _, code, message := input.GetType() + s.Require().NotZero(code) + s.Require().NotEmpty(message) } func (s *TypesTestSuite) requireDecodeError(input DataType) { @@ -71,7 +73,6 @@ func (s *TypesTestSuite) TestEncodeAndDecode() { } func (s *TypesTestSuite) TestEncodeError() { - s.requireEncodeError(nil) s.requireEncodeError(&IntTypeNode{Size: 1}) s.requireEncodeError(&IntTypeNode{Size: 257}) s.requireEncodeError(&FixedBytesTypeNode{Size: 0}) diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go index c327d1b43..41faae93e 100644 --- a/core/vm/sqlvm/errors/errors.go +++ b/core/vm/sqlvm/errors/errors.go @@ -108,7 +108,6 @@ const ( ErrorCodeInvalidUfixedSize ErrorCodeInvalidFixedFractionalDigits ErrorCodeInvalidUfixedFractionalDigits - ErrorCodeDataTypeEncode ErrorCodeDataTypeDecode ErrorCodeDecimalEncode ErrorCodeDecimalDecode @@ -142,7 +141,6 @@ var errorCodeMap = [...]string{ ErrorCodeInvalidUfixedSize: "invalid ufixed size", ErrorCodeInvalidFixedFractionalDigits: "invalid fixed fractional digits", ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits", - ErrorCodeDataTypeEncode: "data type encode failed", ErrorCodeDataTypeDecode: "data type decode failed", ErrorCodeDecimalEncode: "decimal encode failed", ErrorCodeDecimalDecode: "decimal decode failed", |