diff options
Diffstat (limited to 'core/vm/sqlvm/ast')
-rw-r--r-- | core/vm/sqlvm/ast/ast.go | 2 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types.go | 95 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types_test.go | 74 |
3 files changed, 170 insertions, 1 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index 80daacfe1..0a82ac76c 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -135,7 +135,7 @@ type Valuer interface { // BoolValueNode is a boolean constant. type BoolValueNode struct { UntaggedExprNodeBase - V bool + V BoolValue } var _ ExprNode = (*BoolValueNode)(nil) diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index 9d7efc430..ee8036576 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -1,6 +1,7 @@ package ast import ( + "database/sql" "fmt" "math" "math/big" @@ -17,6 +18,16 @@ var ( bigIntTen = big.NewInt(10) ) +// BoolValue represents a boolean value used by SQL three-valued logic. +type BoolValue uint8 + +// Define valid values for SQL boolean type. The zero value is invalid. +const ( + BoolValueTrue BoolValue = 1 + BoolValueFalse BoolValue = 2 + BoolValueUnknown BoolValue = 3 +) + // DataTypeMajor defines type for high byte of DataType. type DataTypeMajor uint8 @@ -54,6 +65,90 @@ const ( DataTypeBad DataType = math.MaxUint16 ) +// Valid returns whether a BoolValue is valid. +func (v BoolValue) Valid() bool { + return v-1 < 3 +} + +var boolValueStringMap = [3]string{ + BoolValueTrue - 1: "TRUE", + BoolValueFalse - 1: "FALSE", + BoolValueUnknown - 1: "UNKNOWN", +} + +// String returns a string for printing a BoolValue. +func (v BoolValue) String() string { + return boolValueStringMap[v-1] +} + +var boolValueNullBoolMap = [3]sql.NullBool{ + BoolValueTrue - 1: {Valid: true, Bool: true}, + BoolValueFalse - 1: {Valid: true, Bool: false}, + BoolValueUnknown - 1: {Valid: false, Bool: false}, +} + +// NullBool converts a BoolValue to a sql.NullBool. +func (v BoolValue) NullBool() sql.NullBool { + return boolValueNullBoolMap[v-1] +} + +var boolValueAndTruthTable = [3][3]BoolValue{ + BoolValueTrue - 1: { + BoolValueTrue - 1: BoolValueTrue, + BoolValueFalse - 1: BoolValueFalse, + BoolValueUnknown - 1: BoolValueUnknown, + }, + BoolValueFalse - 1: { + BoolValueTrue - 1: BoolValueFalse, + BoolValueFalse - 1: BoolValueFalse, + BoolValueUnknown - 1: BoolValueFalse, + }, + BoolValueUnknown - 1: { + BoolValueTrue - 1: BoolValueUnknown, + BoolValueFalse - 1: BoolValueFalse, + BoolValueUnknown - 1: BoolValueUnknown, + }, +} + +// And returns v AND v2. +func (v BoolValue) And(v2 BoolValue) BoolValue { + return boolValueAndTruthTable[v-1][v2-1] +} + +var boolValueOrTruthTable = [3][3]BoolValue{ + BoolValueTrue - 1: { + BoolValueTrue - 1: BoolValueTrue, + BoolValueFalse - 1: BoolValueTrue, + BoolValueUnknown - 1: BoolValueTrue, + }, + BoolValueFalse - 1: { + BoolValueTrue - 1: BoolValueTrue, + BoolValueFalse - 1: BoolValueFalse, + BoolValueUnknown - 1: BoolValueUnknown, + }, + BoolValueUnknown - 1: { + BoolValueTrue - 1: BoolValueTrue, + BoolValueFalse - 1: BoolValueUnknown, + BoolValueUnknown - 1: BoolValueUnknown, + }, +} + +// Or returns v OR v2. +func (v BoolValue) Or(v2 BoolValue) BoolValue { + return boolValueOrTruthTable[v-1][v2-1] +} + +var boolValueNotTruthTable = [3]BoolValue{ + BoolValueTrue - 1: BoolValueFalse, + BoolValueFalse - 1: BoolValueTrue, + BoolValueUnknown - 1: BoolValueUnknown, +} + +// Not returns NOT v. +func (v BoolValue) Not() BoolValue { + return boolValueNotTruthTable[v-1] +} + // DecomposeDataType to major and minor part with given data type. func DecomposeDataType(t DataType) (DataTypeMajor, DataTypeMinor) { return DataTypeMajor(t >> 8), DataTypeMinor(t & 0xff) diff --git a/core/vm/sqlvm/ast/types_test.go b/core/vm/sqlvm/ast/types_test.go index 89a000251..d2051c4b5 100644 --- a/core/vm/sqlvm/ast/types_test.go +++ b/core/vm/sqlvm/ast/types_test.go @@ -1,6 +1,7 @@ package ast import ( + "database/sql" "testing" "github.com/shopspring/decimal" @@ -235,6 +236,79 @@ func (s *TypesTestSuite) TestDecimalToUint64() { } } +func (s *TypesTestSuite) TestBoolValueValidity() { + var v BoolValue + s.Require().False(v.Valid()) + s.Require().Panics(func() { _ = v.String() }) + s.Require().Panics(func() { _ = v.NullBool() }) + v = BoolValue(1) + s.Require().True(v.Valid()) + s.Require().Equal("TRUE", v.String()) + s.Require().Equal(sql.NullBool{Valid: true, Bool: true}, v.NullBool()) + v = BoolValue(4) + s.Require().False(v.Valid()) + s.Require().Panics(func() { _ = v.String() }) + s.Require().Panics(func() { _ = v.NullBool() }) +} + +func (s *TypesTestSuite) TestBoolValueOperations() { + and := func(v, v2 BoolValue) BoolValue { + if v == BoolValueFalse || v2 == BoolValueFalse { + return BoolValueFalse + } + if v == BoolValueUnknown || v2 == BoolValueUnknown { + return BoolValueUnknown + } + // v is true. + return v2 + } + or := func(v, v2 BoolValue) BoolValue { + if v == BoolValueTrue || v2 == BoolValueTrue { + return BoolValueTrue + } + if v == BoolValueUnknown || v2 == BoolValueUnknown { + return BoolValueUnknown + } + // v is false. + return v2 + } + not := func(v BoolValue) BoolValue { + switch v { + case BoolValueTrue: + return BoolValueFalse + case BoolValueFalse: + return BoolValueTrue + case BoolValueUnknown: + return BoolValueUnknown + } + // v is invalid. + return v + } + values := [3]BoolValue{BoolValueTrue, BoolValueFalse, BoolValueUnknown} + for _, v := range values { + for _, v2 := range values { + expected := and(v, v2) + actual := v.And(v2) + s.Require().Equalf(expected, actual, + "%v AND %v = %v, but %v is returned", v, v2, expected, actual) + } + } + for _, v := range values { + for _, v2 := range values { + expected := or(v, v2) + actual := v.Or(v2) + s.Require().Equalf(expected, actual, + "%v OR %v = %v, but %v is returned", v, v2, expected, actual) + } + } + for _, v := range values { + expected := not(v) + actual := v.Not() + s.Require().Equalf(expected, actual, + "NOT %v = %v, but %v is returned", v, expected, actual) + } +} + func TestTypes(t *testing.T) { suite.Run(t, new(TypesTestSuite)) } |