aboutsummaryrefslogtreecommitdiffstats
path: root/core/vm/sqlvm/ast
diff options
context:
space:
mode:
Diffstat (limited to 'core/vm/sqlvm/ast')
-rw-r--r--core/vm/sqlvm/ast/ast.go2
-rw-r--r--core/vm/sqlvm/ast/types.go95
-rw-r--r--core/vm/sqlvm/ast/types_test.go74
3 files changed, 170 insertions, 1 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go
index 80daacfe1..0a82ac76c 100644
--- a/core/vm/sqlvm/ast/ast.go
+++ b/core/vm/sqlvm/ast/ast.go
@@ -135,7 +135,7 @@ type Valuer interface {
// BoolValueNode is a boolean constant.
type BoolValueNode struct {
UntaggedExprNodeBase
- V bool
+ V BoolValue
}
var _ ExprNode = (*BoolValueNode)(nil)
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index 9d7efc430..ee8036576 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -1,6 +1,7 @@
package ast
import (
+ "database/sql"
"fmt"
"math"
"math/big"
@@ -17,6 +18,16 @@ var (
bigIntTen = big.NewInt(10)
)
+// BoolValue represents a boolean value used by SQL three-valued logic.
+type BoolValue uint8
+
+// Define valid values for SQL boolean type. The zero value is invalid.
+const (
+ BoolValueTrue BoolValue = 1
+ BoolValueFalse BoolValue = 2
+ BoolValueUnknown BoolValue = 3
+)
+
// DataTypeMajor defines type for high byte of DataType.
type DataTypeMajor uint8
@@ -54,6 +65,90 @@ const (
DataTypeBad DataType = math.MaxUint16
)
+// Valid returns whether a BoolValue is valid.
+func (v BoolValue) Valid() bool {
+ return v-1 < 3
+}
+
+var boolValueStringMap = [3]string{
+ BoolValueTrue - 1: "TRUE",
+ BoolValueFalse - 1: "FALSE",
+ BoolValueUnknown - 1: "UNKNOWN",
+}
+
+// String returns a string for printing a BoolValue.
+func (v BoolValue) String() string {
+ return boolValueStringMap[v-1]
+}
+
+var boolValueNullBoolMap = [3]sql.NullBool{
+ BoolValueTrue - 1: {Valid: true, Bool: true},
+ BoolValueFalse - 1: {Valid: true, Bool: false},
+ BoolValueUnknown - 1: {Valid: false, Bool: false},
+}
+
+// NullBool converts a BoolValue to a sql.NullBool.
+func (v BoolValue) NullBool() sql.NullBool {
+ return boolValueNullBoolMap[v-1]
+}
+
+var boolValueAndTruthTable = [3][3]BoolValue{
+ BoolValueTrue - 1: {
+ BoolValueTrue - 1: BoolValueTrue,
+ BoolValueFalse - 1: BoolValueFalse,
+ BoolValueUnknown - 1: BoolValueUnknown,
+ },
+ BoolValueFalse - 1: {
+ BoolValueTrue - 1: BoolValueFalse,
+ BoolValueFalse - 1: BoolValueFalse,
+ BoolValueUnknown - 1: BoolValueFalse,
+ },
+ BoolValueUnknown - 1: {
+ BoolValueTrue - 1: BoolValueUnknown,
+ BoolValueFalse - 1: BoolValueFalse,
+ BoolValueUnknown - 1: BoolValueUnknown,
+ },
+}
+
+// And returns v AND v2.
+func (v BoolValue) And(v2 BoolValue) BoolValue {
+ return boolValueAndTruthTable[v-1][v2-1]
+}
+
+var boolValueOrTruthTable = [3][3]BoolValue{
+ BoolValueTrue - 1: {
+ BoolValueTrue - 1: BoolValueTrue,
+ BoolValueFalse - 1: BoolValueTrue,
+ BoolValueUnknown - 1: BoolValueTrue,
+ },
+ BoolValueFalse - 1: {
+ BoolValueTrue - 1: BoolValueTrue,
+ BoolValueFalse - 1: BoolValueFalse,
+ BoolValueUnknown - 1: BoolValueUnknown,
+ },
+ BoolValueUnknown - 1: {
+ BoolValueTrue - 1: BoolValueTrue,
+ BoolValueFalse - 1: BoolValueUnknown,
+ BoolValueUnknown - 1: BoolValueUnknown,
+ },
+}
+
+// Or returns v OR v2.
+func (v BoolValue) Or(v2 BoolValue) BoolValue {
+ return boolValueOrTruthTable[v-1][v2-1]
+}
+
+var boolValueNotTruthTable = [3]BoolValue{
+ BoolValueTrue - 1: BoolValueFalse,
+ BoolValueFalse - 1: BoolValueTrue,
+ BoolValueUnknown - 1: BoolValueUnknown,
+}
+
+// Not returns NOT v.
+func (v BoolValue) Not() BoolValue {
+ return boolValueNotTruthTable[v-1]
+}
+
// DecomposeDataType to major and minor part with given data type.
func DecomposeDataType(t DataType) (DataTypeMajor, DataTypeMinor) {
return DataTypeMajor(t >> 8), DataTypeMinor(t & 0xff)
diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go
index 89a000251..d2051c4b5 100644
--- a/core/vm/sqlvm/ast/types_test.go
+++ b/core/vm/sqlvm/ast/types_test.go
@@ -1,6 +1,7 @@
package ast
import (
+ "database/sql"
"testing"
"github.com/shopspring/decimal"
@@ -235,6 +236,79 @@ func (s *TypesTestSuite) TestDecimalToUint64() {
}
}
+func (s *TypesTestSuite) TestBoolValueValidity() {
+ var v BoolValue
+ s.Require().False(v.Valid())
+ s.Require().Panics(func() { _ = v.String() })
+ s.Require().Panics(func() { _ = v.NullBool() })
+ v = BoolValue(1)
+ s.Require().True(v.Valid())
+ s.Require().Equal("TRUE", v.String())
+ s.Require().Equal(sql.NullBool{Valid: true, Bool: true}, v.NullBool())
+ v = BoolValue(4)
+ s.Require().False(v.Valid())
+ s.Require().Panics(func() { _ = v.String() })
+ s.Require().Panics(func() { _ = v.NullBool() })
+}
+
+func (s *TypesTestSuite) TestBoolValueOperations() {
+ and := func(v, v2 BoolValue) BoolValue {
+ if v == BoolValueFalse || v2 == BoolValueFalse {
+ return BoolValueFalse
+ }
+ if v == BoolValueUnknown || v2 == BoolValueUnknown {
+ return BoolValueUnknown
+ }
+ // v is true.
+ return v2
+ }
+ or := func(v, v2 BoolValue) BoolValue {
+ if v == BoolValueTrue || v2 == BoolValueTrue {
+ return BoolValueTrue
+ }
+ if v == BoolValueUnknown || v2 == BoolValueUnknown {
+ return BoolValueUnknown
+ }
+ // v is false.
+ return v2
+ }
+ not := func(v BoolValue) BoolValue {
+ switch v {
+ case BoolValueTrue:
+ return BoolValueFalse
+ case BoolValueFalse:
+ return BoolValueTrue
+ case BoolValueUnknown:
+ return BoolValueUnknown
+ }
+ // v is invalid.
+ return v
+ }
+ values := [3]BoolValue{BoolValueTrue, BoolValueFalse, BoolValueUnknown}
+ for _, v := range values {
+ for _, v2 := range values {
+ expected := and(v, v2)
+ actual := v.And(v2)
+ s.Require().Equalf(expected, actual,
+ "%v AND %v = %v, but %v is returned", v, v2, expected, actual)
+ }
+ }
+ for _, v := range values {
+ for _, v2 := range values {
+ expected := or(v, v2)
+ actual := v.Or(v2)
+ s.Require().Equalf(expected, actual,
+ "%v OR %v = %v, but %v is returned", v, v2, expected, actual)
+ }
+ }
+ for _, v := range values {
+ expected := not(v)
+ actual := v.Not()
+ s.Require().Equalf(expected, actual,
+ "NOT %v = %v, but %v is returned", v, expected, actual)
+ }
+}
+
func TestTypes(t *testing.T) {
suite.Run(t, new(TypesTestSuite))
}