aboutsummaryrefslogtreecommitdiffstats
path: root/core/vm
diff options
context:
space:
mode:
authorTing-Wei Lan <lantw44@gmail.com>2019-05-03 11:57:14 +0800
committerTing-Wei Lan <tingwei.lan@cobinhood.com>2019-05-14 11:04:15 +0800
commitd4c1848efc0cd1482f324dd456e37379f8e37cf0 (patch)
treea692dd88f6fb02b2596cbda889913b97e43a39d9 /core/vm
parentba0d231fbe274063a056e47d4e8c092adff0b5c0 (diff)
downloaddexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar.gz
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar.bz2
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar.lz
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar.xz
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.tar.zst
dexon-d4c1848efc0cd1482f324dd456e37379f8e37cf0.zip
code backup 28
Diffstat (limited to 'core/vm')
-rw-r--r--core/vm/sqlvm/checker/checker.go349
-rw-r--r--core/vm/sqlvm/checker/utils.go12
2 files changed, 265 insertions, 96 deletions
diff --git a/core/vm/sqlvm/checker/checker.go b/core/vm/sqlvm/checker/checker.go
index 5a2832c50..f376afc1d 100644
--- a/core/vm/sqlvm/checker/checker.go
+++ b/core/vm/sqlvm/checker/checker.go
@@ -874,19 +874,19 @@ func checkExpr(n ast.ExprNode,
return checkGreaterOrEqualOperator(n, s, o, c, el, tr, ta)
case *ast.LessOrEqualOperatorNode:
- return n
+ return checkLessOrEqualOperator(n, s, o, c, el, tr, ta)
case *ast.NotEqualOperatorNode:
- return n
+ return checkNotEqualOperator(n, s, o, c, el, tr, ta)
case *ast.EqualOperatorNode:
- return n
+ return checkEqualOperator(n, s, o, c, el, tr, ta)
case *ast.GreaterOperatorNode:
- return n
+ return checkGreaterOperator(n, s, o, c, el, tr, ta)
case *ast.LessOperatorNode:
- return n
+ return checkLessOperator(n, s, o, c, el, tr, ta)
case *ast.ConcatOperatorNode:
return n
@@ -2063,6 +2063,54 @@ func elAppendTypeErrorOperandDataType(el *errors.ErrorList, n ast.ExprNode,
}, nil)
}
+func inferBinaryOperatorType(n ast.BinaryOperator,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, fn, op string) (ast.DataType, bool) {
+
+ object := n.GetObject()
+ dtObject := object.GetType()
+ dtObjectDetermined := !dtObject.Pending()
+
+ subject := n.GetSubject()
+ dtSubject := subject.GetType()
+ dtSubjectDetermined := !dtSubject.Pending()
+
+ switch {
+ case dtObjectDetermined && dtSubjectDetermined:
+ if !dtObject.Equal(dtSubject) {
+ elAppendTypeErrorOperandDataType(
+ el, subject, fn, op, dtObject, dtSubject)
+ return ast.DataTypeBad, false
+ }
+ return dtObject, true
+
+ case dtObjectDetermined && !dtSubjectDetermined:
+ assign := newTypeActionAssign(dtObject)
+ subject = checkExpr(subject, s, o, c, el, tr, assign)
+ if subject == nil {
+ return ast.DataTypeBad, false
+ }
+ n.SetSubject(subject)
+ return dtObject, true
+
+ case !dtObjectDetermined && dtSubjectDetermined:
+ assign := newTypeActionAssign(dtSubject)
+ object = checkExpr(object, s, o, c, el, tr, assign)
+ if object == nil {
+ return ast.DataTypeBad, false
+ }
+ n.SetObject(object)
+ return dtSubject, true
+
+ case !dtObjectDetermined && !dtSubjectDetermined:
+ // We cannot do type checking when both types are unknown.
+ return ast.DataTypePending, true
+
+ default:
+ panic("unreachable")
+ }
+}
+
func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer,
fn, op string, nExpected ast.Valuer) {
@@ -2103,6 +2151,21 @@ func unknownConstantValueType(v constantValue) string {
return fmt.Sprintf("unknown constant value type %T", v)
}
+func findNilConstantValue(v constantValue) constantValue {
+ switch v.(type) {
+ case constantValueBool:
+ return newConstantValueBoolFromNil()
+ case constantValueBytes:
+ return newConstantValueBytesFromNil()
+ case constantValueDecimal:
+ return newConstantValueDecimalFromNil()
+ case nil:
+ return nil
+ default:
+ panic(unknownConstantValueType(v))
+ }
+}
+
func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
el *errors.ErrorList, fn, op string,
evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue,
@@ -2164,49 +2227,28 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
arg1 := extractConstantValue(object)
arg2 := extractConstantValue(subject)
+ // Resolve nil interfaces.
+ if arg1 == nil && arg2 == nil {
+ arg1 = newConstantValueBoolFromNil()
+ arg2 = newConstantValueBoolFromNil()
+ } else if arg1 == nil {
+ arg1 = findNilConstantValue(arg2)
+ } else if arg2 == nil {
+ arg2 = findNilConstantValue(arg1)
+ }
+
+ // Now we are sure that all interfaces are non-nil.
var vo ast.BoolValue
switch v1 := arg1.(type) {
case constantValueBool:
- var v2 ast.BoolValue
- if arg2 == nil {
- v2 = ast.BoolValueUnknown
- } else {
- v2 = arg2.(constantValueBool).GetBool()
- }
- vo = evalBool(v1.GetBool(), v2)
-
+ v2 := arg2.(constantValueBool)
+ vo = evalBool(v1.GetBool(), v2.GetBool())
case constantValueBytes:
- var v2 []byte
- if arg2 == nil {
- v2 = nil
- } else {
- v2 = arg2.(constantValueBytes).GetBytes()
- }
- vo = evalBytes(v1.GetBytes(), v2)
-
+ v2 := arg2.(constantValueBytes)
+ vo = evalBytes(v1.GetBytes(), v2.GetBytes())
case constantValueDecimal:
- var v2 decimal.NullDecimal
- if arg2 == nil {
- v2 = decimal.NullDecimal{Valid: false}
- } else {
- v2 = arg2.(constantValueDecimal).GetDecimal()
- }
- vo = evalDecimal(v1.GetDecimal(), v2)
-
- case nil:
- switch v2 := arg2.(type) {
- case constantValueBool:
- vo = evalBool(ast.BoolValueUnknown, v2.GetBool())
- case constantValueBytes:
- vo = evalBytes(nil, v2.GetBytes())
- case constantValueDecimal:
- vo = evalDecimal(decimal.NullDecimal{Valid: false}, v2.GetDecimal())
- case nil:
- vo = evalBool(ast.BoolValueUnknown, ast.BoolValueUnknown)
- default:
- panic(unknownConstantValueType(v2))
- }
-
+ v2 := arg2.(constantValueDecimal)
+ vo = evalDecimal(v1.GetDecimal(), v2.GetDecimal())
default:
panic(unknownConstantValueType(v1))
}
@@ -2219,12 +2261,13 @@ func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
return node
}
-func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
+func checkRelationalOperator(n ast.BinaryOperator,
s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
- tr schema.TableRef, ta typeAction) ast.ExprNode {
-
- fn := "CheckGreaterOrEqualOperator"
- op := "binary operator >="
+ tr schema.TableRef, ta typeAction, fn, op string,
+ evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue,
+ evalBytes func([]byte, []byte) ast.BoolValue,
+ evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue,
+) ast.ExprNode {
object := n.GetObject()
object = checkExpr(object, s, o, c, el, tr, nil)
@@ -2248,60 +2291,16 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
if !validateOrderedType(dtSubject, el, subject, fn, op) {
return nil
}
- dt := n.GetType()
- dtObjectDetermined := !dtObject.Pending()
- dtSubjectDetermined := !dtSubject.Pending()
- switch {
- case dtObjectDetermined && dtSubjectDetermined:
- if !dtObject.Equal(dtSubject) {
- elAppendTypeErrorOperandDataType(
- el, subject, fn, op, dtObject, dtSubject)
- return nil
- }
-
- case dtObjectDetermined && !dtSubjectDetermined:
- assign := newTypeActionAssign(dtObject)
- subject = checkExpr(subject, s, o, c, el, tr, assign)
- if subject == nil {
- return nil
- }
- n.SetSubject(subject)
-
- case !dtObjectDetermined && dtSubjectDetermined:
- assign := newTypeActionAssign(dtSubject)
- object = checkExpr(object, s, o, c, el, tr, assign)
- if object == nil {
- return nil
- }
- n.SetObject(object)
-
- case !dtObjectDetermined && !dtSubjectDetermined:
- // We cannot do type checking when both types are unknown.
-
- default:
- panic("unreachable")
+ if _, ok := inferBinaryOperatorType(n, s, o, c, el, tr, fn, op); !ok {
+ return nil
}
+ dt := n.GetType()
if object, ok := object.(ast.Valuer); ok {
if subject, ok := subject.(ast.Valuer); ok {
node := foldRelationalOperator(n, object, subject, el, fn, op,
- func(v1, v2 ast.BoolValue) ast.BoolValue {
- return v1.GreaterOrEqual(v2)
- },
- func(v1, v2 []byte) ast.BoolValue {
- if v1 == nil || v2 == nil {
- return ast.BoolValueUnknown
- }
- return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0)
- },
- func(v1, v2 decimal.NullDecimal) ast.BoolValue {
- if !v1.Valid || !v2.Valid {
- return ast.BoolValueUnknown
- }
- return ast.NewBoolValueFromBool(
- v1.Decimal.GreaterThanOrEqual(v2.Decimal))
- })
+ evalBool, evalBytes, evalDecimal)
if node == nil {
return nil
}
@@ -2320,3 +2319,161 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
}
return r
}
+
+func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckGreaterOrEqualOperator"
+ op := "binary operator >="
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.GreaterOrEqual(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0)
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(
+ v1.Decimal.GreaterThanOrEqual(v2.Decimal))
+ },
+ )
+}
+
+func checkLessOrEqualOperator(n *ast.LessOrEqualOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckLessOrEqualOperator"
+ op := "binary operator <="
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.LessOrEqual(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) <= 0)
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(
+ v1.Decimal.GreaterThanOrEqual(v2.Decimal))
+ },
+ )
+}
+
+func checkNotEqualOperator(n *ast.NotEqualOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckNotEqualOperator"
+ op := "binary operator <>"
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.NotEqual(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(!bytes.Equal(v1, v2))
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(!v1.Decimal.Equal(v2.Decimal))
+ },
+ )
+}
+
+func checkEqualOperator(n *ast.EqualOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckEqualOperator"
+ op := "binary operator ="
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.Equal(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Equal(v1, v2))
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(v1.Decimal.Equal(v2.Decimal))
+ },
+ )
+}
+
+func checkGreaterOperator(n *ast.GreaterOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckGreaterOperator"
+ op := "binary operator >"
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.Greater(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) > 0)
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(v1.Decimal.GreaterThan(v2.Decimal))
+ },
+ )
+}
+
+func checkLessOperator(n *ast.LessOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckLessOperator"
+ op := "binary operator <"
+
+ return checkRelationalOperator(n, s, o, c, el, tr, ta, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.Greater(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) < 0)
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(v1.Decimal.LessThan(v2.Decimal))
+ },
+ )
+}
diff --git a/core/vm/sqlvm/checker/utils.go b/core/vm/sqlvm/checker/utils.go
index 60a3f7c86..36a77ef9e 100644
--- a/core/vm/sqlvm/checker/utils.go
+++ b/core/vm/sqlvm/checker/utils.go
@@ -507,6 +507,10 @@ func newConstantValueBool(b ast.BoolValue) constantValueBool {
return constantValueBool(b)
}
+func newConstantValueBoolFromNil() constantValueBool {
+ return constantValueBool(ast.BoolValueUnknown)
+}
+
func (b constantValueBool) GetBool() ast.BoolValue {
return ast.BoolValue(b)
}
@@ -524,6 +528,10 @@ func newConstantValueBytes(b []byte) constantValueBytes {
return constantValueBytes(b)
}
+func newConstantValueBytesFromNil() constantValueBytes {
+ return nil
+}
+
func (b constantValueBytes) GetBytes() []byte {
return []byte(b)
}
@@ -538,6 +546,10 @@ func newConstantValueDecimal(d decimal.Decimal) constantValueDecimal {
return constantValueDecimal{Decimal: d, Valid: true}
}
+func newConstantValueDecimalFromNil() constantValueDecimal {
+ return constantValueDecimal{Valid: false}
+}
+
func (d constantValueDecimal) GetDecimal() decimal.NullDecimal {
return decimal.NullDecimal(d)
}