aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/vm/sqlvm/checker/expr.go323
1 files changed, 238 insertions, 85 deletions
diff --git a/core/vm/sqlvm/checker/expr.go b/core/vm/sqlvm/checker/expr.go
index 9ea754a48..d62f82ede 100644
--- a/core/vm/sqlvm/checker/expr.go
+++ b/core/vm/sqlvm/checker/expr.go
@@ -141,7 +141,7 @@ func checkExpr(n ast.ExprNode,
return n
case *ast.InOperatorNode:
- return n
+ return checkInOperator(n, s, o, c, el, tr, ta)
case *ast.FunctionOperatorNode:
return n
@@ -1388,12 +1388,57 @@ func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer,
Severity: errors.ErrorSeverityError,
Prefix: fn,
Message: fmt.Sprintf(
- "cannot use %s with %s because the other operand is expected "+
- "to be %s",
+ "cannot use %s with %s because it is already used with %s",
describeValueNodeType(n), op, describeValueNodeType(nExpected)),
}, nil)
}
+func compatibleValueNodes(expected, given ast.Valuer) bool {
+ switch expected.(type) {
+ case *ast.BoolValueNode:
+ switch given.(type) {
+ case *ast.BoolValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.AddressValueNode:
+ switch given.(type) {
+ case *ast.AddressValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.IntegerValueNode:
+ switch given.(type) {
+ case *ast.IntegerValueNode:
+ case *ast.DecimalValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.DecimalValueNode:
+ switch given.(type) {
+ case *ast.IntegerValueNode:
+ case *ast.DecimalValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.BytesValueNode:
+ switch given.(type) {
+ case *ast.BytesValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.NullValueNode:
+ default:
+ panic(unknownValueNodeType(expected))
+ }
+ return true
+}
+
func extractConstantValue(n ast.Valuer) constantValue {
switch n := n.(type) {
case *ast.BoolValueNode:
@@ -1417,7 +1462,7 @@ func unknownConstantValueType(v constantValue) string {
return fmt.Sprintf("unknown constant value type %T", v)
}
-func findNilConstantValue(v constantValue) constantValue {
+func newNilConstantValue(v constantValue) constantValue {
switch v.(type) {
case constantValueBool:
return newConstantValueBoolFromNil()
@@ -1439,53 +1484,7 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue,
) *ast.BoolValueNode {
- compatibleTypes := func() bool {
- switch object.(type) {
- case *ast.BoolValueNode:
- switch subject.(type) {
- case *ast.BoolValueNode:
- case *ast.NullValueNode:
- default:
- return false
- }
- case *ast.AddressValueNode:
- switch subject.(type) {
- case *ast.AddressValueNode:
- case *ast.NullValueNode:
- default:
- return false
- }
- case *ast.IntegerValueNode:
- switch subject.(type) {
- case *ast.IntegerValueNode:
- case *ast.DecimalValueNode:
- case *ast.NullValueNode:
- default:
- return false
- }
- case *ast.DecimalValueNode:
- switch subject.(type) {
- case *ast.IntegerValueNode:
- case *ast.DecimalValueNode:
- case *ast.NullValueNode:
- default:
- return false
- }
- case *ast.BytesValueNode:
- switch subject.(type) {
- case *ast.BytesValueNode:
- case *ast.NullValueNode:
- default:
- return false
- }
- case *ast.NullValueNode:
- default:
- panic(unknownValueNodeType(object))
- }
- return true
- }
-
- if !compatibleTypes() {
+ if !compatibleValueNodes(object, subject) {
elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
return nil
}
@@ -1498,9 +1497,9 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
arg1 = newConstantValueBoolFromNil()
arg2 = newConstantValueBoolFromNil()
} else if arg1 == nil {
- arg1 = findNilConstantValue(arg2)
+ arg1 = newNilConstantValue(arg2)
} else if arg2 == nil {
- arg2 = findNilConstantValue(arg1)
+ arg2 = newNilConstantValue(arg1)
}
// Now we are sure that all interfaces are non-nil.
@@ -1529,7 +1528,7 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
func checkRelationalOperator(n ast.BinaryOperator,
s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
- tr schema.TableRef, ta typeAction, fn, op string,
+ tr schema.TableRef, ta typeAction, fn, op string, requireOrdered bool,
evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue,
evalBytes func([]byte, []byte) ast.BoolValue,
evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue,
@@ -1542,13 +1541,17 @@ func checkRelationalOperator(n ast.BinaryOperator,
object := n.GetObject()
dtObject := object.GetType()
- if !validateOrderedType(dtObject, el, object, fn, op) {
- return nil
- }
+
subject := n.GetSubject()
dtSubject := subject.GetType()
- if !validateOrderedType(dtSubject, el, subject, fn, op) {
- return nil
+
+ if requireOrdered {
+ if !validateOrderedType(dtObject, el, object, fn, op) {
+ return nil
+ }
+ if !validateOrderedType(dtSubject, el, subject, fn, op) {
+ return nil
+ }
}
if _, ok := inferBinaryOperatorType(n, s, o, c, el, tr, fn, op); !ok {
@@ -1577,7 +1580,7 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
fn := "CheckGreaterOrEqualOperator"
op := "binary operator >="
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true,
func(v1, v2 ast.BoolValue) ast.BoolValue {
return v1.GreaterOrEqual(v2)
},
@@ -1604,7 +1607,7 @@ func checkLessOrEqualOperator(n *ast.LessOrEqualOperatorNode,
fn := "CheckLessOrEqualOperator"
op := "binary operator <="
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true,
func(v1, v2 ast.BoolValue) ast.BoolValue {
return v1.LessOrEqual(v2)
},
@@ -1631,7 +1634,7 @@ func checkNotEqualOperator(n *ast.NotEqualOperatorNode,
fn := "CheckNotEqualOperator"
op := "binary operator <>"
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, false,
func(v1, v2 ast.BoolValue) ast.BoolValue {
return v1.NotEqual(v2)
},
@@ -1650,6 +1653,24 @@ func checkNotEqualOperator(n *ast.NotEqualOperatorNode,
)
}
+func evalEqualBool(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.Equal(v2)
+}
+
+func evalEqualBytes(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Equal(v1, v2))
+}
+
+func evalEqualDecimal(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(v1.Decimal.Equal(v2.Decimal))
+}
+
func checkEqualOperator(n *ast.EqualOperatorNode,
s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
tr schema.TableRef, ta typeAction) ast.ExprNode {
@@ -1657,23 +1678,8 @@ func checkEqualOperator(n *ast.EqualOperatorNode,
fn := "CheckEqualOperator"
op := "binary operator ="
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
- func(v1, v2 ast.BoolValue) ast.BoolValue {
- return v1.Equal(v2)
- },
- func(v1, v2 []byte) ast.BoolValue {
- if v1 == nil || v2 == nil {
- return ast.BoolValueUnknown
- }
- return ast.NewBoolValueFromBool(bytes.Equal(v1, v2))
- },
- func(v1, v2 decimal.NullDecimal) ast.BoolValue {
- if !v1.Valid || !v2.Valid {
- return ast.BoolValueUnknown
- }
- return ast.NewBoolValueFromBool(v1.Decimal.Equal(v2.Decimal))
- },
- )
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, false,
+ evalEqualBool, evalEqualBytes, evalEqualDecimal)
}
func checkGreaterOperator(n *ast.GreaterOperatorNode,
@@ -1683,7 +1689,7 @@ func checkGreaterOperator(n *ast.GreaterOperatorNode,
fn := "CheckGreaterOperator"
op := "binary operator >"
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true,
func(v1, v2 ast.BoolValue) ast.BoolValue {
return v1.Greater(v2)
},
@@ -1709,7 +1715,7 @@ func checkLessOperator(n *ast.LessOperatorNode,
fn := "CheckLessOperator"
op := "binary operator <"
- return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true,
func(v1, v2 ast.BoolValue) ast.BoolValue {
return v1.Greater(v2)
},
@@ -2383,11 +2389,14 @@ func checkLikeOperator(n *ast.LikeOperatorNode,
}
if escape != nil {
if v, n, ok := extractOne(escape); ok {
- if !n && len(v) != 1 {
- panic("escape byte must be exactly one byte")
+ if n {
+ null = true
+ } else {
+ if len(v) != 1 {
+ panic("escape byte must be exactly one byte")
+ }
+ vesc = v[0]
}
- vesc = v[0]
- null = null || n
} else {
return nil, nil, 0, false, false
}
@@ -2483,3 +2492,147 @@ func checkLikeOperator(n *ast.LikeOperatorNode,
return verifyTypeAction(r, fn, dt, el, ta)
}
+
+func checkInOperator(n *ast.InOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckInOperator"
+ op := "operator IN"
+
+ children := make([]*ast.ExprNode, 0, 1+len(n.Right))
+ children = append(children, &n.Left)
+ for i := range n.Right {
+ children = append(children, &n.Right[i])
+ }
+
+ // Check our children first.
+ hasError := false
+ for _, child := range children {
+ result := checkExpr(*child, s, o, c, el, tr, nil)
+ if result != nil {
+ *child = result
+ } else {
+ hasError = true
+ }
+ }
+ if hasError {
+ return nil
+ }
+ r := ast.ExprNode(n)
+
+ // Determine the type.
+ dtChildren := ast.DataTypePending
+ for _, child := range children {
+ dtChild := (*child).GetType()
+ if !dtChild.Pending() {
+ dtChildren = dtChild
+ break
+ }
+ }
+
+ // If the type is determined, assign it to all children.
+ if !dtChildren.Pending() {
+ assign := newTypeActionAssign(dtChildren)
+ for _, child := range children {
+ result := checkExpr(*child, s, o, c, el, tr, assign)
+ if result == nil {
+ return nil
+ }
+ *child = result
+ }
+ }
+ dt := n.GetType()
+
+ // Fold constants.
+ fold := func() bool {
+ valuers := make([]ast.Valuer, len(children))
+ // Return early if it cannot be folded.
+ for i, child := range children {
+ if valuer, ok := (*child).(ast.Valuer); ok {
+ valuers[i] = valuer
+ } else {
+ return true
+ }
+ }
+ // Determine the type by finding the first non-NULL node.
+ var typeReference ast.Valuer
+ findType:
+ for _, valuer := range valuers {
+ switch valuer.(type) {
+ case *ast.BoolValueNode,
+ *ast.AddressValueNode,
+ *ast.IntegerValueNode,
+ *ast.DecimalValueNode,
+ *ast.BytesValueNode:
+ typeReference = valuer
+ break findType
+ case *ast.NullValueNode:
+ default:
+ panic(unknownValueNodeType(valuer))
+ }
+ }
+ // Check types of all children against the type we determined above.
+ for _, valuer := range valuers {
+ if !compatibleValueNodes(typeReference, valuer) {
+ elAppendTypeErrorOperandValueNode(
+ el, valuer, fn, op, typeReference)
+ return false
+ }
+ }
+ // Extract values and assign types to NULL values.
+ constantValueReference := extractConstantValue(typeReference)
+ values := make([]constantValue, len(valuers))
+ for i, valuer := range valuers {
+ value := extractConstantValue(valuer)
+ if value == nil {
+ if constantValueReference == nil {
+ value = newConstantValueBoolFromNil()
+ } else {
+ value = newNilConstantValue(constantValueReference)
+ }
+ }
+ values[i] = value
+ }
+ // Calculate the result.
+ var vo ast.BoolValue
+ switch v1 := values[0].(type) {
+ case constantValueBool:
+ v2 := values[1].(constantValueBool)
+ vo = evalEqualBool(v1.GetBool(), v2.GetBool())
+ for _, v2i := range values[2:] {
+ v2 := v2i.(constantValueBool)
+ vo = vo.Or(evalEqualBool(v1.GetBool(), v2.GetBool()))
+ }
+ case constantValueBytes:
+ v2 := values[1].(constantValueBytes)
+ vo = evalEqualBytes(v1.GetBytes(), v2.GetBytes())
+ for _, v2i := range values[2:] {
+ v2 := v2i.(constantValueBytes)
+ vo = vo.Or(evalEqualBytes(v1.GetBytes(), v2.GetBytes()))
+ }
+ case constantValueDecimal:
+ v2 := values[1].(constantValueDecimal)
+ vo = evalEqualDecimal(v1.GetDecimal(), v2.GetDecimal())
+ for _, v2i := range values[2:] {
+ v2 := v2i.(constantValueDecimal)
+ vo = vo.Or(evalEqualDecimal(v1.GetDecimal(), v2.GetDecimal()))
+ }
+ default:
+ panic(unknownConstantValueType(v1))
+ }
+ // Make the new node.
+ node := &ast.BoolValueNode{}
+ node.SetPosition(n.GetPosition())
+ node.SetLength(n.GetLength())
+ node.SetToken(n.GetToken())
+ node.V = vo
+ r = node
+ return true
+ }
+ if !fold() {
+ return nil
+ }
+
+ return verifyTypeAction(r, fn, dt, el, ta)
+}