aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/vm/sqlvm/ast/ast.go105
-rw-r--r--core/vm/sqlvm/ast/types.go12
-rw-r--r--core/vm/sqlvm/ast/types_test.go11
-rw-r--r--core/vm/sqlvm/errors/errors.go2
4 files changed, 85 insertions, 45 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go
index e5de75d45..672770b84 100644
--- a/core/vm/sqlvm/ast/ast.go
+++ b/core/vm/sqlvm/ast/ast.go
@@ -1,6 +1,8 @@
package ast
import (
+ "fmt"
+
"github.com/dexon-foundation/dexon/core/vm/sqlvm/errors"
"github.com/shopspring/decimal"
)
@@ -309,7 +311,7 @@ func (n *NullValueNode) Value() interface{} { return n }
// TypeNode is an interface which should be satisfied nodes representing types.
type TypeNode interface {
Node
- GetType() (DataType, errors.ErrorCode)
+ GetType() (DataType, errors.ErrorCode, string)
}
// IntTypeNode represents solidity int{X} and uint{X} types.
@@ -327,12 +329,30 @@ func (n *IntTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode) {
- if n.Size%8 != 0 || n.Size == 0 || n.Size > 256 {
+func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode, string) {
+ isMultiple := n.Size%8 == 0
+ isNotZero := n.Size != 0
+ isInRange := n.Size <= 256
+ if !isMultiple || !isNotZero || !isInRange {
+ name := "int"
+ code := errors.ErrorCodeInvalidIntSize
if n.Unsigned {
- return DataTypeUnknown, errors.ErrorCodeInvalidUintSize
+ name = "uint"
+ code = errors.ErrorCodeInvalidUintSize
+ }
+ if !isMultiple {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size %d is not a multiple of 8", name, n.Size)
+ }
+ if !isNotZero {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size cannot be zero", name)
}
- return DataTypeUnknown, errors.ErrorCodeInvalidIntSize
+ if !isInRange {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size %d cannot be larger than 256", name, n.Size)
+ }
+ panic("unreachable")
}
var major DataTypeMajor
var minor DataTypeMinor
@@ -342,7 +362,7 @@ func (n *IntTypeNode) GetType() (DataType, errors.ErrorCode) {
major = DataTypeMajorInt
}
minor = DataTypeMinor(n.Size/8 - 1)
- return ComposeDataType(major, minor), errors.ErrorCodeNil
+ return ComposeDataType(major, minor), errors.ErrorCodeNil, ""
}
// FixedTypeNode represents solidity fixed{M}x{N} and ufixed{M}x{N} types.
@@ -361,18 +381,41 @@ func (n *FixedTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode) {
- if n.Size%8 != 0 || n.Size == 0 || n.Size > 256 {
+func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode, string) {
+ sizeIsMultiple := n.Size%8 == 0
+ sizeIsNotZero := n.Size != 0
+ sizeIsInRange := n.Size <= 256
+ fractionalDigitsInRange := n.FractionalDigits <= 80
+ if !sizeIsMultiple || !sizeIsNotZero || !sizeIsInRange ||
+ !fractionalDigitsInRange {
+ name := "fixed"
+ code := errors.ErrorCodeInvalidFixedSize
if n.Unsigned {
- return DataTypeUnknown, errors.ErrorCodeInvalidUfixedSize
+ name = "ufixed"
+ code = errors.ErrorCodeInvalidFixedSize
}
- return DataTypeUnknown, errors.ErrorCodeInvalidFixedSize
- }
- if n.FractionalDigits > 80 {
+ if !sizeIsMultiple {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size %d is not a multiple of 8", name, n.Size)
+ }
+ if !sizeIsNotZero {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size cannot be zero", name)
+ }
+ if !sizeIsInRange {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s size %d cannot be larger than 256", name, n.Size)
+ }
+ code = errors.ErrorCodeInvalidFixedFractionalDigits
if n.Unsigned {
- return DataTypeUnknown, errors.ErrorCodeInvalidUfixedFractionalDigits
+ code = errors.ErrorCodeInvalidUfixedFractionalDigits
+ }
+ if !fractionalDigitsInRange {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "%s fractional digits %d cannot be larger than 80",
+ name, n.FractionalDigits)
}
- return DataTypeUnknown, errors.ErrorCodeInvalidFixedFractionalDigits
+ panic("unreachable")
}
var major DataTypeMajor
var minor DataTypeMinor
@@ -383,7 +426,7 @@ func (n *FixedTypeNode) GetType() (DataType, errors.ErrorCode) {
}
major += DataTypeMajor(n.Size/8 - 1)
minor = DataTypeMinor(n.FractionalDigits)
- return ComposeDataType(major, minor), errors.ErrorCodeNil
+ return ComposeDataType(major, minor), errors.ErrorCodeNil, ""
}
// DynamicBytesTypeNode represents solidity bytes type.
@@ -399,9 +442,9 @@ func (n *DynamicBytesTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *DynamicBytesTypeNode) GetType() (DataType, errors.ErrorCode) {
+func (n *DynamicBytesTypeNode) GetType() (DataType, errors.ErrorCode, string) {
return ComposeDataType(DataTypeMajorDynamicBytes, DataTypeMinorDontCare),
- errors.ErrorCodeNil
+ errors.ErrorCodeNil, ""
}
// FixedBytesTypeNode represents solidity bytes{X} type.
@@ -418,13 +461,23 @@ func (n *FixedBytesTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *FixedBytesTypeNode) GetType() (DataType, errors.ErrorCode) {
- if n.Size == 0 || n.Size > 32 {
- return DataTypeUnknown, errors.ErrorCodeInvalidBytesSize
+func (n *FixedBytesTypeNode) GetType() (DataType, errors.ErrorCode, string) {
+ isNotZero := n.Size != 0
+ isInRange := n.Size <= 32
+ if !isNotZero || !isInRange {
+ code := errors.ErrorCodeInvalidBytesSize
+ if !isNotZero {
+ return DataTypeUnknown, code, "bytes size cannot be zero"
+ }
+ if !isInRange {
+ return DataTypeUnknown, code, fmt.Sprintf(
+ "bytes size %d cannot be larger than 32", n.Size)
+ }
+ panic("unreachable")
}
major := DataTypeMajorFixedBytes
minor := DataTypeMinor(n.Size - 1)
- return ComposeDataType(major, minor), errors.ErrorCodeNil
+ return ComposeDataType(major, minor), errors.ErrorCodeNil, ""
}
// AddressTypeNode represents solidity address type.
@@ -440,9 +493,9 @@ func (n *AddressTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *AddressTypeNode) GetType() (DataType, errors.ErrorCode) {
+func (n *AddressTypeNode) GetType() (DataType, errors.ErrorCode, string) {
return ComposeDataType(DataTypeMajorAddress, DataTypeMinorDontCare),
- errors.ErrorCodeNil
+ errors.ErrorCodeNil, ""
}
// BoolTypeNode represents solidity bool type.
@@ -458,9 +511,9 @@ func (n *BoolTypeNode) GetChildren() []Node {
}
// GetType returns the type represented by the node.
-func (n *BoolTypeNode) GetType() (DataType, errors.ErrorCode) {
+func (n *BoolTypeNode) GetType() (DataType, errors.ErrorCode, string) {
return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare),
- errors.ErrorCodeNil
+ errors.ErrorCodeNil, ""
}
// ---------------------------------------------------------------------------
@@ -825,7 +878,7 @@ func (n *CastOperatorNode) IsConstant() bool {
// GetType returns the type of CAST expression, which is always the target type.
func (n *CastOperatorNode) GetType() DataType {
- if dt, code := n.TargetType.GetType(); code == errors.ErrorCodeNil {
+ if dt, code, _ := n.TargetType.GetType(); code == errors.ErrorCodeNil {
return dt
}
return DataTypeUnknown
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index c37e6c012..5151334a9 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -101,18 +101,6 @@ func (dt DataType) Size() uint8 {
}
}
-// DataTypeEncode encodes data type node into DataType.
-func DataTypeEncode(n TypeNode) (DataType, error) {
- if n == nil {
- return DataTypeUnknown, se.ErrorCodeDataTypeEncode
- }
- t, code := n.GetType()
- if code == se.ErrorCodeNil {
- return t, nil
- }
- return t, code
-}
-
// DataTypeDecode decodes DataType into data type node.
func DataTypeDecode(t DataType) (TypeNode, error) {
major, minor := DecomposeDataType(t)
diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go
index 8373aa344..15e9cdec0 100644
--- a/core/vm/sqlvm/ast/types_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -25,8 +25,9 @@ func (s *TypesTestSuite) requireEncodeAndDecodeDecimalNoError(
func (s *TypesTestSuite) requireEncodeAndDecodeNoError(
d DataType, t TypeNode) {
- encode, err := DataTypeEncode(t)
- s.Require().NoError(err)
+ encode, code, message := t.GetType()
+ s.Require().Zero(code)
+ s.Require().Empty(message)
s.Require().Equal(d, encode)
decode, err := DataTypeDecode(d)
s.Require().NoError(err)
@@ -34,8 +35,9 @@ func (s *TypesTestSuite) requireEncodeAndDecodeNoError(
}
func (s *TypesTestSuite) requireEncodeError(input TypeNode) {
- _, err := DataTypeEncode(input)
- s.Require().Error(err)
+ _, code, message := input.GetType()
+ s.Require().NotZero(code)
+ s.Require().NotEmpty(message)
}
func (s *TypesTestSuite) requireDecodeError(input DataType) {
@@ -71,7 +73,6 @@ func (s *TypesTestSuite) TestEncodeAndDecode() {
}
func (s *TypesTestSuite) TestEncodeError() {
- s.requireEncodeError(nil)
s.requireEncodeError(&IntTypeNode{Size: 1})
s.requireEncodeError(&IntTypeNode{Size: 257})
s.requireEncodeError(&FixedBytesTypeNode{Size: 0})
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index c327d1b43..41faae93e 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -108,7 +108,6 @@ const (
ErrorCodeInvalidUfixedSize
ErrorCodeInvalidFixedFractionalDigits
ErrorCodeInvalidUfixedFractionalDigits
- ErrorCodeDataTypeEncode
ErrorCodeDataTypeDecode
ErrorCodeDecimalEncode
ErrorCodeDecimalDecode
@@ -142,7 +141,6 @@ var errorCodeMap = [...]string{
ErrorCodeInvalidUfixedSize: "invalid ufixed size",
ErrorCodeInvalidFixedFractionalDigits: "invalid fixed fractional digits",
ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits",
- ErrorCodeDataTypeEncode: "data type encode failed",
ErrorCodeDataTypeDecode: "data type decode failed",
ErrorCodeDecimalEncode: "decimal encode failed",
ErrorCodeDecimalDecode: "decimal decode failed",