diff options
-rw-r--r-- | core/vm/sqlvm/checker/expr.go | 323 |
1 files changed, 238 insertions, 85 deletions
diff --git a/core/vm/sqlvm/checker/expr.go b/core/vm/sqlvm/checker/expr.go index 9ea754a48..d62f82ede 100644 --- a/core/vm/sqlvm/checker/expr.go +++ b/core/vm/sqlvm/checker/expr.go @@ -141,7 +141,7 @@ func checkExpr(n ast.ExprNode, return n case *ast.InOperatorNode: - return n + return checkInOperator(n, s, o, c, el, tr, ta) case *ast.FunctionOperatorNode: return n @@ -1388,12 +1388,57 @@ func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer, Severity: errors.ErrorSeverityError, Prefix: fn, Message: fmt.Sprintf( - "cannot use %s with %s because the other operand is expected "+ - "to be %s", + "cannot use %s with %s because it is already used with %s", describeValueNodeType(n), op, describeValueNodeType(nExpected)), }, nil) } +func compatibleValueNodes(expected, given ast.Valuer) bool { + switch expected.(type) { + case *ast.BoolValueNode: + switch given.(type) { + case *ast.BoolValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.AddressValueNode: + switch given.(type) { + case *ast.AddressValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.IntegerValueNode: + switch given.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.DecimalValueNode: + switch given.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.BytesValueNode: + switch given.(type) { + case *ast.BytesValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.NullValueNode: + default: + panic(unknownValueNodeType(expected)) + } + return true +} + func extractConstantValue(n ast.Valuer) constantValue { switch n := n.(type) { case *ast.BoolValueNode: @@ -1417,7 +1462,7 @@ func unknownConstantValueType(v constantValue) string { return fmt.Sprintf("unknown constant value type %T", v) } -func findNilConstantValue(v constantValue) constantValue { +func newNilConstantValue(v constantValue) constantValue { switch v.(type) { case constantValueBool: return newConstantValueBoolFromNil() @@ -1439,53 +1484,7 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer, evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue, ) *ast.BoolValueNode { - compatibleTypes := func() bool { - switch object.(type) { - case *ast.BoolValueNode: - switch subject.(type) { - case *ast.BoolValueNode: - case *ast.NullValueNode: - default: - return false - } - case *ast.AddressValueNode: - switch subject.(type) { - case *ast.AddressValueNode: - case *ast.NullValueNode: - default: - return false - } - case *ast.IntegerValueNode: - switch subject.(type) { - case *ast.IntegerValueNode: - case *ast.DecimalValueNode: - case *ast.NullValueNode: - default: - return false - } - case *ast.DecimalValueNode: - switch subject.(type) { - case *ast.IntegerValueNode: - case *ast.DecimalValueNode: - case *ast.NullValueNode: - default: - return false - } - case *ast.BytesValueNode: - switch subject.(type) { - case *ast.BytesValueNode: - case *ast.NullValueNode: - default: - return false - } - case *ast.NullValueNode: - default: - panic(unknownValueNodeType(object)) - } - return true - } - - if !compatibleTypes() { + if !compatibleValueNodes(object, subject) { elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) return nil } @@ -1498,9 +1497,9 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer, arg1 = newConstantValueBoolFromNil() arg2 = newConstantValueBoolFromNil() } else if arg1 == nil { - arg1 = findNilConstantValue(arg2) + arg1 = newNilConstantValue(arg2) } else if arg2 == nil { - arg2 = findNilConstantValue(arg1) + arg2 = newNilConstantValue(arg1) } // Now we are sure that all interfaces are non-nil. @@ -1529,7 +1528,7 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer, func checkRelationalOperator(n ast.BinaryOperator, s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, - tr schema.TableRef, ta typeAction, fn, op string, + tr schema.TableRef, ta typeAction, fn, op string, requireOrdered bool, evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue, evalBytes func([]byte, []byte) ast.BoolValue, evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue, @@ -1542,13 +1541,17 @@ func checkRelationalOperator(n ast.BinaryOperator, object := n.GetObject() dtObject := object.GetType() - if !validateOrderedType(dtObject, el, object, fn, op) { - return nil - } + subject := n.GetSubject() dtSubject := subject.GetType() - if !validateOrderedType(dtSubject, el, subject, fn, op) { - return nil + + if requireOrdered { + if !validateOrderedType(dtObject, el, object, fn, op) { + return nil + } + if !validateOrderedType(dtSubject, el, subject, fn, op) { + return nil + } } if _, ok := inferBinaryOperatorType(n, s, o, c, el, tr, fn, op); !ok { @@ -1577,7 +1580,7 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode, fn := "CheckGreaterOrEqualOperator" op := "binary operator >=" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true, func(v1, v2 ast.BoolValue) ast.BoolValue { return v1.GreaterOrEqual(v2) }, @@ -1604,7 +1607,7 @@ func checkLessOrEqualOperator(n *ast.LessOrEqualOperatorNode, fn := "CheckLessOrEqualOperator" op := "binary operator <=" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true, func(v1, v2 ast.BoolValue) ast.BoolValue { return v1.LessOrEqual(v2) }, @@ -1631,7 +1634,7 @@ func checkNotEqualOperator(n *ast.NotEqualOperatorNode, fn := "CheckNotEqualOperator" op := "binary operator <>" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, false, func(v1, v2 ast.BoolValue) ast.BoolValue { return v1.NotEqual(v2) }, @@ -1650,6 +1653,24 @@ func checkNotEqualOperator(n *ast.NotEqualOperatorNode, ) } +func evalEqualBool(v1, v2 ast.BoolValue) ast.BoolValue { + return v1.Equal(v2) +} + +func evalEqualBytes(v1, v2 []byte) ast.BoolValue { + if v1 == nil || v2 == nil { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool(bytes.Equal(v1, v2)) +} + +func evalEqualDecimal(v1, v2 decimal.NullDecimal) ast.BoolValue { + if !v1.Valid || !v2.Valid { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool(v1.Decimal.Equal(v2.Decimal)) +} + func checkEqualOperator(n *ast.EqualOperatorNode, s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, tr schema.TableRef, ta typeAction) ast.ExprNode { @@ -1657,23 +1678,8 @@ func checkEqualOperator(n *ast.EqualOperatorNode, fn := "CheckEqualOperator" op := "binary operator =" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, - func(v1, v2 ast.BoolValue) ast.BoolValue { - return v1.Equal(v2) - }, - func(v1, v2 []byte) ast.BoolValue { - if v1 == nil || v2 == nil { - return ast.BoolValueUnknown - } - return ast.NewBoolValueFromBool(bytes.Equal(v1, v2)) - }, - func(v1, v2 decimal.NullDecimal) ast.BoolValue { - if !v1.Valid || !v2.Valid { - return ast.BoolValueUnknown - } - return ast.NewBoolValueFromBool(v1.Decimal.Equal(v2.Decimal)) - }, - ) + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, false, + evalEqualBool, evalEqualBytes, evalEqualDecimal) } func checkGreaterOperator(n *ast.GreaterOperatorNode, @@ -1683,7 +1689,7 @@ func checkGreaterOperator(n *ast.GreaterOperatorNode, fn := "CheckGreaterOperator" op := "binary operator >" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true, func(v1, v2 ast.BoolValue) ast.BoolValue { return v1.Greater(v2) }, @@ -1709,7 +1715,7 @@ func checkLessOperator(n *ast.LessOperatorNode, fn := "CheckLessOperator" op := "binary operator <" - return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, + return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op, true, func(v1, v2 ast.BoolValue) ast.BoolValue { return v1.Greater(v2) }, @@ -2383,11 +2389,14 @@ func checkLikeOperator(n *ast.LikeOperatorNode, } if escape != nil { if v, n, ok := extractOne(escape); ok { - if !n && len(v) != 1 { - panic("escape byte must be exactly one byte") + if n { + null = true + } else { + if len(v) != 1 { + panic("escape byte must be exactly one byte") + } + vesc = v[0] } - vesc = v[0] - null = null || n } else { return nil, nil, 0, false, false } @@ -2483,3 +2492,147 @@ func checkLikeOperator(n *ast.LikeOperatorNode, return verifyTypeAction(r, fn, dt, el, ta) } + +func checkInOperator(n *ast.InOperatorNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + fn := "CheckInOperator" + op := "operator IN" + + children := make([]*ast.ExprNode, 0, 1+len(n.Right)) + children = append(children, &n.Left) + for i := range n.Right { + children = append(children, &n.Right[i]) + } + + // Check our children first. + hasError := false + for _, child := range children { + result := checkExpr(*child, s, o, c, el, tr, nil) + if result != nil { + *child = result + } else { + hasError = true + } + } + if hasError { + return nil + } + r := ast.ExprNode(n) + + // Determine the type. + dtChildren := ast.DataTypePending + for _, child := range children { + dtChild := (*child).GetType() + if !dtChild.Pending() { + dtChildren = dtChild + break + } + } + + // If the type is determined, assign it to all children. + if !dtChildren.Pending() { + assign := newTypeActionAssign(dtChildren) + for _, child := range children { + result := checkExpr(*child, s, o, c, el, tr, assign) + if result == nil { + return nil + } + *child = result + } + } + dt := n.GetType() + + // Fold constants. + fold := func() bool { + valuers := make([]ast.Valuer, len(children)) + // Return early if it cannot be folded. + for i, child := range children { + if valuer, ok := (*child).(ast.Valuer); ok { + valuers[i] = valuer + } else { + return true + } + } + // Determine the type by finding the first non-NULL node. + var typeReference ast.Valuer + findType: + for _, valuer := range valuers { + switch valuer.(type) { + case *ast.BoolValueNode, + *ast.AddressValueNode, + *ast.IntegerValueNode, + *ast.DecimalValueNode, + *ast.BytesValueNode: + typeReference = valuer + break findType + case *ast.NullValueNode: + default: + panic(unknownValueNodeType(valuer)) + } + } + // Check types of all children against the type we determined above. + for _, valuer := range valuers { + if !compatibleValueNodes(typeReference, valuer) { + elAppendTypeErrorOperandValueNode( + el, valuer, fn, op, typeReference) + return false + } + } + // Extract values and assign types to NULL values. + constantValueReference := extractConstantValue(typeReference) + values := make([]constantValue, len(valuers)) + for i, valuer := range valuers { + value := extractConstantValue(valuer) + if value == nil { + if constantValueReference == nil { + value = newConstantValueBoolFromNil() + } else { + value = newNilConstantValue(constantValueReference) + } + } + values[i] = value + } + // Calculate the result. + var vo ast.BoolValue + switch v1 := values[0].(type) { + case constantValueBool: + v2 := values[1].(constantValueBool) + vo = evalEqualBool(v1.GetBool(), v2.GetBool()) + for _, v2i := range values[2:] { + v2 := v2i.(constantValueBool) + vo = vo.Or(evalEqualBool(v1.GetBool(), v2.GetBool())) + } + case constantValueBytes: + v2 := values[1].(constantValueBytes) + vo = evalEqualBytes(v1.GetBytes(), v2.GetBytes()) + for _, v2i := range values[2:] { + v2 := v2i.(constantValueBytes) + vo = vo.Or(evalEqualBytes(v1.GetBytes(), v2.GetBytes())) + } + case constantValueDecimal: + v2 := values[1].(constantValueDecimal) + vo = evalEqualDecimal(v1.GetDecimal(), v2.GetDecimal()) + for _, v2i := range values[2:] { + v2 := v2i.(constantValueDecimal) + vo = vo.Or(evalEqualDecimal(v1.GetDecimal(), v2.GetDecimal())) + } + default: + panic(unknownConstantValueType(v1)) + } + // Make the new node. + node := &ast.BoolValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.V = vo + r = node + return true + } + if !fold() { + return nil + } + + return verifyTypeAction(r, fn, dt, el, ta) +} |