From 49acaa8391cef7bd069b965d7de827cb0ba683b7 Mon Sep 17 00:00:00 2001 From: Ting-Wei Lan Date: Mon, 29 Apr 2019 19:11:19 +0800 Subject: code backup 25 --- core/vm/sqlvm/checker/checker.go | 268 +++++++++++++++++++++++---------------- core/vm/sqlvm/checker/utils.go | 50 ++++++++ 2 files changed, 212 insertions(+), 106 deletions(-) (limited to 'core') diff --git a/core/vm/sqlvm/checker/checker.go b/core/vm/sqlvm/checker/checker.go index dce59f332..5dcfaa043 100644 --- a/core/vm/sqlvm/checker/checker.go +++ b/core/vm/sqlvm/checker/checker.go @@ -2052,8 +2052,8 @@ func elAppendTypeErrorOperandDataType(el *errors.ErrorList, n ast.ExprNode, Severity: errors.ErrorSeverityError, Prefix: fn, Message: fmt.Sprintf( - "cannot use %s (%04x) as an operand of %s because there is "+ - "already an operand declared as %s (%04x)", + "cannot use %s (%04x) with %s because the operand is expected "+ + "to be %s (%04x)", dtGiven.String(), uint16(dtGiven), op, dtExpected.String(), uint16(dtExpected)), }, nil) @@ -2070,12 +2070,151 @@ func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer, Severity: errors.ErrorSeverityError, Prefix: fn, Message: fmt.Sprintf( - "cannot use %s as an operand of %s because there is "+ - "already an operand found to be %s", + "cannot use %s with %s because the other operand is expected "+ + "to be %s", describeValueNodeType(n), op, describeValueNodeType(nExpected)), }, nil) } +func extractConstantValue(n ast.Valuer) constantValue { + switch n := n.(type) { + case *ast.BoolValueNode: + return newConstantValueBool(n.V) + case *ast.AddressValueNode: + return newConstantValueBytes(n.V) + case *ast.IntegerValueNode: + return newConstantValueDecimal(n.V) + case *ast.DecimalValueNode: + return newConstantValueDecimal(n.V) + case *ast.BytesValueNode: + return newConstantValueBytes(n.V) + case *ast.NullValueNode: + return nil + default: + panic(unknownValueNodeType(n)) + } +} + +func unknownConstantValueType(v constantValue) string { + return fmt.Sprintf("unknown constant value type %T", v) +} + +func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer, + el *errors.ErrorList, fn, op string, + evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue, + evalBytes func([]byte, []byte) ast.BoolValue, + evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue, +) *ast.BoolValueNode { + + compatibleTypes := func() bool { + switch object.(type) { + case *ast.BoolValueNode: + switch subject.(type) { + case *ast.BoolValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.AddressValueNode: + switch subject.(type) { + case *ast.AddressValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.IntegerValueNode: + switch subject.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.DecimalValueNode: + switch subject.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.BytesValueNode: + switch subject.(type) { + case *ast.BytesValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.NullValueNode: + default: + panic(unknownValueNodeType(object)) + } + return true + } + + if !compatibleTypes() { + elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) + return nil + } + + arg1 := extractConstantValue(object) + arg2 := extractConstantValue(subject) + + var vo ast.BoolValue + switch v1 := arg1.(type) { + case constantValueBool: + var v2 ast.BoolValue + if arg2 == nil { + v2 = ast.BoolValueUnknown + } else { + v2 = arg2.(constantValueBool).GetBool() + } + vo = evalBool(v1.GetBool(), v2) + + case constantValueBytes: + var v2 []byte + if arg2 == nil { + v2 = nil + } else { + v2 = arg2.(constantValueBytes).GetBytes() + } + vo = evalBytes(v1.GetBytes(), v2) + + case constantValueDecimal: + var v2 decimal.NullDecimal + if arg2 == nil { + v2 = decimal.NullDecimal{Valid: false} + } else { + v2 = arg2.(constantValueDecimal).GetDecimal() + } + vo = evalDecimal(v1.GetDecimal(), v2) + + case nil: + switch v2 := arg2.(type) { + case constantValueBool: + vo = evalBool(ast.BoolValueUnknown, v2.GetBool()) + case constantValueBytes: + vo = evalBytes(nil, v2.GetBytes()) + case constantValueDecimal: + vo = evalDecimal(decimal.NullDecimal{Valid: false}, v2.GetDecimal()) + case nil: + vo = evalBool(ast.BoolValueUnknown, ast.BoolValueUnknown) + default: + panic(unknownConstantValueType(v2)) + } + + default: + panic(unknownConstantValueType(v1)) + } + + node := &ast.BoolValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.V = vo + return node +} + func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode, s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, tr schema.TableRef, ta typeAction) ast.ExprNode { @@ -2140,112 +2279,29 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode, panic("unreachable") } - fold := func(object, subject ast.Valuer) bool { - var vo ast.BoolValue - eval: - switch object := object.(type) { - case *ast.BoolValueNode: - var v1, v2 ast.BoolValue - v1 = object.V - switch subject := subject.(type) { - case *ast.BoolValueNode: - v2 = subject.V - case *ast.NullValueNode: - v2 = ast.BoolValueUnknown - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = v1.GreaterOrEqual(v2) - - case *ast.AddressValueNode: - var v1, v2 []byte - v1 = object.V - switch subject := subject.(type) { - case *ast.AddressValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) - - case *ast.IntegerValueNode: - var v1, v2 decimal.Decimal - v1 = object.V - switch subject := subject.(type) { - case *ast.IntegerValueNode: - v2 = subject.V - case *ast.DecimalValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2)) - - case *ast.DecimalValueNode: - var v1, v2 decimal.Decimal - v1 = object.V - switch subject := subject.(type) { - case *ast.IntegerValueNode: - v2 = subject.V - case *ast.DecimalValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2)) - - case *ast.BytesValueNode: - var v1, v2 []byte - v1 = object.V - switch subject := subject.(type) { - case *ast.BytesValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) - - case *ast.NullValueNode: - switch subject := subject.(type) { - case *ast.BoolValueNode: - vo = ast.BoolValueUnknown.GreaterOrEqual(subject.V) - default: - vo = ast.BoolValueUnknown - } - - default: - panic(unknownValueNodeType(object)) - } - node := &ast.BoolValueNode{} - node.SetPosition(n.GetPosition()) - node.SetLength(n.GetLength()) - node.SetToken(n.GetToken()) - node.V = vo - r = node - return true - } if object, ok := object.(ast.Valuer); ok { if subject, ok := subject.(ast.Valuer); ok { - if !fold(object, subject) { + node := foldRelationalOperator(n, object, subject, el, fn, op, + func(v1, v2 ast.BoolValue) ast.BoolValue { + return v1.GreaterOrEqual(v2) + }, + func(v1, v2 []byte) ast.BoolValue { + if v1 == nil || v2 == nil { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) + }, + func(v1, v2 decimal.NullDecimal) ast.BoolValue { + if !v1.Valid || !v2.Valid { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool( + v1.Decimal.GreaterThanOrEqual(v2.Decimal)) + }) + if node == nil { return nil } + r = node } } diff --git a/core/vm/sqlvm/checker/utils.go b/core/vm/sqlvm/checker/utils.go index 34b73af4a..3ca49676f 100644 --- a/core/vm/sqlvm/checker/utils.go +++ b/core/vm/sqlvm/checker/utils.go @@ -491,3 +491,53 @@ func newTypeActionAssign(expected ast.DataType) typeActionAssign { var _ typeAction = typeActionAssign{} func (typeActionAssign) ˉtypeAction() {} + +//go-sumtype:decl constantValue +type constantValue interface { + ˉconstantValue() +} + +type constantValueBool ast.BoolValue + +var _ constantValue = constantValueBool(0) + +func (constantValueBool) ˉconstantValue() {} + +func newConstantValueBool(b ast.BoolValue) constantValueBool { + return constantValueBool(b) +} + +func (b constantValueBool) GetBool() ast.BoolValue { + return ast.BoolValue(b) +} + +type constantValueBytes []byte + +var _ constantValue = constantValueBytes{} + +func (constantValueBytes) ˉconstantValue() {} + +func newConstantValueBytes(b []byte) constantValueBytes { + if b == nil { + return constantValueBytes{} + } + return constantValueBytes(b) +} + +func (b constantValueBytes) GetBytes() []byte { + return []byte(b) +} + +type constantValueDecimal decimal.NullDecimal + +var _ constantValue = constantValueDecimal{} + +func (constantValueDecimal) ˉconstantValue() {} + +func newConstantValueDecimal(d decimal.Decimal) constantValueDecimal { + return constantValueDecimal{Decimal: d, Valid: true} +} + +func (d constantValueDecimal) GetDecimal() decimal.NullDecimal { + return decimal.NullDecimal(d) +} -- cgit v1.2.3