From 7628271e5d363a9a7283efb6f7d8b1b52e392e45 Mon Sep 17 00:00:00 2001 From: Ting-Wei Lan Date: Thu, 21 Feb 2019 18:28:32 +0800 Subject: core: vm: sqlvm: fill source code position in AST nodes Now all AST nodes should have position information recorded during parsing. These fields are intended to be used to report errors and make debugging easier. However, precise location of each token is currently unavailable. It can be done in the future if it becomes necessary. To make it easier to traverse an AST, GetChildren is modified to skip nil nodes in the output. This means callers of GetChildren don't have to check for nil in returned slices. AST printer is modified to print the position and the corresponding source code token. A few special handling for interfaces are removed because reflection works better for structs. --- core/vm/sqlvm/ast/ast.go | 98 +++++++++++++++++++++++++++++++++++++------- core/vm/sqlvm/ast/printer.go | 51 ++++++++++++----------- 2 files changed, 111 insertions(+), 38 deletions(-) (limited to 'core/vm/sqlvm/ast') diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index 505f8653f..3a1d12c19 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -11,6 +11,7 @@ import ( // Node is an interface which should be satisfied by all nodes in AST. type Node interface { + HasPosition() bool GetPosition() uint32 SetPosition(uint32) GetLength() uint32 @@ -24,6 +25,11 @@ type NodeBase struct { Length uint32 `print:"-"` } +// HasPosition returns whether the position is set. +func (n *NodeBase) HasPosition() bool { + return n.Length > 0 +} + // GetPosition returns the offset in bytes where the corresponding token starts. func (n *NodeBase) GetPosition() uint32 { return n.Position @@ -44,6 +50,18 @@ func (n *NodeBase) SetLength(length uint32) { n.Length = length } +// UpdatePosition sets the position of the destination node from two source +// nodes. It is assumed that the destination node consists of multiple tokens +// which can be mapped to child nodes of the destination node. srcLeft should +// be the node representing the left-most token of the destination node, and +// srcRight should represent the right-most token of the destination node. +func UpdatePosition(dest, srcLeft, srcRight Node) { + begin := srcLeft.GetPosition() + end := srcRight.GetPosition() + srcRight.GetLength() + dest.SetPosition(begin) + dest.SetLength(end - begin) +} + // --------------------------------------------------------------------------- // Identifiers // --------------------------------------------------------------------------- @@ -555,6 +573,15 @@ func (n *NotOperatorNode) GetType() DataType { return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare) } +// ParenOperatorNode is a pair of '(' and ')', representing a parenthesized +// expression. +type ParenOperatorNode struct { + TaggedExprNodeBase + UnaryOperatorNode +} + +var _ UnaryOperator = (*ParenOperatorNode)(nil) + // AndOperatorNode is 'AND'. type AndOperatorNode struct { UntaggedExprNodeBase @@ -942,6 +969,7 @@ func (n *InsertWithColumnOptionNode) GetChildren() []Node { for i := 0; i < len(n.Value); i++ { size += len(n.Value[i]) } + nodes := make([]Node, size) idx := 0 for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { @@ -1062,25 +1090,47 @@ var _ Node = (*SelectStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *SelectStmtNode) GetChildren() []Node { - nodes := make([]Node, len(n.Column)+2+len(n.Group)+len(n.Order)+2) + size := len(n.Column) + len(n.Group) + len(n.Order) + if n.Table != nil { + size++ + } + if n.Where != nil { + size++ + } + if n.Limit != nil { + size++ + } + if n.Offset != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { nodes[idx] = n.Column[i] } - nodes[idx] = n.Table - idx++ - nodes[idx] = n.Where - idx++ + if n.Table != nil { + nodes[idx] = n.Table + idx++ + } + if n.Where != nil { + nodes[idx] = n.Where + idx++ + } for i := 0; i < len(n.Group); i, idx = i+1, idx+1 { nodes[idx] = n.Group[i] } for i := 0; i < len(n.Order); i, idx = i+1, idx+1 { nodes[idx] = n.Order[i] } - nodes[idx] = n.Limit - idx++ - nodes[idx] = n.Offset - idx++ + if n.Limit != nil { + nodes[idx] = n.Limit + idx++ + } + if n.Offset != nil { + nodes[idx] = n.Offset + idx++ + } return nodes } @@ -1096,14 +1146,22 @@ var _ Node = (*UpdateStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *UpdateStmtNode) GetChildren() []Node { - nodes := make([]Node, 1+len(n.Assignment)+1) + size := 1 + len(n.Assignment) + if n.Where != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 nodes[idx] = n.Table + idx++ for i := 0; i < len(n.Assignment); i, idx = i+1, idx+1 { nodes[idx] = n.Assignment[i] } - nodes[idx] = n.Where - idx++ + if n.Where != nil { + nodes[idx] = n.Where + idx++ + } return nodes } @@ -1118,6 +1176,9 @@ var _ Node = (*DeleteStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *DeleteStmtNode) GetChildren() []Node { + if n.Where == nil { + return []Node{n.Table} + } return []Node{n.Table, n.Where} } @@ -1188,7 +1249,12 @@ var _ Node = (*CreateIndexStmtNode)(nil) // GetChildren returns a list of child nodes used for traversing. func (n *CreateIndexStmtNode) GetChildren() []Node { - nodes := make([]Node, 2+len(n.Column)+1) + size := 2 + len(n.Column) + if n.Unique != nil { + size++ + } + + nodes := make([]Node, size) idx := 0 nodes[idx] = n.Index idx++ @@ -1197,7 +1263,9 @@ func (n *CreateIndexStmtNode) GetChildren() []Node { for i := 0; i < len(n.Column); i, idx = i+1, idx+1 { nodes[idx] = n.Column[i] } - nodes[idx] = n.Unique - idx++ + if n.Unique != nil { + nodes[idx] = n.Unique + idx++ + } return nodes } diff --git a/core/vm/sqlvm/ast/printer.go b/core/vm/sqlvm/ast/printer.go index e9b289411..4800fc86b 100644 --- a/core/vm/sqlvm/ast/printer.go +++ b/core/vm/sqlvm/ast/printer.go @@ -35,9 +35,11 @@ func formatString(s string) string { return fmt.Sprintf("%v", []byte(s)) } -func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { - indent := strings.Repeat(base, depth) - indentLong := strings.Repeat(base, depth+1) +func printAST(w io.Writer, n interface{}, s []byte, prefix string, + detail bool, depth int) { + + indent := strings.Repeat(prefix, depth) + indentLong := strings.Repeat(prefix, depth+1) if n == nil { fmt.Fprintf(w, "%snil\n", indent) return @@ -56,20 +58,6 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } name := typeOf.Name() - if op, ok := n.(UnaryOperator); ok { - fmt.Fprintf(w, "%s%s:\n", indent, name) - fmt.Fprintf(w, "%sTarget:\n", indentLong) - printAST(w, op.GetTarget(), depth+2, base, detail) - return - } - if op, ok := n.(BinaryOperator); ok { - fmt.Fprintf(w, "%s%s:\n", indent, name) - fmt.Fprintf(w, "%sObject:\n", indentLong) - printAST(w, op.GetObject(), depth+2, base, detail) - fmt.Fprintf(w, "%sSubject:\n", indentLong) - printAST(w, op.GetSubject(), depth+2, base, detail) - return - } if stringer, ok := n.(fmt.Stringer); ok { s := stringer.String() fmt.Fprintf(w, "%s%s\n", indent, formatString(s)) @@ -92,7 +80,7 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { fmt.Fprintf(w, "%s[\n", indent) for i := 0; i < l; i++ { v := valueOf.Index(i) - printAST(w, v.Interface(), depth+1, base, detail) + printAST(w, v.Interface(), s, prefix, detail, depth+1) } fmt.Fprintf(w, "%s]\n", indent) return @@ -124,15 +112,30 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } } collect(typeOf, valueOf) + + var position string + if node, ok := n.(Node); ok { + begin := node.GetPosition() + length := node.GetLength() + if node.HasPosition() { + end := begin + length - 1 + token := s[begin : begin+length] + position = fmt.Sprintf("%d-%d %s", + begin, end, strconv.Quote(string(token))) + } else { + position = "no position info" + } + } + fmt.Fprintf(w, "%s%s", indent, name) if len(fields) == 0 { - fmt.Fprintf(w, " {}\n") + fmt.Fprintf(w, " {} // %s\n", position) return } - fmt.Fprintf(w, " {\n") + fmt.Fprintf(w, " { // %s\n", position) for i := 0; i < len(fields); i++ { fmt.Fprintf(w, "%s%s:\n", indentLong, fields[i].name) - printAST(w, fields[i].value, depth+2, base, detail) + printAST(w, fields[i].value, s, prefix, detail, depth+2) } fmt.Fprintf(w, "%s}\n", indent) return @@ -141,6 +144,8 @@ func printAST(w io.Writer, n interface{}, depth int, base string, detail bool) { } // PrintAST prints AST for debugging. -func PrintAST(w io.Writer, n interface{}, indent string, detail bool) { - printAST(w, n, 0, indent, detail) +func PrintAST(output io.Writer, node interface{}, source []byte, + indent string, detail bool) { + + printAST(output, node, source, indent, detail, 0) } -- cgit v1.2.3