From 06266ab9f5f5491fed858a1cee668d37d2b3a44c Mon Sep 17 00:00:00 2001 From: Ting-Wei Lan <tingwei.lan@cobinhood.com> Date: Thu, 25 Apr 2019 16:17:06 +0800 Subject: code backup 20 --- core/vm/sqlvm/checker/actions.go | 147 +++ core/vm/sqlvm/checker/checker.go | 1714 +++++++++++++++++++++++++++++++++ core/vm/sqlvm/checker/utils.go | 493 ++++++++++ core/vm/sqlvm/checkers/actions.go | 147 --- core/vm/sqlvm/checkers/checkers.go | 1579 ------------------------------ core/vm/sqlvm/checkers/utils.go | 471 --------- core/vm/sqlvm/cmd/ast-checker/main.go | 45 +- 7 files changed, 2380 insertions(+), 2216 deletions(-) create mode 100644 core/vm/sqlvm/checker/actions.go create mode 100644 core/vm/sqlvm/checker/checker.go create mode 100644 core/vm/sqlvm/checker/utils.go delete mode 100644 core/vm/sqlvm/checkers/actions.go delete mode 100644 core/vm/sqlvm/checkers/checkers.go delete mode 100644 core/vm/sqlvm/checkers/utils.go (limited to 'core') diff --git a/core/vm/sqlvm/checker/actions.go b/core/vm/sqlvm/checker/actions.go new file mode 100644 index 000000000..6adb9b912 --- /dev/null +++ b/core/vm/sqlvm/checker/actions.go @@ -0,0 +1,147 @@ +package checker + +import ( + "fmt" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// CheckOptions stores boolean options for Check* functions. +type CheckOptions uint32 + +const ( + // CheckWithSafeMath enables overflow and underflow checks during expression + // evaluation. An error will be thrown when the result is out of range. + CheckWithSafeMath CheckOptions = 1 << iota + // CheckWithSafeCast enables overflow and underflow checks during casting. + // An error will be thrown if the value does not fit in the target type. + CheckWithSafeCast + // CheckWithConstantOnly restricts the expression to be a constant. An error + // will be thrown if the expression cannot be folded into a constant. + CheckWithConstantOnly +) + +// CheckCreate runs CREATE commands to generate a database schema. It modifies +// AST in-place during evaluation of expressions. +func CheckCreate(ss []ast.StmtNode, o CheckOptions) (schema.Schema, error) { + fn := "CheckCreate" + s := schema.Schema{} + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.CreateTableStmtNode: + checkCreateTableStmt(n, &s, o, c, &el) + case *ast.CreateIndexStmtNode: + checkCreateIndexStmt(n, &s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when creating a contract", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + + if len(s) == 0 && len(el) == 0 { + el.Append(errors.Error{ + Position: 0, + Length: 0, + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeNoCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "creating a contract without a table is not allowed", + }, nil) + } + if len(el) != 0 { + return s, el + } + return s, nil +} + +// CheckQuery checks and modifies SELECT commands with a given database schema. +func CheckQuery(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { + fn := "CheckQuery" + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.SelectStmtNode: + checkSelectStmt(n, s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when calling query", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + if len(el) != 0 { + return el + } + return nil +} + +// CheckExec checks and modifies UPDATE, DELETE, INSERT commands with a given +// database schema. +func CheckExec(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { + fn := "CheckExec" + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.UpdateStmtNode: + checkUpdateStmt(n, s, o, c, &el) + case *ast.DeleteStmtNode: + checkDeleteStmt(n, s, o, c, &el) + case *ast.InsertStmtNode: + checkInsertStmt(n, s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when calling exec", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + if len(el) != 0 { + return el + } + return nil +} diff --git a/core/vm/sqlvm/checker/checker.go b/core/vm/sqlvm/checker/checker.go new file mode 100644 index 000000000..dbdad68fa --- /dev/null +++ b/core/vm/sqlvm/checker/checker.go @@ -0,0 +1,1714 @@ +package checker + +import ( + "fmt" + "sort" + + "github.com/dexon-foundation/decimal" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// In addition to the convention mentioned in utils.go, we have these variable +// names in this file: +// +// ftd -> foreign table descriptor +// ftn -> foreign table name +// fcd -> foreign column descriptor +// fcn -> foreign column name +// fid -> foreign index descriptor +// fin -> foreign index name +// +// fmid -> first matching index descriptor +// fmir -> first matching index reference +// fmin -> first matching index name + +// findFirstMatchingIndex finds the first index in 'haystack' matching the +// declaration of 'needle' with attributes specified in 'attrDontCare' ignored. +// This function is considered as a part of the interface, so it have to work +// deterministically. +func findFirstMatchingIndex(haystack []schema.Index, needle schema.Index, + attrDontCare schema.IndexAttr) (schema.IndexRef, bool) { + + compareAttr := func(a1, a2 schema.IndexAttr) bool { + a1 = a1.GetDeclaredFlags() | attrDontCare + a2 = a2.GetDeclaredFlags() | attrDontCare + return a1 == a2 + } + + compareColumns := func(c1, c2 []schema.ColumnRef) bool { + if len(c1) != len(c2) { + return false + } + for ci := range c1 { + if c1[ci] != c2[ci] { + return false + } + } + return true + } + + for ii := range haystack { + if compareAttr(haystack[ii].Attr, needle.Attr) && + compareColumns(haystack[ii].Columns, needle.Columns) { + ir := schema.IndexRef(ii) + return ir, true + } + } + return 0, false +} + +func checkCreateTableStmt(n *ast.CreateTableStmtNode, s *schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { + + fn := "CheckCreateTableStmt" + hasError := false + + if c.Begin() != 0 { + panic("schema cache must not have any open scope") + } + defer func() { + if hasError { + c.Rollback() + return + } + c.Commit() + }() + + // Return early if there are too many tables. We cannot ignore this error + // because it will overflow schema.TableRef, which is used as a part of + // column key in schemaCache. + if len(*s) > schema.MaxTableRef { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyTables, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d tables", + schema.MaxTableRef+1), + }, &hasError) + return + } + + table := schema.Table{} + tr := schema.TableRef(len(*s)) + td := schema.TableDescriptor{Table: tr} + + if len(n.Table.Name) == 0 { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyTableName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot create a table with an empty name", + }, &hasError) + } + + tn := n.Table.Name + if !c.AddTable(string(tn), td) { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateTableName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("table %s already exists", + ast.QuoteIdentifier(tn)), + }, &hasError) + } + table.Name = n.Table.Name + table.Columns = make([]schema.Column, 0, len(n.Column)) + + // Handle the primary key index. + pk := []schema.ColumnRef{} + // Handle sequences. + seq := 0 + // Handle indices for unique constraints. + type localIndex struct { + index schema.Index + node ast.Node + } + localIndices := []localIndex{} + // Handle indices for foreign key constraints. + type foreignNewIndex struct { + table schema.TableDescriptor + index schema.Index + node ast.Node + } + foreignNewIndices := []foreignNewIndex{} + type foreignExistingIndex struct { + index schema.IndexDescriptor + node ast.Node + } + foreignExistingIndices := []foreignExistingIndex{} + + for ci := range n.Column { + if len(table.Columns) > schema.MaxColumnRef { + el.Append(errors.Error{ + Position: n.Column[ci].GetPosition(), + Length: n.Column[ci].GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyColumns, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d columns", + schema.MaxColumnRef+1), + }, &hasError) + return + } + + column := schema.Column{} + ok := func() (ok bool) { + innerHasError := false + defer func() { ok = !innerHasError }() + + // Block access to the outer hasError variable. + hasError := struct{}{} + // Suppress “declared and not used” error. + _ = hasError + + c.Begin() + defer func() { + if innerHasError { + c.Rollback() + return + } + c.Commit() + }() + + cr := schema.ColumnRef(len(table.Columns)) + cd := schema.ColumnDescriptor{Table: tr, Column: cr} + + if len(n.Column[ci].Column.Name) == 0 { + el.Append(errors.Error{ + Position: n.Column[ci].Column.GetPosition(), + Length: n.Column[ci].Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyColumnName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot declare a column with an empty name", + }, &innerHasError) + } + + cn := n.Column[ci].Column.Name + if !c.AddColumn(string(cn), cd) { + el.Append(errors.Error{ + Position: n.Column[ci].Column.GetPosition(), + Length: n.Column[ci].Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateColumnName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("column %s already exists", + ast.QuoteIdentifier(cn)), + }, &innerHasError) + } else { + column.Name = n.Column[ci].Column.Name + } + + dt, code, message := n.Column[ci].DataType.GetType() + if code == errors.ErrorCodeNil { + if !dt.ValidColumn() { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidColumnDataType, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot declare a column with type %s", dt.String()), + }, &innerHasError) + } + } else { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: code, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: message, + }, &innerHasError) + } + column.Type = dt + + // Backup lengths of slices in case we have to rollback. We don't + // have to copy slice headers or data stored in underlying arrays + // because we always append data at the end. + defer func(LPK, SEQ, LLI, LFNI, LFEI int) { + if innerHasError { + pk = pk[:LPK] + seq = SEQ + localIndices = localIndices[:LLI] + foreignNewIndices = foreignNewIndices[:LFNI] + foreignExistingIndices = foreignExistingIndices[:LFEI] + } + }( + len(pk), seq, len(localIndices), len(foreignNewIndices), + len(foreignExistingIndices), + ) + + // cs -> constraint + // csi -> constraint index + for csi := range n.Column[ci].Constraint { + // Cases are sorted in the same order as internal/grammar.peg. + cs: + switch cs := n.Column[ci].Constraint[csi].(type) { + case *ast.PrimaryOptionNode: + pk = append(pk, cr) + column.Attr |= schema.ColumnAttrPrimaryKey + + case *ast.NotNullOptionNode: + column.Attr |= schema.ColumnAttrNotNull + + case *ast.UniqueOptionNode: + if (column.Attr & schema.ColumnAttrUnique) != 0 { + // Don't create duplicate indices on a column. + break cs + } + column.Attr |= schema.ColumnAttrUnique + indexName := fmt.Sprintf("%s_%s_unique", + table.Name, column.Name) + idx := schema.Index{ + Name: []byte(indexName), + Attr: schema.IndexAttrUnique, + Columns: []schema.ColumnRef{cr}, + } + localIndices = append(localIndices, localIndex{ + index: idx, + node: cs, + }) + + case *ast.DefaultOptionNode: + if (column.Attr & schema.ColumnAttrHasDefault) != 0 { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeMultipleDefaultValues, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot have multiple default values", + }, &innerHasError) + break cs + } + column.Attr |= schema.ColumnAttrHasDefault + + value := cs.Value + value = checkExpr(cs.Value, *s, o|CheckWithConstantOnly, + c, el, 0, newTypeActionAssign(column.Type)) + if value == nil { + innerHasError = true + break cs + } + cs.Value = value + + switch v := cs.Value.(ast.Valuer).(type) { + case *ast.BoolValueNode: + sb := v.V.NullBool() + if !sb.Valid { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNullDefaultValue, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "default value must not be NULL", + }, &innerHasError) + break cs + } + column.Default = sb.Bool + + case *ast.AddressValueNode: + column.Default = v.V + + case *ast.IntegerValueNode: + column.Default = v.V + + case *ast.DecimalValueNode: + column.Default = v.V + + case *ast.BytesValueNode: + column.Default = v.V + + case *ast.NullValueNode: + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNullDefaultValue, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "default value must not be NULL", + }, &innerHasError) + break cs + + default: + panic(unknownValueNodeType(v)) + } + + case *ast.ForeignOptionNode: + if len(column.ForeignKeys) > schema.MaxForeignKeys { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyForeignKeys, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d foreign key "+ + "constraints in a column", + schema.MaxForeignKeys+1), + }, &innerHasError) + break cs + } + column.Attr |= schema.ColumnAttrHasForeignKey + ftn := cs.Table.Name + ftd, found := c.FindTableInBase(string(ftn)) + if !found { + el.Append(errors.Error{ + Position: cs.Table.GetPosition(), + Length: cs.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTableNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "foreign table %s does not exist", + ast.QuoteIdentifier(ftn)), + }, &innerHasError) + break cs + } + fcn := cs.Column.Name + fcd, found := c.FindColumnInBase(ftd.Table, string(fcn)) + if !found { + el.Append(errors.Error{ + Position: cs.Column.GetPosition(), + Length: cs.Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s does not exist in foreign table %s", + ast.QuoteIdentifier(fcn), + ast.QuoteIdentifier(ftn)), + }, &innerHasError) + break cs + } + foreignType := (*s)[fcd.Table].Columns[fcd.Column].Type + if !foreignType.Equal(column.Type) { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeForeignKeyDataTypeMismatch, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "foreign column has type %s (%04x), but "+ + "this column has type %s (%04x)", + foreignType.String(), uint16(foreignType), + column.Type.String(), uint16(column.Type)), + }, &innerHasError) + break cs + } + + idx := schema.Index{ + Attr: schema.IndexAttrReferenced, + Columns: []schema.ColumnRef{fcd.Column}, + } + fmir, found := findFirstMatchingIndex( + (*s)[ftd.Table].Indices, idx, schema.IndexAttrUnique) + if found { + fmid := schema.IndexDescriptor{ + Table: ftd.Table, + Index: fmir, + } + foreignExistingIndices = append( + foreignExistingIndices, foreignExistingIndex{ + index: fmid, + node: cs, + }) + } else { + if len(column.ForeignKeys) > 0 { + idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key_%d", + table.Name, column.Name, len(column.ForeignKeys))) + } else { + idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key", + table.Name, column.Name)) + } + foreignNewIndices = append( + foreignNewIndices, foreignNewIndex{ + table: ftd, + index: idx, + node: cs, + }) + } + column.ForeignKeys = append(column.ForeignKeys, fcd) + + case *ast.AutoIncrementOptionNode: + if (column.Attr & schema.ColumnAttrHasSequence) != 0 { + // Don't process AUTOINCREMENT twice. + break cs + } + // We set the flag regardless of the error because we + // don't want to produce duplicate errors. + column.Attr |= schema.ColumnAttrHasSequence + if seq > schema.MaxSequenceRef { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManySequences, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d sequences", + schema.MaxSequenceRef+1), + }, &innerHasError) + break cs + } + major, _ := ast.DecomposeDataType(column.Type) + switch major { + case ast.DataTypeMajorInt, ast.DataTypeMajorUint: + default: + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidAutoIncrementDataType, + Prefix: fn, + Message: fmt.Sprintf( + "AUTOINCREMENT is only supported on "+ + "INT and UINT types, but this column "+ + "has type %s (%04x)", + column.Type.String(), uint16(column.Type)), + }, &innerHasError) + break cs + } + column.Sequence = schema.SequenceRef(seq) + seq++ + + default: + panic(fmt.Sprintf("unknown constraint type %T", c)) + } + } + + // The return value will be set by the first defer function. + return + }() + + // If an error occurs in the function, stop here and continue + // processing the next column. + if !ok { + hasError = true + continue + } + + // Commit the column. + table.Columns = append(table.Columns, column) + } + + // Return early if there is any error. + if hasError { + return + } + + mustAddIndex := func(name *[]byte, id schema.IndexDescriptor) { + for !c.AddIndex(string(*name), id, true) { + *name = append(*name, '_') + } + } + + // Create the primary key index. This is the first index on the table, so + // it is not possible to exceed the limit on the number of indices. + ir := schema.IndexRef(len(table.Indices)) + if len(pk) > 0 { + idx := schema.Index{ + Name: []byte(fmt.Sprintf("%s_primary_key", table.Name)), + Attr: schema.IndexAttrUnique, + Columns: pk, + } + id := schema.IndexDescriptor{Table: tr, Index: ir} + mustAddIndex(&idx.Name, id) + table.Indices = append(table.Indices, idx) + } + + // Create indices for the current table. + for ii := range localIndices { + if len(table.Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: localIndices[ii].node.GetPosition(), + Length: localIndices[ii].node.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d indices", + schema.MaxIndexRef+1), + }, &hasError) + return + } + idx := localIndices[ii].index + ir := schema.IndexRef(len(table.Indices)) + id := schema.IndexDescriptor{Table: tr, Index: ir} + mustAddIndex(&idx.Name, id) + table.Indices = append(table.Indices, idx) + } + + // Create indices for foreign tables. + for ii := range foreignNewIndices { + ftd := foreignNewIndices[ii].table + if len((*s)[ftd.Table].Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: foreignNewIndices[ii].node.GetPosition(), + Length: foreignNewIndices[ii].node.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "table %s already has %d indices", + ast.QuoteIdentifier((*s)[ftd.Table].Name), + schema.MaxIndexRef+1), + }, &hasError) + return + } + idx := foreignNewIndices[ii].index + ir := schema.IndexRef(len((*s)[ftd.Table].Indices)) + id := schema.IndexDescriptor{Table: ftd.Table, Index: ir} + mustAddIndex(&idx.Name, id) + (*s)[ftd.Table].Indices = append((*s)[ftd.Table].Indices, idx) + defer func(tr schema.TableRef, length schema.IndexRef) { + if hasError { + (*s)[tr].Indices = (*s)[tr].Indices[:ir] + } + }(ftd.Table, ir) + } + + // Mark existing indices as referenced. + for ii := range foreignExistingIndices { + fid := foreignExistingIndices[ii].index + (*s)[fid.Table].Indices[fid.Index].Attr |= schema.IndexAttrReferenced + } + + // Finally, we can commit the table definition to the schema. + *s = append(*s, table) +} + +func checkCreateIndexStmt(n *ast.CreateIndexStmtNode, s *schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { + + fn := "CheckCreateIndexStmt" + hasError := false + + if c.Begin() != 0 { + panic("schema cache must not have any open scope") + } + defer func() { + if hasError { + c.Rollback() + return + } + c.Commit() + }() + + tn := n.Table.Name + td, found := c.FindTableInBase(string(tn)) + if !found { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTableNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "index table %s does not exist", + ast.QuoteIdentifier(tn)), + }, &hasError) + return + } + + if len(n.Column) > schema.MaxColumnRef { + begin := n.Column[0].GetPosition() + last := len(n.Column) - 1 + end := n.Column[last].GetPosition() + n.Column[last].GetLength() + el.Append(errors.Error{ + Position: begin, + Length: end - begin, + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyColumns, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot create an index on more than %d columns", + schema.MaxColumnRef+1), + }, &hasError) + return + } + + columnRefs := newColumnRefSlice(uint8(len(n.Column))) + for ci := range n.Column { + cn := n.Column[ci].Name + cd, found := c.FindColumnInBase(td.Table, string(cn)) + if !found { + el.Append(errors.Error{ + Position: n.Column[ci].GetPosition(), + Length: n.Column[ci].GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s does not exist in index table %s", + ast.QuoteIdentifier(cn), + ast.QuoteIdentifier(tn)), + }, &hasError) + continue + } + columnRefs.Append(cd.Column, uint8(ci)) + } + if hasError { + return + } + + sort.Stable(columnRefs) + for ci := 1; ci < len(n.Column); ci++ { + if columnRefs.columns[ci] == columnRefs.columns[ci-1] { + el.Append(errors.Error{ + Position: n.Column[columnRefs.nodes[ci]].GetPosition(), + Length: n.Column[columnRefs.nodes[ci]].GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateIndexColumn, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s already exists in the column list", + ast.QuoteIdentifier(n.Column[columnRefs.nodes[ci]].Name)), + }, &hasError) + return + } + } + + index := schema.Index{} + index.Columns = columnRefs.columns + + if len((*s)[td.Table].Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d indices in table %s", + schema.MaxIndexRef+1, + ast.QuoteIdentifier(tn)), + }, &hasError) + return + } + + ir := schema.IndexRef(len((*s)[td.Table].Indices)) + id := schema.IndexDescriptor{Table: td.Table, Index: ir} + if n.Unique != nil { + index.Attr |= schema.IndexAttrUnique + } + + if len(n.Index.Name) == 0 { + el.Append(errors.Error{ + Position: n.Index.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyIndexName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot create an index with an empty name", + }, &hasError) + return + } + + // If there is an existing index that is automatically created, rename it + // instead of creating a new one. + rename := false + fmir, found := findFirstMatchingIndex((*s)[id.Table].Indices, index, 0) + if found { + fmid := schema.IndexDescriptor{Table: id.Table, Index: fmir} + fmin := (*s)[id.Table].Indices[fmir].Name + fminString := string(fmin) + fmidCache, auto, found := c.FindIndexInBase(fminString) + if !found { + panic(fmt.Sprintf("index %s exists in the schema, "+ + "but it cannot be found in the schema cache", + ast.QuoteIdentifier(fmin))) + } + if fmidCache != fmid { + panic(fmt.Sprintf("index %s has descriptor %+v, "+ + "but the schema cache records it as %+v", + ast.QuoteIdentifier(fmin), fmid, fmidCache)) + } + if auto { + if !c.DeleteIndex(fminString) { + panic(fmt.Sprintf("unable to mark index %s for deletion", + ast.QuoteIdentifier(fmin))) + } + rename = true + id = fmid + ir = id.Index + } + } + + in := n.Index.Name + if !c.AddIndex(string(in), id, false) { + el.Append(errors.Error{ + Position: n.Index.GetPosition(), + Length: n.Index.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateIndexName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("index %s already exists", + ast.QuoteIdentifier(in)), + }, &hasError) + return + } + + // Commit the change into the schema. + if rename { + (*s)[id.Table].Indices[id.Index].Name = n.Index.Name + } else { + index.Name = n.Index.Name + (*s)[id.Table].Indices = append((*s)[id.Table].Indices, index) + } +} + +func checkSelectStmt(n *ast.SelectStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkUpdateStmt(n *ast.UpdateStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkDeleteStmt(n *ast.DeleteStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkInsertStmt(n *ast.InsertStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkExpr(n ast.ExprNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + switch n := n.(type) { + case *ast.IdentifierNode: + return checkVariable(n, s, o, c, el, tr, ta) + + case *ast.BoolValueNode: + return checkBoolValue(n, o, el, ta) + + case *ast.AddressValueNode: + return checkAddressValue(n, o, el, ta) + + case *ast.IntegerValueNode: + return checkIntegerValue(n, o, el, ta) + + case *ast.DecimalValueNode: + return checkDecimalValue(n, o, el, ta) + + case *ast.BytesValueNode: + return checkBytesValue(n, o, el, ta) + + case *ast.NullValueNode: + return checkNullValue(n, o, el, ta) + + case *ast.PosOperatorNode: + return checkPosOperator(n, s, o, c, el, tr, ta) + + case *ast.NegOperatorNode: + return checkNegOperator(n, s, o, c, el, tr, ta) + + case *ast.NotOperatorNode: + return n + + case *ast.ParenOperatorNode: + return n + + case *ast.AndOperatorNode: + return n + + case *ast.OrOperatorNode: + return n + + case *ast.GreaterOrEqualOperatorNode: + return n + + case *ast.LessOrEqualOperatorNode: + return n + + case *ast.NotEqualOperatorNode: + return n + + case *ast.EqualOperatorNode: + return n + + case *ast.GreaterOperatorNode: + return n + + case *ast.LessOperatorNode: + return n + + case *ast.ConcatOperatorNode: + return n + + case *ast.AddOperatorNode: + return n + + case *ast.SubOperatorNode: + return n + + case *ast.MulOperatorNode: + return n + + case *ast.DivOperatorNode: + return n + + case *ast.ModOperatorNode: + return n + + case *ast.IsOperatorNode: + return n + + case *ast.LikeOperatorNode: + return n + + case *ast.CastOperatorNode: + return n + + case *ast.InOperatorNode: + return n + + case *ast.FunctionOperatorNode: + return n + + default: + panic(fmt.Sprintf("unknown expression type %T", n)) + } +} + +func elAppendTypeErrorMismatch(el *errors.ErrorList, n ast.ExprNode, + fn string, dtExpected, dtGiven ast.DataType) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s (%04x) is given", + dtExpected.String(), uint16(dtExpected), + dtGiven.String(), uint16(dtGiven)), + }, nil) +} +func checkVariable(n *ast.IdentifierNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + fn := "CheckVariable" + + if (o & CheckWithConstantOnly) != 0 { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNonConstantExpression, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("%s is not a constant", + ast.QuoteIdentifier(n.Name)), + }, nil) + return nil + } + + cn := string(n.Name) + cd, found := c.FindColumnInBaseWithFallback(tr, cn, s) + if !found { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot find column %s in table %s", + ast.QuoteIdentifier(n.Name), + ast.QuoteIdentifier(s[tr].Name)), + }, nil) + return nil + } + + cr := cd.Column + dt := s[tr].Columns[cr].Type + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + if !dt.Equal(a.dt) { + elAppendTypeErrorMismatch(el, n, fn, a.dt, dt) + return nil + } + } + + n.SetType(dt) + n.Desc = cd + return n +} + +func unknownValueNodeType(n ast.Valuer) string { + return fmt.Sprintf("unknown constant type %T", n) +} + +func describeValueNodeType(n ast.Valuer) string { + switch n.(type) { + case *ast.BoolValueNode: + return "boolean constant" + case *ast.AddressValueNode: + return "address constant" + case *ast.IntegerValueNode, *ast.DecimalValueNode: + return "number constant" + case *ast.BytesValueNode: + return "bytes constant" + case *ast.NullValueNode: + return "null constant" + default: + panic(unknownValueNodeType(n)) + } +} + +func elAppendTypeErrorValueNode(el *errors.ErrorList, n ast.Valuer, + fn string, dt ast.DataType) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s is given", + dt.String(), uint16(dt), describeValueNodeType(n)), + }, nil) +} + +func checkBoolValue(n *ast.BoolValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckBoolValue" + + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + major, _ := ast.DecomposeDataType(a.dt) + if major != ast.DataTypeMajorBool { + elAppendTypeErrorValueNode(el, n, fn, a.dt) + return nil + } + } + return n +} + +func checkAddressValue(n *ast.AddressValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckAddressValue" + + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + major, _ := ast.DecomposeDataType(a.dt) + if major != ast.DataTypeMajorAddress { + elAppendTypeErrorValueNode(el, n, fn, a.dt) + return nil + } + } + return n +} + +func mustGetMinMax(dt ast.DataType) (decimal.Decimal, decimal.Decimal) { + min, max, ok := dt.GetMinMax() + if !ok { + panic(fmt.Sprintf("GetMinMax does not handle %v", dt)) + } + return min, max +} + +func mustDecimalEncode(dt ast.DataType, d decimal.Decimal) []byte { + b, ok := ast.DecimalEncode(dt, d) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", dt)) + } + return b +} + +func mustDecimalDecode(dt ast.DataType, b []byte) decimal.Decimal { + d, ok := ast.DecimalDecode(dt, b) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", dt)) + } + return d +} + +func cropDecimal(dt ast.DataType, d decimal.Decimal) decimal.Decimal { + b := mustDecimalEncode(dt, d) + return mustDecimalDecode(dt, b) +} + +func elAppendConstantTooLongError(el *errors.ErrorList, n ast.Valuer, + fn string, v decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeConstantTooLong, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "constant expression %s has more than %d digits", + ast.QuoteString(n.GetToken()), MaxIntegerPartDigits), + }, nil) +} + +func elAppendOverflowError(el *errors.ErrorList, n ast.Valuer, + fn string, dt ast.DataType, v, min, max decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeOverflow, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "number %s (%s) overflows %s (%04x)", + ast.QuoteString(n.GetToken()), v.String(), + dt.String(), uint16(dt)), + }, nil) + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: 0, + Code: 0, + Severity: errors.ErrorSeverityNote, + Prefix: fn, + Message: fmt.Sprintf( + "the range of %s is [%s, %s]", + dt.String(), min.String(), max.String()), + }, nil) +} + +func elAppendOverflowWarning(el *errors.ErrorList, n ast.Valuer, + fn string, dt ast.DataType, from, to decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: 0, + Code: 0, + Severity: errors.ErrorSeverityWarning, + Prefix: fn, + Message: fmt.Sprintf( + "number %s (%s) overflows %s (%04x), converted to %s", + ast.QuoteString(n.GetToken()), from.String(), + dt.String(), uint16(dt), to.String()), + }, nil) +} + +func checkIntegerValue(n *ast.IntegerValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckIntegerValue" + + normalizeDecimal(&n.V) + if !safeDecimalRange(n.V) { + elAppendConstantTooLongError(el, n, fn, n.V) + return nil + } + + infer := func(size int) (ast.DataType, bool) { + // The first case: assume V fits in the signed integer. + minor := ast.DataTypeMinor(size - 1) + dt := ast.ComposeDataType(ast.DataTypeMajorInt, minor) + min, max := mustGetMinMax(dt) + // Return if V < min. V must be negative so it must be signed. + if n.V.LessThan(min) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + // We are done if V fits in the signed integer. + if n.V.LessThanOrEqual(max) { + return dt, true + } + + // The second case: V is a non-negative integer, but it does not fit + // in the signed integer. Test whether the unsigned integer works. + dt = ast.ComposeDataType(ast.DataTypeMajorUint, minor) + min, max = mustGetMinMax(dt) + // Return if V > max. We don't have to test whether V < min because min + // is always zero and we already know V is non-negative. + if n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + return dt, true + } + + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + // Default to int256 or uint256. + var ok bool + dt, ok = infer(256 / 8) + if !ok { + return nil + } + + case typeActionInferWithSize: + var ok bool + dt, ok = infer(a.size) + if !ok { + return nil + } + + case typeActionAssign: + dt = a.dt + major, _ := ast.DecomposeDataType(dt) + switch { + case major == ast.DataTypeMajorAddress: + if !n.IsAddress { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidAddressChecksum, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s is not an address constant", + dt.String(), uint16(dt), n.GetToken()), + }, nil) + return nil + } + // Redirect to checkAddressValue if it becomes an address. + addrNode := &ast.AddressValueNode{} + addrNode.SetPosition(addrNode.GetPosition()) + addrNode.SetLength(addrNode.GetLength()) + addrNode.SetToken(addrNode.GetToken()) + addrNode.V = mustDecimalEncode(ast.ComposeDataType( + ast.DataTypeMajorUint, ast.DataTypeMinor(160/8-1)), n.V) + return checkAddressValue(addrNode, o, el, ta) + + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint, + major.IsFixedRange(), + major.IsUfixedRange(): + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) || n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return nil + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + } + + default: + elAppendTypeErrorValueNode(el, n, fn, dt) + return nil + } + } + + if !dt.Pending() { + n.SetType(dt) + } + return n +} + +func checkDecimalValue(n *ast.DecimalValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckDecimalValue" + + normalizeDecimal(&n.V) + if !safeDecimalRange(n.V) { + elAppendConstantTooLongError(el, n, fn, n.V) + return nil + } + + // Redirect to checkIntegerValue if the value is an integer. + if intPart := n.V.Truncate(0); n.V.Equal(intPart) { + intNode := &ast.IntegerValueNode{} + intNode.SetPosition(n.GetPosition()) + intNode.SetLength(n.GetLength()) + intNode.SetToken(n.GetToken()) + intNode.SetType(n.GetType()) + intNode.IsAddress = false + intNode.V = n.V + return checkIntegerValue(intNode, o, el, ta) + } + + infer := func(size, fractionalDigits int) (ast.DataType, bool) { + major := ast.DataTypeMajorFixed + ast.DataTypeMajor(size-1) + minor := ast.DataTypeMinor(fractionalDigits) + dt := ast.ComposeDataType(major, minor) + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, false + } + if n.V.LessThanOrEqual(max) { + return dt, true + } + + major = ast.DataTypeMajorUfixed + ast.DataTypeMajor(size-1) + minor = ast.DataTypeMinor(fractionalDigits) + dt = ast.ComposeDataType(major, minor) + min, max = mustGetMinMax(dt) + if n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + return dt, true + } + + // Now we are sure the number we are dealing has fractional part. + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + // Default to fixed128x18 and ufixed128x18. + var ok bool + dt, ok = infer(128/8, 18) + if !ok { + return nil + } + + case typeActionInferWithSize: + // It is unclear that what the size hint means for fixed-point numbers, + // so we just ignore it and do the same thing as the above case. + var ok bool + dt, ok = infer(128/8, 18) + if !ok { + return nil + } + + case typeActionAssign: + dt = a.dt + major, _ := ast.DecomposeDataType(dt) + switch { + case major.IsFixedRange(), + major.IsUfixedRange(): + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) || n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return nil + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + } + + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint: + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but the number %s has fractional part", + dt.String(), uint16(dt), n.V.String()), + }, nil) + return nil + + default: + elAppendTypeErrorValueNode(el, n, fn, dt) + return nil + } + } + + if !dt.Pending() { + n.SetType(dt) + _, minor := ast.DecomposeDataType(dt) + fractionalDigits := int32(minor) + n.V = n.V.Round(fractionalDigits) + normalizeDecimal(&n.V) + } + return n +} + +func checkBytesValue(n *ast.BytesValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckBytesValue" + + dt := ast.DataTypePending + +executeTypeAction: + switch a := ta.(type) { + case typeActionInferDefault: + // Default to bytes. + major := ast.DataTypeMajorDynamicBytes + minor := ast.DataTypeMinorDontCare + dt = ast.ComposeDataType(major, minor) + ta = newTypeActionAssign(dt) + goto executeTypeAction + + case typeActionInferWithSize: + major := ast.DataTypeMajorFixedBytes + minor := ast.DataTypeMinor(a.size - 1) + dt = ast.ComposeDataType(major, minor) + ta = newTypeActionAssign(dt) + goto executeTypeAction + + case typeActionAssign: + dt = a.dt + major, minor := ast.DecomposeDataType(dt) + switch major { + case ast.DataTypeMajorDynamicBytes: + // Do nothing because it is always valid. + + case ast.DataTypeMajorFixedBytes: + sizeGiven := len(n.V) + sizeExpected := int(minor) + 1 + if sizeGiven != sizeExpected { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s has %d bytes", + dt.String(), uint16(dt), + ast.QuoteString(n.V), sizeGiven), + }, nil) + return nil + } + + default: + elAppendTypeErrorValueNode(el, n, fn, dt) + return nil + } + } + + if !dt.Pending() { + n.SetType(dt) + } + return n +} + +func checkNullValue(n *ast.NullValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + dt = ast.DataTypeNull + case typeActionInferWithSize: + dt = ast.DataTypeNull + case typeActionAssign: + dt = a.dt + } + + if !dt.Pending() { + n.SetType(dt) + } + return n +} + +func elAppendTypeErrorOperatorValueNode(el *errors.ErrorList, n ast.Valuer, + fn string, op string) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("%s is not defined for %s", + op, describeValueNodeType(n)), + }, nil) +} + +func elAppendTypeErrorOperatorDataType(el *errors.ErrorList, n ast.ExprNode, + fn string, op string, dt ast.DataType) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("%s is not defined for %s (%04x)", + op, dt.String(), uint16(dt)), + }, nil) +} + +func checkPosOperator(n *ast.PosOperatorNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + fn := "CheckPosOperator" + op := "unary operator +" + + target := n.GetTarget() + target = checkExpr(target, s, o, c, el, tr, nil) + if target == nil { + return nil + } + r := ast.ExprNode(target) + + dtTarget := target.GetType() + if !dtTarget.Pending() { + major, _ := ast.DecomposeDataType(dtTarget) + switch { + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint, + major.IsFixedRange(), + major.IsUfixedRange(): + default: + elAppendTypeErrorOperatorDataType(el, target, fn, op, dtTarget) + return nil + } + } + dt := dtTarget + + if target, ok := target.(ast.Valuer); ok { + switch v := target.(type) { + case *ast.IntegerValueNode: + node := &ast.IntegerValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + node.IsAddress = false + node.V = v.V + r = node + + case *ast.DecimalValueNode: + node := &ast.DecimalValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + node.V = v.V + r = node + + case *ast.NullValueNode: + if dt.Pending() { + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + } + node := &ast.NullValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + r = node + + case *ast.BoolValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + case *ast.AddressValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + case *ast.BytesValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + default: + panic(unknownValueNodeType(v)) + } + } + + if dt.Pending() { + r = checkExpr(r, s, o, c, el, tr, ta) + } else { + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + if !dt.Equal(a.dt) { + elAppendTypeErrorMismatch(el, n, fn, a.dt, dt) + return nil + } + } + } + return r +} + +func checkNegOperator(n *ast.NegOperatorNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + fn := "CheckNegOperator" + op := "unary operator -" + + target := n.GetTarget() + target = checkExpr(target, s, o, c, el, tr, nil) + if target == nil { + return nil + } + n.SetTarget(target) + r := ast.ExprNode(n) + + dtTarget := target.GetType() + if !dtTarget.Pending() { + major, _ := ast.DecomposeDataType(dtTarget) + switch { + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint, + major.IsFixedRange(), + major.IsUfixedRange(): + default: + elAppendTypeErrorOperatorDataType(el, target, fn, op, dtTarget) + return nil + } + } + dt := dtTarget + + eval := func(n ast.Valuer, v decimal.Decimal) (decimal.Decimal, bool) { + r := v.Neg() + if !dt.Pending() { + min, max := mustGetMinMax(dt) + if r.LessThan(min) || r.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, r, min, max) + return r, false + } + cropped := cropDecimal(dt, r) + elAppendOverflowWarning(el, n, fn, dt, r, cropped) + r = cropped + } + } + normalizeDecimal(&r) + return r, true + } + if target, ok := target.(ast.Valuer); ok { + switch v := target.(type) { + case *ast.IntegerValueNode: + node := &ast.IntegerValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + node.IsAddress = false + node.V, ok = eval(node, v.V) + if !ok { + return nil + } + r = node + + case *ast.DecimalValueNode: + node := &ast.DecimalValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + node.V, ok = eval(node, v.V) + if !ok { + return nil + } + r = node + + case *ast.NullValueNode: + if dt.Pending() { + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + } + node := &ast.NullValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.SetType(dt) + r = node + + case *ast.BoolValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + case *ast.AddressValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + case *ast.BytesValueNode: + elAppendTypeErrorOperatorValueNode(el, v, fn, op) + return nil + default: + panic(unknownValueNodeType(v)) + } + } + + if dt.Pending() { + r = checkExpr(r, s, o, c, el, tr, ta) + } else { + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + if !dt.Equal(a.dt) { + elAppendTypeErrorMismatch(el, n, fn, a.dt, dt) + return nil + } + } + } + return r +} diff --git a/core/vm/sqlvm/checker/utils.go b/core/vm/sqlvm/checker/utils.go new file mode 100644 index 000000000..34b73af4a --- /dev/null +++ b/core/vm/sqlvm/checker/utils.go @@ -0,0 +1,493 @@ +package checker + +import ( + "github.com/dexon-foundation/decimal" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// Variable name convention: +// +// fn -> function name +// el -> error list +// +// td -> table descriptor +// tr -> table reference +// ti -> table index +// tn -< table name +// +// cd -> column descriptor +// cr -> column reference +// ci -> column index +// cn -> column name +// +// id -> index descriptor +// ir -> index reference +// ii -> index index +// in -> index name + +const ( + MaxIntegerPartDigits int32 = 200 + MaxFractionalPartDigits int32 = 200 +) + +var ( + MaxConstant = func() decimal.Decimal { + max := (decimal.New(1, MaxIntegerPartDigits). + Sub(decimal.New(1, -MaxFractionalPartDigits))) + normalizeDecimal(&max) + return max + }() + MinConstant = MaxConstant.Neg() +) + +func normalizeDecimal(d *decimal.Decimal) { + if d.Exponent() != -MaxFractionalPartDigits { + *d = d.Rescale(-MaxFractionalPartDigits) + } +} + +func safeDecimalRange(d decimal.Decimal) bool { + return d.GreaterThanOrEqual(MinConstant) && d.LessThanOrEqual(MaxConstant) +} + +// schemaCache is a multi-layer symbol table used to support the checker. +// It allows changes to be easily rolled back by keeping modifications in a +// separate layer, providing an experience similar to a database transaction. +type schemaCache struct { + base schemaCacheBase + scopes []schemaCacheScope +} + +type schemaCacheIndexValue struct { + id schema.IndexDescriptor + auto bool +} + +type schemaCacheColumnKey struct { + tr schema.TableRef + n string +} + +type schemaCacheBase struct { + table map[string]schema.TableDescriptor + index map[string]schemaCacheIndexValue + column map[schemaCacheColumnKey]schema.ColumnDescriptor +} + +func (lower *schemaCacheBase) Merge(upper schemaCacheScope) { + // Process deletions. + for n := range upper.tableDeleted { + delete(lower.table, n) + } + for n := range upper.indexDeleted { + delete(lower.index, n) + } + for ck := range upper.columnDeleted { + delete(lower.column, ck) + } + + // Process additions. + for n, td := range upper.table { + lower.table[n] = td + } + for n, iv := range upper.index { + lower.index[n] = iv + } + for ck, cd := range upper.column { + lower.column[ck] = cd + } +} + +type schemaCacheScope struct { + table map[string]schema.TableDescriptor + tableDeleted map[string]struct{} + index map[string]schemaCacheIndexValue + indexDeleted map[string]struct{} + column map[schemaCacheColumnKey]schema.ColumnDescriptor + columnDeleted map[schemaCacheColumnKey]struct{} +} + +func (lower *schemaCacheScope) Merge(upper schemaCacheScope) { + // Process deletions. + for n := range upper.tableDeleted { + delete(lower.table, n) + lower.tableDeleted[n] = struct{}{} + } + for n := range upper.indexDeleted { + delete(lower.index, n) + lower.indexDeleted[n] = struct{}{} + } + for ck := range upper.columnDeleted { + delete(lower.column, ck) + lower.columnDeleted[ck] = struct{}{} + } + + // Process additions. + for n, td := range upper.table { + lower.table[n] = td + } + for n, iv := range upper.index { + lower.index[n] = iv + } + for ck, cd := range upper.column { + lower.column[ck] = cd + } +} + +func newSchemaCache() *schemaCache { + return &schemaCache{ + base: schemaCacheBase{ + table: map[string]schema.TableDescriptor{}, + index: map[string]schemaCacheIndexValue{}, + column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, + }, + } +} + +func (c *schemaCache) Begin() int { + position := len(c.scopes) + scope := schemaCacheScope{ + table: map[string]schema.TableDescriptor{}, + tableDeleted: map[string]struct{}{}, + index: map[string]schemaCacheIndexValue{}, + indexDeleted: map[string]struct{}{}, + column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, + columnDeleted: map[schemaCacheColumnKey]struct{}{}, + } + c.scopes = append(c.scopes, scope) + return position +} + +func (c *schemaCache) Rollback() { + if len(c.scopes) == 0 { + panic("there is no scope to rollback") + } + c.scopes = c.scopes[:len(c.scopes)-1] +} + +func (c *schemaCache) RollbackTo(position int) { + for position <= len(c.scopes) { + c.Rollback() + } +} + +func (c *schemaCache) Commit() { + if len(c.scopes) == 0 { + panic("there is no scope to commit") + } + if len(c.scopes) == 1 { + c.base.Merge(c.scopes[0]) + } else { + src := len(c.scopes) - 1 + dst := len(c.scopes) - 2 + c.scopes[dst].Merge(c.scopes[src]) + } + c.scopes = c.scopes[:len(c.scopes)-1] +} + +func (c *schemaCache) CommitTo(position int) { + for position <= len(c.scopes) { + c.Commit() + } +} + +func (c *schemaCache) FindTableInBase(n string) ( + schema.TableDescriptor, bool) { + + td, exists := c.base.table[n] + return td, exists +} + +func (c *schemaCache) FindTableInScope(n string) ( + schema.TableDescriptor, bool) { + + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if td, exists := c.scopes[si].table[n]; exists { + return td, true + } + if _, exists := c.scopes[si].tableDeleted[n]; exists { + return schema.TableDescriptor{}, false + } + } + return c.FindTableInBase(n) +} + +func (c *schemaCache) FindTableInBaseWithFallback(n string, + fallback schema.Schema) (schema.TableDescriptor, bool) { + + if td, found := c.FindTableInBase(n); found { + return td, true + } + if fallback == nil { + return schema.TableDescriptor{}, false + } + + s := fallback + for ti := range s { + if n == string(s[ti].Name) { + tr := schema.TableRef(ti) + td := schema.TableDescriptor{Table: tr} + c.base.table[n] = td + return td, true + } + } + return schema.TableDescriptor{}, false +} + +func (c *schemaCache) FindIndexInBase(n string) ( + schema.IndexDescriptor, bool, bool) { + + iv, exists := c.base.index[n] + return iv.id, iv.auto, exists +} + +func (c *schemaCache) FindIndexInScope(n string) ( + schema.IndexDescriptor, bool, bool) { + + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if iv, exists := c.scopes[si].index[n]; exists { + return iv.id, iv.auto, true + } + if _, exists := c.scopes[si].indexDeleted[n]; exists { + return schema.IndexDescriptor{}, false, false + } + } + return c.FindIndexInBase(n) +} + +func (c *schemaCache) FindIndexInBaseWithFallback(n string, + fallback schema.Schema) (schema.IndexDescriptor, bool, bool) { + + if id, auto, found := c.FindIndexInBase(n); found { + return id, auto, true + } + if fallback == nil { + return schema.IndexDescriptor{}, false, false + } + + s := fallback + for ti := range s { + for ii := range s[ti].Indices { + if n == string(s[ti].Indices[ii].Name) { + tr := schema.TableRef(ti) + ir := schema.IndexRef(ii) + id := schema.IndexDescriptor{Table: tr, Index: ir} + iv := schemaCacheIndexValue{id: id, auto: false} + c.base.index[n] = iv + return id, false, true + } + } + } + return schema.IndexDescriptor{}, false, false +} + +func (c *schemaCache) FindColumnInBase(tr schema.TableRef, n string) ( + schema.ColumnDescriptor, bool) { + + cd, exists := c.base.column[schemaCacheColumnKey{tr: tr, n: n}] + return cd, exists +} + +func (c *schemaCache) FindColumnInScope(tr schema.TableRef, n string) ( + schema.ColumnDescriptor, bool) { + + ck := schemaCacheColumnKey{tr: tr, n: n} + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if cd, exists := c.scopes[si].column[ck]; exists { + return cd, true + } + if _, exists := c.scopes[si].columnDeleted[ck]; exists { + return schema.ColumnDescriptor{}, false + } + } + return c.FindColumnInBase(tr, n) +} + +func (c *schemaCache) FindColumnInBaseWithFallback(tr schema.TableRef, n string, + fallback schema.Schema) (schema.ColumnDescriptor, bool) { + + if cd, found := c.FindColumnInBase(tr, n); found { + return cd, true + } + if fallback == nil { + return schema.ColumnDescriptor{}, false + } + + s := fallback + for ci := range s[tr].Columns { + if n == string(s[tr].Columns[ci].Name) { + cr := schema.ColumnRef(ci) + cd := schema.ColumnDescriptor{Table: tr, Column: cr} + ck := schemaCacheColumnKey{tr: tr, n: n} + c.base.column[ck] = cd + return cd, true + } + } + return schema.ColumnDescriptor{}, false +} + +func (c *schemaCache) AddTable(n string, + td schema.TableDescriptor) bool { + + top := len(c.scopes) - 1 + if _, found := c.FindTableInScope(n); found { + return false + } + + c.scopes[top].table[n] = td + return true +} + +func (c *schemaCache) AddIndex(n string, + id schema.IndexDescriptor, auto bool) bool { + + top := len(c.scopes) - 1 + if _, _, found := c.FindIndexInScope(n); found { + return false + } + + iv := schemaCacheIndexValue{id: id, auto: auto} + c.scopes[top].index[n] = iv + return true +} + +func (c *schemaCache) AddColumn(n string, + cd schema.ColumnDescriptor) bool { + + top := len(c.scopes) - 1 + tr := cd.Table + if _, found := c.FindColumnInScope(tr, n); found { + return false + } + + ck := schemaCacheColumnKey{tr: tr, n: n} + c.scopes[top].column[ck] = cd + return true +} + +func (c *schemaCache) DeleteTable(n string) bool { + top := len(c.scopes) - 1 + if _, found := c.FindTableInScope(n); !found { + return false + } + + delete(c.scopes[top].table, n) + c.scopes[top].tableDeleted[n] = struct{}{} + return true +} + +func (c *schemaCache) DeleteIndex(n string) bool { + top := len(c.scopes) - 1 + if _, _, found := c.FindIndexInScope(n); !found { + return false + } + + delete(c.scopes[top].index, n) + c.scopes[top].indexDeleted[n] = struct{}{} + return true +} + +func (c *schemaCache) DeleteColumn(tr schema.TableRef, n string) bool { + top := len(c.scopes) - 1 + if _, found := c.FindColumnInScope(tr, n); !found { + return false + } + + ck := schemaCacheColumnKey{tr: tr, n: n} + delete(c.scopes[top].column, ck) + c.scopes[top].columnDeleted[ck] = struct{}{} + return true +} + +// columnRefSlice implements sort.Interface. It allows sorting a slice of +// schema.ColumnRef while keeping references to AST nodes they originate from. +type columnRefSlice struct { + columns []schema.ColumnRef + nodes []uint8 +} + +func newColumnRefSlice(c uint8) columnRefSlice { + return columnRefSlice{ + columns: make([]schema.ColumnRef, 0, c), + nodes: make([]uint8, 0, c), + } +} + +func (s *columnRefSlice) Append(c schema.ColumnRef, i uint8) { + s.columns = append(s.columns, c) + s.nodes = append(s.nodes, i) +} + +func (s columnRefSlice) Len() int { + return len(s.columns) +} + +func (s columnRefSlice) Less(i, j int) bool { + return s.columns[i] < s.columns[j] +} + +func (s columnRefSlice) Swap(i, j int) { + s.columns[i], s.columns[j] = s.columns[j], s.columns[i] + s.nodes[i], s.nodes[j] = s.nodes[j], s.nodes[i] +} + +// typeAction represents an action on type inference requested from the parent +// node. An action is usually only applied on a single node. It is seldom +// propagated to child nodes because we want to delay the assignment of types +// until it is necessary, making constant operations easier to use without +// being restricted by data types. +//go-sumtype:decl typeAction +type typeAction interface { + ˉtypeAction() +} + +// typeActionInferDefault requests the node to infer the type using its default +// rule. It usually means that the parent node does not care the data type, +// such as the select list in a SELECT statement. It is an advisory request. +// If the type of the node is already determined, it should ignore the request. +type typeActionInferDefault struct{} + +func newTypeActionInferDefaultSize() typeActionInferDefault { + return typeActionInferDefault{} +} + +var _ typeAction = typeActionInferDefault{} + +func (typeActionInferDefault) ˉtypeAction() {} + +// typeActionInferWithSize requests the node to infer the type with size +// requirement. The size is measured in bytes. It is indented to be used in +// CAST to support conversion between integer and fixed-size bytes types. +// It is an advisory request. If the type is already determined, the request is +// ignored and the parent node should be able to handle the problem by itself. +type typeActionInferWithSize struct { + size int +} + +func newTypeActionInferWithSize(bytes int) typeActionInferWithSize { + return typeActionInferWithSize{size: bytes} +} + +var _ typeAction = typeActionInferWithSize{} + +func (typeActionInferWithSize) ˉtypeAction() {} + +type typeActionAssign struct { + dt ast.DataType +} + +// newTypeActionAssign requests the node to have a specific type. It is a +// mandatory request. If the node is unable to meet the requirement, it should +// throw an error. It is not allowed to ignore the request. +func newTypeActionAssign(expected ast.DataType) typeActionAssign { + return typeActionAssign{dt: expected} +} + +var _ typeAction = typeActionAssign{} + +func (typeActionAssign) ˉtypeAction() {} diff --git a/core/vm/sqlvm/checkers/actions.go b/core/vm/sqlvm/checkers/actions.go deleted file mode 100644 index c15029d9a..000000000 --- a/core/vm/sqlvm/checkers/actions.go +++ /dev/null @@ -1,147 +0,0 @@ -package checkers - -import ( - "fmt" - - "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" -) - -// CheckOptions stores boolean options for Check* functions. -type CheckOptions uint32 - -const ( - // CheckWithSafeMath enables overflow and underflow checks during expression - // evaluation. An error will be thrown when the result is out of range. - CheckWithSafeMath CheckOptions = 1 << iota - // CheckWithSafeCast enables overflow and underflow checks during casting. - // An error will be thrown if the value does not fit in the target type. - CheckWithSafeCast - // CheckWithConstantOnly restricts the expression to be a constant. An error - // will be thrown if the expression cannot be folded into a constant. - CheckWithConstantOnly -) - -// CheckCreate runs CREATE commands to generate a database schema. It modifies -// AST in-place during evaluation of expressions. -func CheckCreate(ss []ast.StmtNode, o CheckOptions) (schema.Schema, error) { - fn := "CheckCreate" - s := schema.Schema{} - c := newSchemaCache() - el := errors.ErrorList{} - - for idx := range ss { - if ss[idx] == nil { - continue - } - - switch n := ss[idx].(type) { - case *ast.CreateTableStmtNode: - checkCreateTableStmt(n, &s, o, c, &el) - case *ast.CreateIndexStmtNode: - checkCreateIndexStmt(n, &s, o, c, &el) - default: - el.Append(errors.Error{ - Position: ss[idx].GetPosition(), - Length: ss[idx].GetLength(), - Category: errors.ErrorCategoryCommand, - Code: errors.ErrorCodeDisallowedCommand, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "command %s is not allowed when creating a contract", - ast.QuoteIdentifier(ss[idx].GetVerb())), - }, nil) - } - } - - if len(s) == 0 && len(el) == 0 { - el.Append(errors.Error{ - Position: 0, - Length: 0, - Category: errors.ErrorCategoryCommand, - Code: errors.ErrorCodeNoCommand, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "creating a contract without a table is not allowed", - }, nil) - } - if len(el) != 0 { - return s, el - } - return s, nil -} - -// CheckQuery checks and modifies SELECT commands with a given database schema. -func CheckQuery(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { - fn := "CheckQuery" - c := newSchemaCache() - el := errors.ErrorList{} - - for idx := range ss { - if ss[idx] == nil { - continue - } - - switch n := ss[idx].(type) { - case *ast.SelectStmtNode: - checkSelectStmt(n, s, o, c, &el) - default: - el.Append(errors.Error{ - Position: ss[idx].GetPosition(), - Length: ss[idx].GetLength(), - Category: errors.ErrorCategoryCommand, - Code: errors.ErrorCodeDisallowedCommand, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "command %s is not allowed when calling query", - ast.QuoteIdentifier(ss[idx].GetVerb())), - }, nil) - } - } - if len(el) != 0 { - return el - } - return nil -} - -// CheckExec checks and modifies UPDATE, DELETE, INSERT commands with a given -// database schema. -func CheckExec(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { - fn := "CheckExec" - c := newSchemaCache() - el := errors.ErrorList{} - - for idx := range ss { - if ss[idx] == nil { - continue - } - - switch n := ss[idx].(type) { - case *ast.UpdateStmtNode: - checkUpdateStmt(n, s, o, c, &el) - case *ast.DeleteStmtNode: - checkDeleteStmt(n, s, o, c, &el) - case *ast.InsertStmtNode: - checkInsertStmt(n, s, o, c, &el) - default: - el.Append(errors.Error{ - Position: ss[idx].GetPosition(), - Length: ss[idx].GetLength(), - Category: errors.ErrorCategoryCommand, - Code: errors.ErrorCodeDisallowedCommand, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "command %s is not allowed when calling exec", - ast.QuoteIdentifier(ss[idx].GetVerb())), - }, nil) - } - } - if len(el) != 0 { - return el - } - return nil -} diff --git a/core/vm/sqlvm/checkers/checkers.go b/core/vm/sqlvm/checkers/checkers.go deleted file mode 100644 index 7881aa149..000000000 --- a/core/vm/sqlvm/checkers/checkers.go +++ /dev/null @@ -1,1579 +0,0 @@ -package checkers - -import ( - "fmt" - "sort" - - "github.com/dexon-foundation/decimal" - - "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" -) - -// In addition to the convention mentioned in utils.go, we have these variable -// names in this file: -// -// ftd -> foreign table descriptor -// ftn -> foreign table name -// fcd -> foreign column descriptor -// fcn -> foreign column name -// fid -> foreign index descriptor -// fin -> foreign index name -// -// fmid -> first matching index descriptor -// fmir -> first matching index reference -// fmin -> first matching index name - -// findFirstMatchingIndex finds the first index in 'haystack' matching the -// declaration of 'needle' with attributes specified in 'attrDontCare' ignored. -// This function is considered as a part of the interface, so it have to work -// deterministically. -func findFirstMatchingIndex(haystack []schema.Index, needle schema.Index, - attrDontCare schema.IndexAttr) (schema.IndexRef, bool) { - - compareAttr := func(a1, a2 schema.IndexAttr) bool { - a1 = a1.GetDeclaredFlags() | attrDontCare - a2 = a2.GetDeclaredFlags() | attrDontCare - return a1 == a2 - } - - compareColumns := func(c1, c2 []schema.ColumnRef) bool { - if len(c1) != len(c2) { - return false - } - for ci := range c1 { - if c1[ci] != c2[ci] { - return false - } - } - return true - } - - for ii := range haystack { - if compareAttr(haystack[ii].Attr, needle.Attr) && - compareColumns(haystack[ii].Columns, needle.Columns) { - ir := schema.IndexRef(ii) - return ir, true - } - } - return 0, false -} - -func checkCreateTableStmt(n *ast.CreateTableStmtNode, s *schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { - - fn := "CheckCreateTableStmt" - hasError := false - - if c.Begin() != 0 { - panic("schema cache must not have any open scope") - } - defer func() { - if hasError { - c.Rollback() - return - } - c.Commit() - }() - - // Return early if there are too many tables. We cannot ignore this error - // because it will overflow schema.TableRef, which is used as a part of - // column key in schemaCache. - if len(*s) > schema.MaxTableRef { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyTables, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("cannot have more than %d tables", - schema.MaxTableRef+1), - }, &hasError) - return - } - - table := schema.Table{} - tr := schema.TableRef(len(*s)) - td := schema.TableDescriptor{Table: tr} - - if len(n.Table.Name) == 0 { - el.Append(errors.Error{ - Position: n.Table.GetPosition(), - Length: n.Table.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeEmptyTableName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "cannot create a table with an empty name", - }, &hasError) - } - - tn := n.Table.Name - if !c.AddTable(string(tn), td) { - el.Append(errors.Error{ - Position: n.Table.GetPosition(), - Length: n.Table.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeDuplicateTableName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("table %s already exists", - ast.QuoteIdentifier(tn)), - }, &hasError) - } - table.Name = n.Table.Name - table.Columns = make([]schema.Column, 0, len(n.Column)) - - // Handle the primary key index. - pk := []schema.ColumnRef{} - // Handle sequences. - seq := 0 - // Handle indices for unique constraints. - type localIndex struct { - index schema.Index - node ast.Node - } - localIndices := []localIndex{} - // Handle indices for foreign key constraints. - type foreignNewIndex struct { - table schema.TableDescriptor - index schema.Index - node ast.Node - } - foreignNewIndices := []foreignNewIndex{} - type foreignExistingIndex struct { - index schema.IndexDescriptor - node ast.Node - } - foreignExistingIndices := []foreignExistingIndex{} - - for ci := range n.Column { - if len(table.Columns) > schema.MaxColumnRef { - el.Append(errors.Error{ - Position: n.Column[ci].GetPosition(), - Length: n.Column[ci].GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyColumns, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("cannot have more than %d columns", - schema.MaxColumnRef+1), - }, &hasError) - return - } - - column := schema.Column{} - ok := func() (ok bool) { - innerHasError := false - defer func() { ok = !innerHasError }() - - // Block access to the outer hasError variable. - hasError := struct{}{} - // Suppress “declared and not used” error. - _ = hasError - - c.Begin() - defer func() { - if innerHasError { - c.Rollback() - return - } - c.Commit() - }() - - cr := schema.ColumnRef(len(table.Columns)) - cd := schema.ColumnDescriptor{Table: tr, Column: cr} - - if len(n.Column[ci].Column.Name) == 0 { - el.Append(errors.Error{ - Position: n.Column[ci].Column.GetPosition(), - Length: n.Column[ci].Column.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeEmptyColumnName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "cannot declare a column with an empty name", - }, &innerHasError) - } - - cn := n.Column[ci].Column.Name - if !c.AddColumn(string(cn), cd) { - el.Append(errors.Error{ - Position: n.Column[ci].Column.GetPosition(), - Length: n.Column[ci].Column.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeDuplicateColumnName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("column %s already exists", - ast.QuoteIdentifier(cn)), - }, &innerHasError) - } else { - column.Name = n.Column[ci].Column.Name - } - - dt, code, message := n.Column[ci].DataType.GetType() - if code == errors.ErrorCodeNil { - if !dt.ValidColumn() { - el.Append(errors.Error{ - Position: n.Column[ci].DataType.GetPosition(), - Length: n.Column[ci].DataType.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeInvalidColumnDataType, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot declare a column with type %s", dt.String()), - }, &innerHasError) - } - } else { - el.Append(errors.Error{ - Position: n.Column[ci].DataType.GetPosition(), - Length: n.Column[ci].DataType.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: code, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: message, - }, &innerHasError) - } - column.Type = dt - - // Backup lengths of slices in case we have to rollback. We don't - // have to copy slice headers or data stored in underlying arrays - // because we always append data at the end. - defer func(LPK, SEQ, LLI, LFNI, LFEI int) { - if innerHasError { - pk = pk[:LPK] - seq = SEQ - localIndices = localIndices[:LLI] - foreignNewIndices = foreignNewIndices[:LFNI] - foreignExistingIndices = foreignExistingIndices[:LFEI] - } - }( - len(pk), seq, len(localIndices), len(foreignNewIndices), - len(foreignExistingIndices), - ) - - // cs -> constraint - // csi -> constraint index - for csi := range n.Column[ci].Constraint { - // Cases are sorted in the same order as internal/grammar.peg. - cs: - switch cs := n.Column[ci].Constraint[csi].(type) { - case *ast.PrimaryOptionNode: - pk = append(pk, cr) - column.Attr |= schema.ColumnAttrPrimaryKey - - case *ast.NotNullOptionNode: - column.Attr |= schema.ColumnAttrNotNull - - case *ast.UniqueOptionNode: - if (column.Attr & schema.ColumnAttrUnique) != 0 { - // Don't create duplicate indices on a column. - break cs - } - column.Attr |= schema.ColumnAttrUnique - indexName := fmt.Sprintf("%s_%s_unique", - table.Name, column.Name) - idx := schema.Index{ - Name: []byte(indexName), - Attr: schema.IndexAttrUnique, - Columns: []schema.ColumnRef{cr}, - } - localIndices = append(localIndices, localIndex{ - index: idx, - node: cs, - }) - - case *ast.DefaultOptionNode: - if (column.Attr & schema.ColumnAttrHasDefault) != 0 { - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeMultipleDefaultValues, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "cannot have multiple default values", - }, &innerHasError) - break cs - } - column.Attr |= schema.ColumnAttrHasDefault - - value := cs.Value - value = checkExpr(cs.Value, *s, o|CheckWithConstantOnly, - c, el, 0, newTypeActionAssign(column.Type)) - if value == nil { - innerHasError = true - break cs - } - cs.Value = value - - switch v := cs.Value.(ast.Valuer).(type) { - case *ast.BoolValueNode: - sb := v.V.NullBool() - if !sb.Valid { - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeNullDefaultValue, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "default value must not be NULL", - }, &innerHasError) - break cs - } - column.Default = sb.Bool - - case *ast.AddressValueNode: - column.Default = v.V - - case *ast.IntegerValueNode: - column.Default = v.V - - case *ast.DecimalValueNode: - column.Default = v.V - - case *ast.BytesValueNode: - column.Default = v.V - - case *ast.NullValueNode: - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeNullDefaultValue, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "default value must not be NULL", - }, &innerHasError) - break cs - - default: - panic(unknownValueNodeType(v)) - } - - case *ast.ForeignOptionNode: - if len(column.ForeignKeys) > schema.MaxForeignKeys { - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyForeignKeys, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot have more than %d foreign key "+ - "constraints in a column", - schema.MaxForeignKeys+1), - }, &innerHasError) - break cs - } - column.Attr |= schema.ColumnAttrHasForeignKey - ftn := cs.Table.Name - ftd, found := c.FindTableInBase(string(ftn)) - if !found { - el.Append(errors.Error{ - Position: cs.Table.GetPosition(), - Length: cs.Table.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTableNotFound, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "foreign table %s does not exist", - ast.QuoteIdentifier(ftn)), - }, &innerHasError) - break cs - } - fcn := cs.Column.Name - fcd, found := c.FindColumnInBase(ftd.Table, string(fcn)) - if !found { - el.Append(errors.Error{ - Position: cs.Column.GetPosition(), - Length: cs.Column.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeColumnNotFound, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "column %s does not exist in foreign table %s", - ast.QuoteIdentifier(fcn), - ast.QuoteIdentifier(ftn)), - }, &innerHasError) - break cs - } - foreignType := (*s)[fcd.Table].Columns[fcd.Column].Type - if !foreignType.Equal(column.Type) { - el.Append(errors.Error{ - Position: n.Column[ci].DataType.GetPosition(), - Length: n.Column[ci].DataType.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeForeignKeyDataTypeMismatch, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "foreign column has type %s (%04x), but "+ - "this column has type %s (%04x)", - foreignType.String(), uint16(foreignType), - column.Type.String(), uint16(column.Type)), - }, &innerHasError) - break cs - } - - idx := schema.Index{ - Attr: schema.IndexAttrReferenced, - Columns: []schema.ColumnRef{fcd.Column}, - } - fmir, found := findFirstMatchingIndex( - (*s)[ftd.Table].Indices, idx, schema.IndexAttrUnique) - if found { - fmid := schema.IndexDescriptor{ - Table: ftd.Table, - Index: fmir, - } - foreignExistingIndices = append( - foreignExistingIndices, foreignExistingIndex{ - index: fmid, - node: cs, - }) - } else { - if len(column.ForeignKeys) > 0 { - idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key_%d", - table.Name, column.Name, len(column.ForeignKeys))) - } else { - idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key", - table.Name, column.Name)) - } - foreignNewIndices = append( - foreignNewIndices, foreignNewIndex{ - table: ftd, - index: idx, - node: cs, - }) - } - column.ForeignKeys = append(column.ForeignKeys, fcd) - - case *ast.AutoIncrementOptionNode: - if (column.Attr & schema.ColumnAttrHasSequence) != 0 { - // Don't process AUTOINCREMENT twice. - break cs - } - // We set the flag regardless of the error because we - // don't want to produce duplicate errors. - column.Attr |= schema.ColumnAttrHasSequence - if seq > schema.MaxSequenceRef { - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManySequences, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot have more than %d sequences", - schema.MaxSequenceRef+1), - }, &innerHasError) - break cs - } - major, _ := ast.DecomposeDataType(column.Type) - switch major { - case ast.DataTypeMajorInt, ast.DataTypeMajorUint: - default: - el.Append(errors.Error{ - Position: cs.GetPosition(), - Length: cs.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeInvalidAutoIncrementDataType, - Prefix: fn, - Message: fmt.Sprintf( - "AUTOINCREMENT is only supported on "+ - "INT and UINT types, but this column "+ - "has type %s (%04x)", - column.Type.String(), uint16(column.Type)), - }, &innerHasError) - break cs - } - column.Sequence = schema.SequenceRef(seq) - seq++ - - default: - panic(fmt.Sprintf("unknown constraint type %T", c)) - } - } - - // The return value will be set by the first defer function. - return - }() - - // If an error occurs in the function, stop here and continue - // processing the next column. - if !ok { - hasError = true - continue - } - - // Commit the column. - table.Columns = append(table.Columns, column) - } - - // Return early if there is any error. - if hasError { - return - } - - mustAddIndex := func(name *[]byte, id schema.IndexDescriptor) { - for !c.AddIndex(string(*name), id, true) { - *name = append(*name, '_') - } - } - - // Create the primary key index. This is the first index on the table, so - // it is not possible to exceed the limit on the number of indices. - ir := schema.IndexRef(len(table.Indices)) - if len(pk) > 0 { - idx := schema.Index{ - Name: []byte(fmt.Sprintf("%s_primary_key", table.Name)), - Attr: schema.IndexAttrUnique, - Columns: pk, - } - id := schema.IndexDescriptor{Table: tr, Index: ir} - mustAddIndex(&idx.Name, id) - table.Indices = append(table.Indices, idx) - } - - // Create indices for the current table. - for ii := range localIndices { - if len(table.Indices) > schema.MaxIndexRef { - el.Append(errors.Error{ - Position: localIndices[ii].node.GetPosition(), - Length: localIndices[ii].node.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyIndices, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("cannot have more than %d indices", - schema.MaxIndexRef+1), - }, &hasError) - return - } - idx := localIndices[ii].index - ir := schema.IndexRef(len(table.Indices)) - id := schema.IndexDescriptor{Table: tr, Index: ir} - mustAddIndex(&idx.Name, id) - table.Indices = append(table.Indices, idx) - } - - // Create indices for foreign tables. - for ii := range foreignNewIndices { - ftd := foreignNewIndices[ii].table - if len((*s)[ftd.Table].Indices) > schema.MaxIndexRef { - el.Append(errors.Error{ - Position: foreignNewIndices[ii].node.GetPosition(), - Length: foreignNewIndices[ii].node.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyIndices, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "table %s already has %d indices", - ast.QuoteIdentifier((*s)[ftd.Table].Name), - schema.MaxIndexRef+1), - }, &hasError) - return - } - idx := foreignNewIndices[ii].index - ir := schema.IndexRef(len((*s)[ftd.Table].Indices)) - id := schema.IndexDescriptor{Table: ftd.Table, Index: ir} - mustAddIndex(&idx.Name, id) - (*s)[ftd.Table].Indices = append((*s)[ftd.Table].Indices, idx) - defer func(tr schema.TableRef, length schema.IndexRef) { - if hasError { - (*s)[tr].Indices = (*s)[tr].Indices[:ir] - } - }(ftd.Table, ir) - } - - // Mark existing indices as referenced. - for ii := range foreignExistingIndices { - fid := foreignExistingIndices[ii].index - (*s)[fid.Table].Indices[fid.Index].Attr |= schema.IndexAttrReferenced - } - - // Finally, we can commit the table definition to the schema. - *s = append(*s, table) -} - -func checkCreateIndexStmt(n *ast.CreateIndexStmtNode, s *schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { - - fn := "CheckCreateIndexStmt" - hasError := false - - if c.Begin() != 0 { - panic("schema cache must not have any open scope") - } - defer func() { - if hasError { - c.Rollback() - return - } - c.Commit() - }() - - tn := n.Table.Name - td, found := c.FindTableInBase(string(tn)) - if !found { - el.Append(errors.Error{ - Position: n.Table.GetPosition(), - Length: n.Table.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTableNotFound, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "index table %s does not exist", - ast.QuoteIdentifier(tn)), - }, &hasError) - return - } - - if len(n.Column) > schema.MaxColumnRef { - begin := n.Column[0].GetPosition() - last := len(n.Column) - 1 - end := n.Column[last].GetPosition() + n.Column[last].GetLength() - el.Append(errors.Error{ - Position: begin, - Length: end - begin, - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyColumns, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot create an index on more than %d columns", - schema.MaxColumnRef+1), - }, &hasError) - return - } - - columnRefs := newColumnRefSlice(uint8(len(n.Column))) - for ci := range n.Column { - cn := n.Column[ci].Name - cd, found := c.FindColumnInBase(td.Table, string(cn)) - if !found { - el.Append(errors.Error{ - Position: n.Column[ci].GetPosition(), - Length: n.Column[ci].GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeColumnNotFound, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "column %s does not exist in index table %s", - ast.QuoteIdentifier(cn), - ast.QuoteIdentifier(tn)), - }, &hasError) - continue - } - columnRefs.Append(cd.Column, uint8(ci)) - } - if hasError { - return - } - - sort.Stable(columnRefs) - for ci := 1; ci < len(n.Column); ci++ { - if columnRefs.columns[ci] == columnRefs.columns[ci-1] { - el.Append(errors.Error{ - Position: n.Column[columnRefs.nodes[ci]].GetPosition(), - Length: n.Column[columnRefs.nodes[ci]].GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeDuplicateIndexColumn, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "column %s already exists in the column list", - ast.QuoteIdentifier(n.Column[columnRefs.nodes[ci]].Name)), - }, &hasError) - return - } - } - - index := schema.Index{} - index.Columns = columnRefs.columns - - if len((*s)[td.Table].Indices) > schema.MaxIndexRef { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategoryLimit, - Code: errors.ErrorCodeTooManyIndices, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot have more than %d indices in table %s", - schema.MaxIndexRef+1, - ast.QuoteIdentifier(tn)), - }, &hasError) - return - } - - ir := schema.IndexRef(len((*s)[td.Table].Indices)) - id := schema.IndexDescriptor{Table: td.Table, Index: ir} - if n.Unique != nil { - index.Attr |= schema.IndexAttrUnique - } - - if len(n.Index.Name) == 0 { - el.Append(errors.Error{ - Position: n.Index.GetPosition(), - Length: n.Table.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeEmptyIndexName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: "cannot create an index with an empty name", - }, &hasError) - return - } - - // If there is an existing index that is automatically created, rename it - // instead of creating a new one. - rename := false - fmir, found := findFirstMatchingIndex((*s)[id.Table].Indices, index, 0) - if found { - fmid := schema.IndexDescriptor{Table: id.Table, Index: fmir} - fmin := (*s)[id.Table].Indices[fmir].Name - fminString := string(fmin) - fmidCache, auto, found := c.FindIndexInBase(fminString) - if !found { - panic(fmt.Sprintf("index %s exists in the schema, "+ - "but it cannot be found in the schema cache", - ast.QuoteIdentifier(fmin))) - } - if fmidCache != fmid { - panic(fmt.Sprintf("index %s has descriptor %+v, "+ - "but the schema cache records it as %+v", - ast.QuoteIdentifier(fmin), fmid, fmidCache)) - } - if auto { - if !c.DeleteIndex(fminString) { - panic(fmt.Sprintf("unable to mark index %s for deletion", - ast.QuoteIdentifier(fmin))) - } - rename = true - id = fmid - ir = id.Index - } - } - - in := n.Index.Name - if !c.AddIndex(string(in), id, false) { - el.Append(errors.Error{ - Position: n.Index.GetPosition(), - Length: n.Index.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeDuplicateIndexName, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("index %s already exists", - ast.QuoteIdentifier(in)), - }, &hasError) - return - } - - // Commit the change into the schema. - if rename { - (*s)[id.Table].Indices[id.Index].Name = n.Index.Name - } else { - index.Name = n.Index.Name - (*s)[id.Table].Indices = append((*s)[id.Table].Indices, index) - } -} - -func checkSelectStmt(n *ast.SelectStmtNode, s schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { -} - -func checkUpdateStmt(n *ast.UpdateStmtNode, s schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { -} - -func checkDeleteStmt(n *ast.DeleteStmtNode, s schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { -} - -func checkInsertStmt(n *ast.InsertStmtNode, s schema.Schema, - o CheckOptions, c *schemaCache, el *errors.ErrorList) { -} - -func checkExpr(n ast.ExprNode, - s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, - tr schema.TableRef, ta typeAction) ast.ExprNode { - - switch n := n.(type) { - case *ast.IdentifierNode: - return checkVariable(n, s, o, c, el, tr, ta) - - case *ast.BoolValueNode: - return checkBoolValue(n, o, el, ta) - - case *ast.AddressValueNode: - return checkAddressValue(n, o, el, ta) - - case *ast.IntegerValueNode: - return checkIntegerValue(n, o, el, ta) - - case *ast.DecimalValueNode: - return checkDecimalValue(n, o, el, ta) - - case *ast.BytesValueNode: - return checkBytesValue(n, o, el, ta) - - case *ast.NullValueNode: - return checkNullValue(n, o, el, ta) - - case *ast.PosOperatorNode: - return checkPosOperator(n, s, o, c, el, tr, ta) - - case *ast.NegOperatorNode: - return n - - case *ast.NotOperatorNode: - return n - - case *ast.ParenOperatorNode: - return n - - case *ast.AndOperatorNode: - return n - - case *ast.OrOperatorNode: - return n - - case *ast.GreaterOrEqualOperatorNode: - return n - - case *ast.LessOrEqualOperatorNode: - return n - - case *ast.NotEqualOperatorNode: - return n - - case *ast.EqualOperatorNode: - return n - - case *ast.GreaterOperatorNode: - return n - - case *ast.LessOperatorNode: - return n - - case *ast.ConcatOperatorNode: - return n - - case *ast.AddOperatorNode: - return n - - case *ast.SubOperatorNode: - return n - - case *ast.MulOperatorNode: - return n - - case *ast.DivOperatorNode: - return n - - case *ast.ModOperatorNode: - return n - - case *ast.IsOperatorNode: - return n - - case *ast.LikeOperatorNode: - return n - - case *ast.CastOperatorNode: - return n - - case *ast.InOperatorNode: - return n - - case *ast.FunctionOperatorNode: - return n - - default: - panic(fmt.Sprintf("unknown expression type %T", n)) - } -} - -func elAppendTypeErrorMismatch(el *errors.ErrorList, n ast.ExprNode, - fn string, dtExpected, dtGiven ast.DataType) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "expect %s (%04x), but %s (%04x) is given", - dtExpected.String(), uint16(dtExpected), - dtGiven.String(), uint16(dtGiven)), - }, nil) -} -func checkVariable(n *ast.IdentifierNode, - s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, - tr schema.TableRef, ta typeAction) ast.ExprNode { - - fn := "CheckVariable" - - if (o & CheckWithConstantOnly) != 0 { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeNonConstantExpression, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("%s is not a constant", - ast.QuoteIdentifier(n.Name)), - }, nil) - return nil - } - - cn := string(n.Name) - cd, found := c.FindColumnInBase(tr, cn) - if !found { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeColumnNotFound, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "cannot find column %s in table %s", - ast.QuoteIdentifier(n.Name), - ast.QuoteIdentifier(s[tr].Name)), - }, nil) - return nil - } - - cr := cd.Column - dt := s[tr].Columns[cr].Type - switch a := ta.(type) { - case typeActionInferDefault: - case typeActionInferWithSize: - case typeActionAssign: - if !dt.Equal(a.dt) { - elAppendTypeErrorMismatch(el, n, fn, a.dt, dt) - return nil - } - } - - n.SetType(dt) - n.Desc = cd - return n -} - -func unknownValueNodeType(n ast.Valuer) string { - return fmt.Sprintf("unknown constant type %T", n) -} - -func describeValueNodeType(n ast.Valuer) string { - switch n.(type) { - case *ast.BoolValueNode: - return "boolean constant" - case *ast.AddressValueNode: - return "address constant" - case *ast.IntegerValueNode, *ast.DecimalValueNode: - return "number constant" - case *ast.BytesValueNode: - return "bytes constant" - case *ast.NullValueNode: - return "null constant" - default: - panic(unknownValueNodeType(n)) - } -} - -func elAppendTypeErrorValueNode(el *errors.ErrorList, n ast.Valuer, - fn string, dt ast.DataType) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "expect %s (%04x), but %s is given", - dt.String(), uint16(dt), describeValueNodeType(n)), - }, nil) -} - -func checkBoolValue(n *ast.BoolValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - fn := "CheckBoolValue" - - switch a := ta.(type) { - case typeActionInferDefault: - case typeActionInferWithSize: - case typeActionAssign: - major, _ := ast.DecomposeDataType(a.dt) - if major != ast.DataTypeMajorBool { - elAppendTypeErrorValueNode(el, n, fn, a.dt) - return nil - } - } - return n -} - -func checkAddressValue(n *ast.AddressValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - fn := "CheckAddressValue" - - switch a := ta.(type) { - case typeActionInferDefault: - case typeActionInferWithSize: - case typeActionAssign: - major, _ := ast.DecomposeDataType(a.dt) - if major != ast.DataTypeMajorAddress { - elAppendTypeErrorValueNode(el, n, fn, a.dt) - return nil - } - } - return n -} - -func mustGetMinMax(dt ast.DataType) (decimal.Decimal, decimal.Decimal) { - min, max, ok := dt.GetMinMax() - if !ok { - panic(fmt.Sprintf("GetMinMax does not handle %v", dt)) - } - return min, max -} - -func mustDecimalEncode(dt ast.DataType, d decimal.Decimal) []byte { - b, ok := ast.DecimalEncode(dt, d) - if !ok { - panic(fmt.Sprintf("DecimalEncode does not handle %v", dt)) - } - return b -} - -func mustDecimalDecode(dt ast.DataType, b []byte) decimal.Decimal { - d, ok := ast.DecimalDecode(dt, b) - if !ok { - panic(fmt.Sprintf("DecimalDecode does not handle %v", dt)) - } - return d -} - -func cropDecimal(dt ast.DataType, d decimal.Decimal) decimal.Decimal { - b := mustDecimalEncode(dt, d) - return mustDecimalDecode(dt, b) -} - -func elAppendConstantTooLongError(el *errors.ErrorList, n ast.Valuer, - fn string, v decimal.Decimal) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeConstantTooLong, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "constant expression %s has more than %d digits", - ast.QuoteString(n.GetToken()), MaxIntegerPartDigits), - }, nil) -} - -func elAppendOverflowError(el *errors.ErrorList, n ast.Valuer, - fn string, dt ast.DataType, v, min, max decimal.Decimal) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeOverflow, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "number %s (%s) overflows %s (%04x)", - ast.QuoteString(n.GetToken()), v.String(), - dt.String(), uint16(dt)), - }, nil) - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: 0, - Code: 0, - Severity: errors.ErrorSeverityNote, - Prefix: fn, - Message: fmt.Sprintf( - "the range of %s is [%s, %s]", - dt.String(), min.String(), max.String()), - }, nil) -} - -func elAppendOverflowWarning(el *errors.ErrorList, n ast.Valuer, - fn string, dt ast.DataType, from, to decimal.Decimal) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: 0, - Code: 0, - Severity: errors.ErrorSeverityWarning, - Prefix: fn, - Message: fmt.Sprintf( - "number %s (%s) overflows %s (%04x), converted to %s", - ast.QuoteString(n.GetToken()), from.String(), - dt.String(), uint16(dt), to.String()), - }, nil) -} - -func checkIntegerValue(n *ast.IntegerValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - fn := "CheckIntegerValue" - - normalizeDecimal(&n.V) - if !safeDecimalRange(n.V) { - elAppendConstantTooLongError(el, n, fn, n.V) - return nil - } - - infer := func(size int) (ast.DataType, bool) { - // The first case: assume V fits in the signed integer. - minor := ast.DataTypeMinor(size - 1) - dt := ast.ComposeDataType(ast.DataTypeMajorInt, minor) - min, max := mustGetMinMax(dt) - // Return if V < min. V must be negative so it must be signed. - if n.V.LessThan(min) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return dt, false - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - return dt, true - } - // We are done if V fits in the signed integer. - if n.V.LessThanOrEqual(max) { - return dt, true - } - - // The second case: V is a non-negative integer, but it does not fit - // in the signed integer. Test whether the unsigned integer works. - dt = ast.ComposeDataType(ast.DataTypeMajorUint, minor) - min, max = mustGetMinMax(dt) - // Return if V > max. We don't have to test whether V < min because min - // is always zero and we already know V is non-negative. - if n.V.GreaterThan(max) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return dt, false - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - return dt, true - } - return dt, true - } - - dt := ast.DataTypePending - switch a := ta.(type) { - case typeActionInferDefault: - // Default to int256 or uint256. - var ok bool - dt, ok = infer(256 / 8) - if !ok { - return nil - } - - case typeActionInferWithSize: - var ok bool - dt, ok = infer(a.size) - if !ok { - return nil - } - - case typeActionAssign: - dt = a.dt - major, _ := ast.DecomposeDataType(dt) - switch { - case major == ast.DataTypeMajorAddress: - if !n.IsAddress { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeInvalidAddressChecksum, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "expect %s (%04x), but %s is not an address constant", - dt.String(), uint16(dt), n.GetToken()), - }, nil) - return nil - } - // Redirect to checkAddressValue if it becomes an address. - addrNode := &ast.AddressValueNode{} - addrNode.SetPosition(addrNode.GetPosition()) - addrNode.SetLength(addrNode.GetLength()) - addrNode.SetToken(addrNode.GetToken()) - addrNode.V = mustDecimalEncode(ast.ComposeDataType( - ast.DataTypeMajorUint, ast.DataTypeMinor(160/8-1)), n.V) - return checkAddressValue(addrNode, o, el, ta) - - case major == ast.DataTypeMajorInt, - major == ast.DataTypeMajorUint, - major.IsFixedRange(), - major.IsUfixedRange(): - min, max := mustGetMinMax(dt) - if n.V.LessThan(min) || n.V.GreaterThan(max) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return nil - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - } - - default: - elAppendTypeErrorValueNode(el, n, fn, dt) - return nil - } - } - - if !dt.Pending() { - n.SetType(dt) - } - return n -} - -func checkDecimalValue(n *ast.DecimalValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - fn := "CheckDecimalValue" - - normalizeDecimal(&n.V) - if !safeDecimalRange(n.V) { - elAppendConstantTooLongError(el, n, fn, n.V) - return nil - } - - // Redirect to checkIntegerValue if the value is an integer. - if intPart := n.V.Truncate(0); n.V.Equal(intPart) { - intNode := &ast.IntegerValueNode{} - intNode.SetPosition(n.GetPosition()) - intNode.SetLength(n.GetLength()) - intNode.SetToken(n.GetToken()) - intNode.SetType(n.GetType()) - intNode.IsAddress = false - intNode.V = n.V - return checkIntegerValue(intNode, o, el, ta) - } - - infer := func(size, fractionalDigits int) (ast.DataType, bool) { - major := ast.DataTypeMajorFixed + ast.DataTypeMajor(size-1) - minor := ast.DataTypeMinor(fractionalDigits) - dt := ast.ComposeDataType(major, minor) - min, max := mustGetMinMax(dt) - if n.V.LessThan(min) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return dt, false - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - return dt, false - } - if n.V.LessThanOrEqual(max) { - return dt, true - } - - major = ast.DataTypeMajorUfixed + ast.DataTypeMajor(size-1) - minor = ast.DataTypeMinor(fractionalDigits) - dt = ast.ComposeDataType(major, minor) - min, max = mustGetMinMax(dt) - if n.V.GreaterThan(max) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return dt, false - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - return dt, true - } - return dt, true - } - - // Now we are sure the number we are dealing has fractional part. - dt := ast.DataTypePending - switch a := ta.(type) { - case typeActionInferDefault: - // Default to fixed128x18 and ufixed128x18. - var ok bool - dt, ok = infer(128/8, 18) - if !ok { - return nil - } - - case typeActionInferWithSize: - // It is unclear that what the size hint means for fixed-point numbers, - // so we just ignore it and do the same thing as the above case. - var ok bool - dt, ok = infer(128/8, 18) - if !ok { - return nil - } - - case typeActionAssign: - dt = a.dt - major, _ := ast.DecomposeDataType(dt) - switch { - case major.IsFixedRange(), - major.IsUfixedRange(): - min, max := mustGetMinMax(dt) - if n.V.LessThan(min) || n.V.GreaterThan(max) { - if (o & CheckWithSafeMath) != 0 { - elAppendOverflowError(el, n, fn, dt, n.V, min, max) - return nil - } - cropped := cropDecimal(dt, n.V) - elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) - normalizeDecimal(&cropped) - n.V = cropped - } - - case major == ast.DataTypeMajorInt, - major == ast.DataTypeMajorUint: - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "expect %s (%04x), but the number %s has fractional part", - dt.String(), uint16(dt), n.V.String()), - }, nil) - return nil - - default: - elAppendTypeErrorValueNode(el, n, fn, dt) - return nil - } - } - - if !dt.Pending() { - n.SetType(dt) - _, minor := ast.DecomposeDataType(dt) - fractionalDigits := int32(minor) - n.V = n.V.Round(fractionalDigits) - normalizeDecimal(&n.V) - } - return n -} - -func checkBytesValue(n *ast.BytesValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - fn := "CheckBytesValue" - - dt := ast.DataTypePending - -executeTypeAction: - switch a := ta.(type) { - case typeActionInferDefault: - // Default to bytes. - major := ast.DataTypeMajorDynamicBytes - minor := ast.DataTypeMinorDontCare - dt = ast.ComposeDataType(major, minor) - ta = newTypeActionAssign(dt) - goto executeTypeAction - - case typeActionInferWithSize: - major := ast.DataTypeMajorFixedBytes - minor := ast.DataTypeMinor(a.size - 1) - dt = ast.ComposeDataType(major, minor) - ta = newTypeActionAssign(dt) - goto executeTypeAction - - case typeActionAssign: - dt = a.dt - major, minor := ast.DecomposeDataType(dt) - switch major { - case ast.DataTypeMajorDynamicBytes: - // Do nothing because it is always valid. - - case ast.DataTypeMajorFixedBytes: - sizeGiven := len(n.V) - sizeExpected := int(minor) + 1 - if sizeGiven != sizeExpected { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf( - "expect %s (%04x), but %s has %d bytes", - dt.String(), uint16(dt), - ast.QuoteString(n.V), sizeGiven), - }, nil) - return nil - } - - default: - elAppendTypeErrorValueNode(el, n, fn, dt) - return nil - } - } - - if !dt.Pending() { - n.SetType(dt) - } - return n -} - -func checkNullValue(n *ast.NullValueNode, - o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { - - dt := ast.DataTypePending - switch a := ta.(type) { - case typeActionInferDefault: - dt = ast.DataTypeNull - case typeActionInferWithSize: - dt = ast.DataTypeNull - case typeActionAssign: - dt = a.dt - } - - if !dt.Pending() { - n.SetType(dt) - } - return n -} - -func elAppendTypeErrorOperatorValueNode(el *errors.ErrorList, n ast.Valuer, - fn string, op string) { - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("%s is not defined for %s", - op, describeValueNodeType(n)), - }, nil) -} - -func elAppendTypeErrorOperatorDataType(el *errors.ErrorList, n ast.ExprNode, - fn string, op string) { - dt := n.GetType() - el.Append(errors.Error{ - Position: n.GetPosition(), - Length: n.GetLength(), - Category: errors.ErrorCategorySemantic, - Code: errors.ErrorCodeTypeError, - Severity: errors.ErrorSeverityError, - Prefix: fn, - Message: fmt.Sprintf("%s is not defined for %s (%04x)", - op, dt.String(), uint16(dt)), - }, nil) -} - -func checkPosOperator(n *ast.PosOperatorNode, - s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, - tr schema.TableRef, ta typeAction) ast.ExprNode { - - fn := "CheckPosOperator" - op := "unary operator +" - - target := n.GetTarget() - target = checkExpr(target, s, o, c, el, tr, nil) - if target == nil { - return nil - } - - if v, ok := target.(ast.Valuer); ok { - switch v := v.(type) { - case *ast.IntegerValueNode: - // Clone the node to reset IsAddress to false. - if v.IsAddress { - result := &ast.IntegerValueNode{} - result.SetPosition(v.GetPosition()) - result.SetLength(v.GetLength()) - result.SetToken(v.GetToken()) - result.SetType(v.GetType()) - result.IsAddress = false - result.V = v.V - target = result - } - case *ast.DecimalValueNode: - // Do nothing because the result is the same as the input. - case *ast.BoolValueNode: - elAppendTypeErrorOperatorValueNode(el, v, fn, op) - return nil - case *ast.AddressValueNode: - elAppendTypeErrorOperatorValueNode(el, v, fn, op) - return nil - case *ast.BytesValueNode: - elAppendTypeErrorOperatorValueNode(el, v, fn, op) - return nil - case *ast.NullValueNode: - elAppendTypeErrorOperatorValueNode(el, v, fn, op) - return nil - default: - panic(unknownValueNodeType(v)) - } - } - - dt := target.GetType() - if dt.Pending() { - target = checkExpr(target, s, o, c, el, tr, ta) - } else { - major, _ := ast.DecomposeDataType(dt) - switch { - case major == ast.DataTypeMajorInt, - major == ast.DataTypeMajorUint, - major.IsFixedRange(), - major.IsUfixedRange(): - default: - elAppendTypeErrorOperatorDataType(el, target, fn, op) - return nil - } - switch a := ta.(type) { - case typeActionInferDefault: - case typeActionInferWithSize: - case typeActionAssign: - if !dt.Equal(a.dt) { - elAppendTypeErrorMismatch(el, n, fn, a.dt, dt) - } - } - } - return target -} diff --git a/core/vm/sqlvm/checkers/utils.go b/core/vm/sqlvm/checkers/utils.go deleted file mode 100644 index 4a8bbd96e..000000000 --- a/core/vm/sqlvm/checkers/utils.go +++ /dev/null @@ -1,471 +0,0 @@ -package checkers - -import ( - "github.com/dexon-foundation/decimal" - - "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" -) - -// Variable name convention: -// -// fn -> function name -// el -> error list -// -// td -> table descriptor -// tr -> table reference -// ti -> table index -// tn -< table name -// -// cd -> column descriptor -// cr -> column reference -// ci -> column index -// cn -> column name -// -// id -> index descriptor -// ir -> index reference -// ii -> index index -// in -> index name - -const ( - MaxIntegerPartDigits int32 = 200 - MaxFractionalPartDigits int32 = 200 -) - -var ( - MaxConstant = func() decimal.Decimal { - max := (decimal.New(1, MaxIntegerPartDigits). - Sub(decimal.New(1, -MaxFractionalPartDigits))) - normalizeDecimal(&max) - return max - }() - MinConstant = MaxConstant.Neg() -) - -func normalizeDecimal(d *decimal.Decimal) { - if d.Exponent() != -MaxFractionalPartDigits { - *d = d.Rescale(-MaxFractionalPartDigits) - } -} - -func safeDecimalRange(d decimal.Decimal) bool { - return d.GreaterThanOrEqual(MinConstant) && d.LessThanOrEqual(MaxConstant) -} - -type schemaCache struct { - base schemaCacheBase - scopes []schemaCacheScope -} - -type schemaCacheIndexValue struct { - id schema.IndexDescriptor - auto bool -} - -type schemaCacheColumnKey struct { - tr schema.TableRef - n string -} - -type schemaCacheBase struct { - table map[string]schema.TableDescriptor - index map[string]schemaCacheIndexValue - column map[schemaCacheColumnKey]schema.ColumnDescriptor -} - -func (lower *schemaCacheBase) Merge(upper schemaCacheScope) { - // Process deletions. - for n := range upper.tableDeleted { - delete(lower.table, n) - } - for n := range upper.indexDeleted { - delete(lower.index, n) - } - for ck := range upper.columnDeleted { - delete(lower.column, ck) - } - - // Process additions. - for n, td := range upper.table { - lower.table[n] = td - } - for n, iv := range upper.index { - lower.index[n] = iv - } - for ck, cd := range upper.column { - lower.column[ck] = cd - } -} - -type schemaCacheScope struct { - table map[string]schema.TableDescriptor - tableDeleted map[string]struct{} - index map[string]schemaCacheIndexValue - indexDeleted map[string]struct{} - column map[schemaCacheColumnKey]schema.ColumnDescriptor - columnDeleted map[schemaCacheColumnKey]struct{} -} - -func (lower *schemaCacheScope) Merge(upper schemaCacheScope) { - // Process deletions. - for n := range upper.tableDeleted { - delete(lower.table, n) - lower.tableDeleted[n] = struct{}{} - } - for n := range upper.indexDeleted { - delete(lower.index, n) - lower.indexDeleted[n] = struct{}{} - } - for ck := range upper.columnDeleted { - delete(lower.column, ck) - lower.columnDeleted[ck] = struct{}{} - } - - // Process additions. - for n, td := range upper.table { - lower.table[n] = td - } - for n, iv := range upper.index { - lower.index[n] = iv - } - for ck, cd := range upper.column { - lower.column[ck] = cd - } -} - -func newSchemaCache() *schemaCache { - return &schemaCache{ - base: schemaCacheBase{ - table: map[string]schema.TableDescriptor{}, - index: map[string]schemaCacheIndexValue{}, - column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, - }, - } -} - -func (c *schemaCache) Begin() int { - position := len(c.scopes) - scope := schemaCacheScope{ - table: map[string]schema.TableDescriptor{}, - tableDeleted: map[string]struct{}{}, - index: map[string]schemaCacheIndexValue{}, - indexDeleted: map[string]struct{}{}, - column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, - columnDeleted: map[schemaCacheColumnKey]struct{}{}, - } - c.scopes = append(c.scopes, scope) - return position -} - -func (c *schemaCache) Rollback() { - if len(c.scopes) == 0 { - panic("there is no scope to rollback") - } - c.scopes = c.scopes[:len(c.scopes)-1] -} - -func (c *schemaCache) RollbackTo(position int) { - for position <= len(c.scopes) { - c.Rollback() - } -} - -func (c *schemaCache) Commit() { - if len(c.scopes) == 0 { - panic("there is no scope to commit") - } - if len(c.scopes) == 1 { - c.base.Merge(c.scopes[0]) - } else { - src := len(c.scopes) - 1 - dst := len(c.scopes) - 2 - c.scopes[dst].Merge(c.scopes[src]) - } - c.scopes = c.scopes[:len(c.scopes)-1] -} - -func (c *schemaCache) CommitTo(position int) { - for position <= len(c.scopes) { - c.Commit() - } -} - -func (c *schemaCache) FindTableInBase(n string) ( - schema.TableDescriptor, bool) { - - td, exists := c.base.table[n] - return td, exists -} - -func (c *schemaCache) FindTableInScope(n string) ( - schema.TableDescriptor, bool) { - - for si := range c.scopes { - si = len(c.scopes) - si - 1 - if td, exists := c.scopes[si].table[n]; exists { - return td, true - } - if _, exists := c.scopes[si].tableDeleted[n]; exists { - return schema.TableDescriptor{}, false - } - } - return c.FindTableInBase(n) -} - -func (c *schemaCache) FindTableInBaseWithFallback(n string, - fallback schema.Schema) (schema.TableDescriptor, bool) { - - if td, found := c.FindTableInBase(n); found { - return td, true - } - if fallback == nil { - return schema.TableDescriptor{}, false - } - - s := fallback - for ti := range s { - if n == string(s[ti].Name) { - tr := schema.TableRef(ti) - td := schema.TableDescriptor{Table: tr} - c.base.table[n] = td - return td, true - } - } - return schema.TableDescriptor{}, false -} - -func (c *schemaCache) FindIndexInBase(n string) ( - schema.IndexDescriptor, bool, bool) { - - iv, exists := c.base.index[n] - return iv.id, iv.auto, exists -} - -func (c *schemaCache) FindIndexInScope(n string) ( - schema.IndexDescriptor, bool, bool) { - - for si := range c.scopes { - si = len(c.scopes) - si - 1 - if iv, exists := c.scopes[si].index[n]; exists { - return iv.id, iv.auto, true - } - if _, exists := c.scopes[si].indexDeleted[n]; exists { - return schema.IndexDescriptor{}, false, false - } - } - return c.FindIndexInBase(n) -} - -func (c *schemaCache) FindIndexInBaseWithFallback(n string, - fallback schema.Schema) (schema.IndexDescriptor, bool, bool) { - - if id, auto, found := c.FindIndexInBase(n); found { - return id, auto, true - } - if fallback == nil { - return schema.IndexDescriptor{}, false, false - } - - s := fallback - for ti := range s { - for ii := range s[ti].Indices { - if n == string(s[ti].Indices[ii].Name) { - tr := schema.TableRef(ti) - ir := schema.IndexRef(ii) - id := schema.IndexDescriptor{Table: tr, Index: ir} - iv := schemaCacheIndexValue{id: id, auto: false} - c.base.index[n] = iv - return id, false, true - } - } - } - return schema.IndexDescriptor{}, false, false -} - -func (c *schemaCache) FindColumnInBase(tr schema.TableRef, n string) ( - schema.ColumnDescriptor, bool) { - - cd, exists := c.base.column[schemaCacheColumnKey{tr: tr, n: n}] - return cd, exists -} - -func (c *schemaCache) FindColumnInScope(tr schema.TableRef, n string) ( - schema.ColumnDescriptor, bool) { - - ck := schemaCacheColumnKey{tr: tr, n: n} - for si := range c.scopes { - si = len(c.scopes) - si - 1 - if cd, exists := c.scopes[si].column[ck]; exists { - return cd, true - } - if _, exists := c.scopes[si].columnDeleted[ck]; exists { - return schema.ColumnDescriptor{}, false - } - } - return c.FindColumnInBase(tr, n) -} - -func (c *schemaCache) FindColumnInBaseWithFallback(tr schema.TableRef, n string, - fallback schema.Schema) (schema.ColumnDescriptor, bool) { - - if cd, found := c.FindColumnInBase(tr, n); found { - return cd, true - } - if fallback == nil { - return schema.ColumnDescriptor{}, false - } - - s := fallback - for ci := range s[tr].Columns { - if n == string(s[tr].Columns[ci].Name) { - cr := schema.ColumnRef(ci) - cd := schema.ColumnDescriptor{Table: tr, Column: cr} - ck := schemaCacheColumnKey{tr: tr, n: n} - c.base.column[ck] = cd - return cd, true - } - } - return schema.ColumnDescriptor{}, false -} - -func (c *schemaCache) AddTable(n string, - td schema.TableDescriptor) bool { - - top := len(c.scopes) - 1 - if _, found := c.FindTableInScope(n); found { - return false - } - - c.scopes[top].table[n] = td - return true -} - -func (c *schemaCache) AddIndex(n string, - id schema.IndexDescriptor, auto bool) bool { - - top := len(c.scopes) - 1 - if _, _, found := c.FindIndexInScope(n); found { - return false - } - - iv := schemaCacheIndexValue{id: id, auto: auto} - c.scopes[top].index[n] = iv - return true -} - -func (c *schemaCache) AddColumn(n string, - cd schema.ColumnDescriptor) bool { - - top := len(c.scopes) - 1 - tr := cd.Table - if _, found := c.FindColumnInScope(tr, n); found { - return false - } - - ck := schemaCacheColumnKey{tr: tr, n: n} - c.scopes[top].column[ck] = cd - return true -} - -func (c *schemaCache) DeleteTable(n string) bool { - top := len(c.scopes) - 1 - if _, found := c.FindTableInScope(n); !found { - return false - } - - delete(c.scopes[top].table, n) - c.scopes[top].tableDeleted[n] = struct{}{} - return true -} - -func (c *schemaCache) DeleteIndex(n string) bool { - top := len(c.scopes) - 1 - if _, _, found := c.FindIndexInScope(n); !found { - return false - } - - delete(c.scopes[top].index, n) - c.scopes[top].indexDeleted[n] = struct{}{} - return true -} - -func (c *schemaCache) DeleteColumn(tr schema.TableRef, n string) bool { - top := len(c.scopes) - 1 - if _, found := c.FindColumnInScope(tr, n); !found { - return false - } - - ck := schemaCacheColumnKey{tr: tr, n: n} - delete(c.scopes[top].column, ck) - c.scopes[top].columnDeleted[ck] = struct{}{} - return true -} - -type columnRefSlice struct { - columns []schema.ColumnRef - nodes []uint8 -} - -func newColumnRefSlice(c uint8) columnRefSlice { - return columnRefSlice{ - columns: make([]schema.ColumnRef, 0, c), - nodes: make([]uint8, 0, c), - } -} - -func (s *columnRefSlice) Append(c schema.ColumnRef, i uint8) { - s.columns = append(s.columns, c) - s.nodes = append(s.nodes, i) -} - -func (s columnRefSlice) Len() int { - return len(s.columns) -} - -func (s columnRefSlice) Less(i, j int) bool { - return s.columns[i] < s.columns[j] -} - -func (s columnRefSlice) Swap(i, j int) { - s.columns[i], s.columns[j] = s.columns[j], s.columns[i] - s.nodes[i], s.nodes[j] = s.nodes[j], s.nodes[i] -} - -//go-sumtype:decl typeAction -type typeAction interface { - ˉtypeAction() -} - -type typeActionInferDefault struct{} - -func newTypeActionInferDefaultSize() typeActionInferDefault { - return typeActionInferDefault{} -} - -var _ typeAction = typeActionInferDefault{} - -func (typeActionInferDefault) ˉtypeAction() {} - -type typeActionInferWithSize struct { - size int -} - -func newTypeActionInferWithSize(bytes int) typeActionInferWithSize { - return typeActionInferWithSize{size: bytes} -} - -var _ typeAction = typeActionInferWithSize{} - -func (typeActionInferWithSize) ˉtypeAction() {} - -type typeActionAssign struct { - dt ast.DataType -} - -func newTypeActionAssign(expected ast.DataType) typeActionAssign { - return typeActionAssign{dt: expected} -} - -var _ typeAction = typeActionAssign{} - -func (typeActionAssign) ˉtypeAction() {} diff --git a/core/vm/sqlvm/cmd/ast-checker/main.go b/core/vm/sqlvm/cmd/ast-checker/main.go index c02b58f0f..1c90d5028 100644 --- a/core/vm/sqlvm/cmd/ast-checker/main.go +++ b/core/vm/sqlvm/cmd/ast-checker/main.go @@ -7,18 +7,18 @@ import ( "fmt" "os" - "github.com/dexon-foundation/dexon/core/vm/sqlvm/checkers" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/checker" "github.com/dexon-foundation/dexon/core/vm/sqlvm/parser" "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" "github.com/dexon-foundation/dexon/rlp" ) -func create(sql string, o checkers.CheckOptions) int { +func create(sql string, o checker.CheckOptions) int { n, parseErr := parser.Parse([]byte(sql)) if parseErr != nil { fmt.Fprintf(os.Stderr, "Parse error:\n%+v\n", parseErr) } - s, checkErr := checkers.CheckCreate(n, o) + s, checkErr := checker.CheckCreate(n, o) if checkErr != nil { fmt.Fprintf(os.Stderr, "Check error:\n%+v\n", checkErr) } @@ -52,42 +52,49 @@ func decode(ss string) int { return 0 } -func query(ss, sql string, o checkers.CheckOptions) int { +func query(ss, sql string, o checker.CheckOptions) int { fmt.Fprintln(os.Stderr, "Function not implemented") return 1 } -func exec(ss, sql string, o checkers.CheckOptions) int { +func exec(ss, sql string, o checker.CheckOptions) int { fmt.Fprintln(os.Stderr, "Function not implemented") return 1 } func main() { - var noSafeMath bool - var noSafeCast bool - flag.BoolVar(&noSafeMath, "no-safe-math", false, "disable safe math") - flag.BoolVar(&noSafeCast, "no-safe-cast", false, "disable safe cast") + var safeMath bool + var safeCast bool + var constantOnly bool + flag.BoolVar(&safeMath, "safe-math", true, "") + flag.BoolVar(&safeCast, "safe-cast", true, "") + flag.BoolVar(&constantOnly, "constant-only", false, " (default false)") flag.Parse() if flag.NArg() < 1 { fmt.Fprintf(os.Stderr, - "Usage: %s <action> <arguments>\n"+ + "Usage: %s [options] <action> <arguments>\n"+ + "Options:\n"+ + " -help Show options\n"+ "Actions:\n"+ - " create <SQL> returns schema\n"+ - " decode <schema> returns SQL\n"+ - " query <schema> <SQL> returns AST\n"+ - " exec <schema> <SQL> returns AST\n", + " create (SQL) -> schema\n"+ + " decode (schema) -> SQL\n"+ + " query (schema, SQL) -> AST\n"+ + " exec (schema, SQL) -> AST\n", os.Args[0]) os.Exit(1) } - o := checkers.CheckWithSafeMath | checkers.CheckWithSafeCast - if noSafeMath { - o &= ^(checkers.CheckWithSafeMath) + var o checker.CheckOptions + if safeMath { + o |= checker.CheckWithSafeMath } - if noSafeCast { - o &= ^(checkers.CheckWithSafeCast) + if safeCast { + o |= checker.CheckWithSafeCast + } + if constantOnly { + o |= checker.CheckWithConstantOnly } action := flag.Arg(0) -- cgit v1.2.3