diff options
-rw-r--r-- | core/vm/sqlvm/ast/ast.go | 21 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/printer.go | 16 | ||||
-rw-r--r-- | core/vm/sqlvm/cmd/ast-printer/main.go | 2 | ||||
-rw-r--r-- | core/vm/sqlvm/parser/parser.go | 6 |
4 files changed, 26 insertions, 19 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index 8071fbe19..71e09f21f 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -19,6 +19,8 @@ type Node interface { SetPosition(uint32) GetLength() uint32 SetLength(uint32) + GetToken() []byte + SetToken([]byte) GetChildren() []Node } @@ -26,6 +28,7 @@ type Node interface { type NodeBase struct { Position uint32 `print:"-"` Length uint32 `print:"-"` + Token []byte `print:"-"` } // HasPosition returns whether the position is set. @@ -53,16 +56,14 @@ func (n *NodeBase) SetLength(length uint32) { n.Length = length } -// UpdatePosition sets the position of the destination node from two source -// nodes. It is assumed that the destination node consists of multiple tokens -// which can be mapped to child nodes of the destination node. srcLeft should -// be the node representing the left-most token of the destination node, and -// srcRight should represent the right-most token of the destination node. -func UpdatePosition(dest, srcLeft, srcRight Node) { - begin := srcLeft.GetPosition() - end := srcRight.GetPosition() + srcRight.GetLength() - dest.SetPosition(begin) - dest.SetLength(end - begin) +// GetToken returns the corresponding token of the node. +func (n *NodeBase) GetToken() []byte { + return n.Token +} + +// SetToken sets the corresponding token of the node. +func (n *NodeBase) SetToken(token []byte) { + n.Token = token } // --------------------------------------------------------------------------- diff --git a/core/vm/sqlvm/ast/printer.go b/core/vm/sqlvm/ast/printer.go index ac4c43fe0..92683bef1 100644 --- a/core/vm/sqlvm/ast/printer.go +++ b/core/vm/sqlvm/ast/printer.go @@ -35,8 +35,8 @@ func formatString(s string) string { return fmt.Sprintf("%v", []byte(s)) } -func printAST(w io.Writer, n interface{}, s []byte, prefix string, - detail bool, depth int) (int, error) { +func printAST(w io.Writer, n interface{}, prefix string, detail bool, + depth int) (int, error) { indent := strings.Repeat(prefix, depth) indentLong := strings.Repeat(prefix, depth+1) @@ -80,7 +80,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string, } for i := 0; i < l; i++ { v := valueOf.Index(i) - b, err = printAST(w, v.Interface(), s, prefix, detail, depth+1) + b, err = printAST(w, v.Interface(), prefix, detail, depth+1) bytesWritten += b if err != nil { return bytesWritten, err @@ -124,7 +124,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string, length := node.GetLength() if node.HasPosition() { end := begin + length - 1 - token := s[begin : begin+length] + token := node.GetToken() position = fmt.Sprintf("%d-%d %s", begin, end, strconv.Quote(string(token))) } else { @@ -154,7 +154,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string, if err != nil { return bytesWritten, err } - b, err = printAST(w, fields[i].value, s, prefix, detail, depth+2) + b, err = printAST(w, fields[i].value, prefix, detail, depth+2) bytesWritten += b if err != nil { return bytesWritten, err @@ -168,8 +168,8 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string, } // PrintAST prints AST for debugging. -func PrintAST(output io.Writer, node interface{}, source []byte, - indent string, detail bool) (int, error) { +func PrintAST(output io.Writer, node interface{}, indent string, detail bool) ( + int, error) { - return printAST(output, node, source, indent, detail, 0) + return printAST(output, node, indent, detail, 0) } diff --git a/core/vm/sqlvm/cmd/ast-printer/main.go b/core/vm/sqlvm/cmd/ast-printer/main.go index d62cb4fc8..197608b6d 100644 --- a/core/vm/sqlvm/cmd/ast-printer/main.go +++ b/core/vm/sqlvm/cmd/ast-printer/main.go @@ -18,7 +18,7 @@ func main() { fmt.Fprintf(os.Stderr, "detail: %t\n", detail) s := []byte(flag.Arg(0)) n, parseErr := parser.Parse(s) - b, printErr := ast.PrintAST(os.Stdout, n, s, " ", detail) + b, printErr := ast.PrintAST(os.Stdout, n, " ", detail) if parseErr != nil { fmt.Fprintf(os.Stderr, "Parse error:\n%+v\n", parseErr) } diff --git a/core/vm/sqlvm/parser/parser.go b/core/vm/sqlvm/parser/parser.go index 85862d499..8a83cfceb 100644 --- a/core/vm/sqlvm/parser/parser.go +++ b/core/vm/sqlvm/parser/parser.go @@ -76,6 +76,11 @@ func Parse(b []byte) ([]ast.StmtNode, error) { options := []internal.Option{internal.Recover(false)} root, pigeonErr := internal.Parse("", eb, options...) + // Copy the input text. We will put references to the source code on AST + // nodes, so we have to make our own copy to prevent the AST from being + // broken by the caller if the input byte slice was modified afterwards. + b = append([]byte{}, b...) + // Process the AST. var stmts []ast.StmtNode if root != nil { @@ -117,6 +122,7 @@ func Parse(b []byte) ([]ast.StmtNode, error) { } n.SetPosition(fixedBegin) n.SetLength(fixedEnd - fixedBegin) + n.SetToken(b[fixedBegin:fixedEnd]) }) if !r { return nil, errors.ErrorList{ |