diff options
Diffstat (limited to 'core/vm/sqlvm/ast')
-rw-r--r-- | core/vm/sqlvm/ast/ast.go | 98 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/printer.go | 51 |
2 files changed, 111 insertions, 38 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index 505f8653f..3a1d12c19 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -11,6 +11,7 @@ import ( // Node is an interface which should be satisfied by all nodes in AST. type Node interface { + HasPosition() bool GetPosition() uint32 SetPosition(uint32) GetLength() uint32 @@ -24,6 +25,11 @@ type NodeBase struct { Length uint32 `print:"-"` } +// HasPosition returns whether the position is set. +func (n *NodeBase) HasPosition() bool { + return n.Length > 0 +} + // GetPosition returns the offset in bytes where the corresponding token starts. func (n *NodeBase) GetPosition() uint32 { return n.Position @@ -44,6 +50,18 @@ func (n *NodeBase) SetLength(length uint32) { n.Length = length } +// UpdatePosition sets the position of the destination node from two source +// nodes. It is assumed that the destination node consists of multiple tokens +// which can be mapped to child nodes of the destination node. srcLeft should +// be the node representing the left-most token of the destination node, and +// srcRight should represent the right-most token of the destination node. +func UpdatePosition(dest, srcLeft, srcRight Node) { + begin := srcLeft.GetPosition() + end := srcRight.GetPosition() + srcRight.GetLength() + dest.SetPosition(begin) + dest.SetLength(end - begin) +} + // --------------------------------------------------------------------------- // Identifiers // --------------------------------------------------------------------------- @@ -555,6 +573,15 @@ func (n *NotOperatorNode) GetType() DataType { return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare) } +// ParenOperatorNode is a pair of '(' and ')', representing a parenthesized +// expression. +type ParenOperatorNode struct { + TaggedExprNodeBase + UnaryOperatorNode +} + +var _ UnaryOperator = (*ParenOperatorNode)(nil) + // AndOperatorNode is 'AND'. type AndOperatorNode struct { UntaggedExprNodeBase @@ -942,6 +969,7 @@ func (n *InsertWithColumnOptionNode) GetChildren() []Node { for i := 0; i < len(n.Value); i++ { size += len(n.Value[i]) } + nodes := make([]Node, size) idx := 0 for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { @@ -1062,25 +1090,47 @@ var _ Node = (*SelectStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *SelectStmtNode) GetChildren() []Node { - nodes := make([]Node, len(n.Column)+2+len(n.Group)+len(n.Order)+2) + size := len(n.Column) + len(n.Group) + len(n.Order) + if n.Table != nil { + size++ + } + if n.Where != nil { + size++ + } + if n.Limit != nil { + size++ + } + if n.Offset != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { nodes[idx] = n.Column[i] } - nodes[idx] = n.Table - idx++ - nodes[idx] = n.Where - idx++ + if n.Table != nil { + nodes[idx] = n.Table + idx++ + } + if n.Where != nil { + nodes[idx] = n.Where + idx++ + } for i := 0; i < len(n.Group); i, idx = i+1, idx+1 { nodes[idx] = n.Group[i] } for i := 0; i < len(n.Order); i, idx = i+1, idx+1 { nodes[idx] = n.Order[i] } - nodes[idx] = n.Limit - idx++ - nodes[idx] = n.Offset - idx++ + if n.Limit != nil { + nodes[idx] = n.Limit + idx++ + } + if n.Offset != nil { + nodes[idx] = n.Offset + idx++ + } return nodes } @@ -1096,14 +1146,22 @@ var _ Node = (*UpdateStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *UpdateStmtNode) GetChildren() []Node { - nodes := make([]Node, 1+len(n.Assignment)+1) + size := 1 + len(n.Assignment) + if n.Where != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 nodes[idx] = n.Table + idx++ for i := 0; i < len(n.Assignment); i, idx = i+1, idx+1 { nodes[idx] = n.Assignment[i] } - nodes[idx] = n.Where - idx++ + if n.Where != nil { + nodes[idx] = n.Where + idx++ + } return nodes } @@ -1118,6 +1176,9 @@ var _ Node = (*DeleteStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *DeleteStmtNode) GetChildren() []Node { + if n.Where == nil { + return []Node{n.Table} + } return []Node{n.Table, n.Where} } @@ -1188,7 +1249,12 @@ var _ Node = (*CreateIndexStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *CreateIndexStmtNode) GetChildren() []Node { - nodes := make([]Node, 2+len(n.Column)+1) + size := 2 + len(n.Column) + if n.Unique != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 nodes[idx] = n.Index idx++ @@ -1197,7 +1263,9 @@ func (n *CreateIndexStmtNode) GetChildren() []Node { for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { nodes[idx] = n.Column[i] } - nodes[idx] = n.Unique - idx++ + if n.Unique != nil { + nodes[idx] = n.Unique + idx++ + } return nodes } diff --git a/core/vm/sqlvm/ast/printer.go b/core/vm/sqlvm/ast/printer.go index e9b289411..4800fc86b 100644 --- a/core/vm/sqlvm/ast/printer.go +++ b/core/vm/sqlvm/ast/printer.go @@ -35,9 +35,11 @@ func formatString(s string) string { return fmt.Sprintf("%v", []byte(s)) } -func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { - indent := strings.Repeat(base, depth) - indentLong := strings.Repeat(base, depth+1) +func printAST(w io.Writer, n interface{}, s []byte, prefix string, + detail bool, depth int) { + + indent := strings.Repeat(prefix, depth) + indentLong := strings.Repeat(prefix, depth+1) if n == nil { fmt.Fprintf(w, "%snil\n", indent) return @@ -56,20 +58,6 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } name := typeOf.Name() - if op, ok := n.(UnaryOperator); ok { - fmt.Fprintf(w, "%s%s:\n", indent, name) - fmt.Fprintf(w, "%sTarget:\n", indentLong) - printAST(w, op.GetTarget(), depth+2, base, detail) - return - } - if op, ok := n.(BinaryOperator); ok { - fmt.Fprintf(w, "%s%s:\n", indent, name) - fmt.Fprintf(w, "%sObject:\n", indentLong) - printAST(w, op.GetObject(), depth+2, base, detail) - fmt.Fprintf(w, "%sSubject:\n", indentLong) - printAST(w, op.GetSubject(), depth+2, base, detail) - return - } if stringer, ok := n.(fmt.Stringer); ok { s := stringer.String() fmt.Fprintf(w, "%s%s\n", indent, formatString(s)) @@ -92,7 +80,7 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { fmt.Fprintf(w, "%s[\n", indent) for i := 0; i < l; i++ { v := valueOf.Index(i) - printAST(w, v.Interface(), depth+1, base, detail) + printAST(w, v.Interface(), s, prefix, detail, depth+1) } fmt.Fprintf(w, "%s]\n", indent) return @@ -124,15 +112,30 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } } collect(typeOf, valueOf) + + var position string + if node, ok := n.(Node); ok { + begin := node.GetPosition() + length := node.GetLength() + if node.HasPosition() { + end := begin + length - 1 + token := s[begin : begin+length] + position = fmt.Sprintf("%d-%d %s", + begin, end, strconv.Quote(string(token))) + } else { + position = "no position info" + } + } + fmt.Fprintf(w, "%s%s", indent, name) if len(fields) == 0 { - fmt.Fprintf(w, " {}\n") + fmt.Fprintf(w, " {} // %s\n", position) return } - fmt.Fprintf(w, " {\n") + fmt.Fprintf(w, " { // %s\n", position) for i := 0; i < len(fields); i++ { fmt.Fprintf(w, "%s%s:\n", indentLong, fields[i].name) - printAST(w, fields[i].value, depth+2, base, detail) + printAST(w, fields[i].value, s, prefix, detail, depth+2) } fmt.Fprintf(w, "%s}\n", indent) return @@ -141,6 +144,8 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } // PrintAST prints AST for debugging. -func PrintAST(w io.Writer, n interface{}, indent string, detail bool) { - printAST(w, n, 0, indent, detail) +func PrintAST(output io.Writer, node interface{}, source []byte, + indent string, detail bool) { + + printAST(output, node, source, indent, detail, 0) } |