aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/vm/sqlvm/ast/ast.go21
-rw-r--r--core/vm/sqlvm/ast/printer.go16
-rw-r--r--core/vm/sqlvm/cmd/ast-printer/main.go2
-rw-r--r--core/vm/sqlvm/parser/parser.go6
4 files changed, 26 insertions, 19 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go
index 8071fbe19..71e09f21f 100644
--- a/core/vm/sqlvm/ast/ast.go
+++ b/core/vm/sqlvm/ast/ast.go
@@ -19,6 +19,8 @@ type Node interface {
SetPosition(uint32)
GetLength() uint32
SetLength(uint32)
+ GetToken() []byte
+ SetToken([]byte)
GetChildren() []Node
}
@@ -26,6 +28,7 @@ type Node interface {
type NodeBase struct {
Position uint32 `print:"-"`
Length uint32 `print:"-"`
+ Token []byte `print:"-"`
}
// HasPosition returns whether the position is set.
@@ -53,16 +56,14 @@ func (n *NodeBase) SetLength(length uint32) {
n.Length = length
}
-// UpdatePosition sets the position of the destination node from two source
-// nodes. It is assumed that the destination node consists of multiple tokens
-// which can be mapped to child nodes of the destination node. srcLeft should
-// be the node representing the left-most token of the destination node, and
-// srcRight should represent the right-most token of the destination node.
-func UpdatePosition(dest, srcLeft, srcRight Node) {
- begin := srcLeft.GetPosition()
- end := srcRight.GetPosition() + srcRight.GetLength()
- dest.SetPosition(begin)
- dest.SetLength(end - begin)
+// GetToken returns the corresponding token of the node.
+func (n *NodeBase) GetToken() []byte {
+ return n.Token
+}
+
+// SetToken sets the corresponding token of the node.
+func (n *NodeBase) SetToken(token []byte) {
+ n.Token = token
}
// ---------------------------------------------------------------------------
diff --git a/core/vm/sqlvm/ast/printer.go b/core/vm/sqlvm/ast/printer.go
index ac4c43fe0..92683bef1 100644
--- a/core/vm/sqlvm/ast/printer.go
+++ b/core/vm/sqlvm/ast/printer.go
@@ -35,8 +35,8 @@ func formatString(s string) string {
return fmt.Sprintf("%v", []byte(s))
}
-func printAST(w io.Writer, n interface{}, s []byte, prefix string,
- detail bool, depth int) (int, error) {
+func printAST(w io.Writer, n interface{}, prefix string, detail bool,
+ depth int) (int, error) {
indent := strings.Repeat(prefix, depth)
indentLong := strings.Repeat(prefix, depth+1)
@@ -80,7 +80,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string,
}
for i := 0; i < l; i++ {
v := valueOf.Index(i)
- b, err = printAST(w, v.Interface(), s, prefix, detail, depth+1)
+ b, err = printAST(w, v.Interface(), prefix, detail, depth+1)
bytesWritten += b
if err != nil {
return bytesWritten, err
@@ -124,7 +124,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string,
length := node.GetLength()
if node.HasPosition() {
end := begin + length - 1
- token := s[begin : begin+length]
+ token := node.GetToken()
position = fmt.Sprintf("%d-%d %s",
begin, end, strconv.Quote(string(token)))
} else {
@@ -154,7 +154,7 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string,
if err != nil {
return bytesWritten, err
}
- b, err = printAST(w, fields[i].value, s, prefix, detail, depth+2)
+ b, err = printAST(w, fields[i].value, prefix, detail, depth+2)
bytesWritten += b
if err != nil {
return bytesWritten, err
@@ -168,8 +168,8 @@ func printAST(w io.Writer, n interface{}, s []byte, prefix string,
}
// PrintAST prints AST for debugging.
-func PrintAST(output io.Writer, node interface{}, source []byte,
- indent string, detail bool) (int, error) {
+func PrintAST(output io.Writer, node interface{}, indent string, detail bool) (
+ int, error) {
- return printAST(output, node, source, indent, detail, 0)
+ return printAST(output, node, indent, detail, 0)
}
diff --git a/core/vm/sqlvm/cmd/ast-printer/main.go b/core/vm/sqlvm/cmd/ast-printer/main.go
index d62cb4fc8..197608b6d 100644
--- a/core/vm/sqlvm/cmd/ast-printer/main.go
+++ b/core/vm/sqlvm/cmd/ast-printer/main.go
@@ -18,7 +18,7 @@ func main() {
fmt.Fprintf(os.Stderr, "detail: %t\n", detail)
s := []byte(flag.Arg(0))
n, parseErr := parser.Parse(s)
- b, printErr := ast.PrintAST(os.Stdout, n, s, " ", detail)
+ b, printErr := ast.PrintAST(os.Stdout, n, " ", detail)
if parseErr != nil {
fmt.Fprintf(os.Stderr, "Parse error:\n%+v\n", parseErr)
}
diff --git a/core/vm/sqlvm/parser/parser.go b/core/vm/sqlvm/parser/parser.go
index 85862d499..8a83cfceb 100644
--- a/core/vm/sqlvm/parser/parser.go
+++ b/core/vm/sqlvm/parser/parser.go
@@ -76,6 +76,11 @@ func Parse(b []byte) ([]ast.StmtNode, error) {
options := []internal.Option{internal.Recover(false)}
root, pigeonErr := internal.Parse("", eb, options...)
+ // Copy the input text. We will put references to the source code on AST
+ // nodes, so we have to make our own copy to prevent the AST from being
+ // broken by the caller if the input byte slice was modified afterwards.
+ b = append([]byte{}, b...)
+
// Process the AST.
var stmts []ast.StmtNode
if root != nil {
@@ -117,6 +122,7 @@ func Parse(b []byte) ([]ast.StmtNode, error) {
}
n.SetPosition(fixedBegin)
n.SetLength(fixedEnd - fixedBegin)
+ n.SetToken(b[fixedBegin:fixedEnd])
})
if !r {
return nil, errors.ErrorList{