From 126145a7eca2c054367abc65291ab1c90c51b998 Mon Sep 17 00:00:00 2001
From: Ting-Wei Lan <tingwei.lan@cobinhood.com>
Date: Mon, 4 Mar 2019 16:06:14 +0800
Subject: core: vm: sqlvm: limit the depth of AST to 1024

Since we traverse an AST by calling functions recursively, we have to
protect the parser by limiting the depth of an AST.
---
 core/vm/sqlvm/ast/constants.go |  5 +++++
 core/vm/sqlvm/errors/errors.go |  4 ++++
 core/vm/sqlvm/parser/parser.go | 48 ++++++++++++++++++++++++++++++++++++------
 3 files changed, 50 insertions(+), 7 deletions(-)
 create mode 100644 core/vm/sqlvm/ast/constants.go

(limited to 'core/vm/sqlvm')

diff --git a/core/vm/sqlvm/ast/constants.go b/core/vm/sqlvm/ast/constants.go
new file mode 100644
index 000000000..a11a182ea
--- /dev/null
+++ b/core/vm/sqlvm/ast/constants.go
@@ -0,0 +1,5 @@
+package ast
+
+// DepthLimit is the limit of AST depth used to prevent exhausting the stack
+// when traversing the tree recursively.
+const DepthLimit = 1024
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index 886f3beb7..696d061c8 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -61,12 +61,14 @@ type ErrorCategory uint16
 // Error category starts from 1. Zero value is invalid.
 const (
 	ErrorCategoryNil ErrorCategory = iota
+	ErrorCategoryLimit
 	ErrorCategoryGrammar
 	ErrorCategorySemantic
 	ErrorCategoryRuntime
 )
 
 var errorCategoryMap = [...]string{
+	ErrorCategoryLimit:    "limit",
 	ErrorCategoryGrammar:  "grammar",
 	ErrorCategorySemantic: "semantic",
 	ErrorCategoryRuntime:  "runtime",
@@ -82,6 +84,7 @@ type ErrorCode uint16
 // Error code starts from 1. Zero value is invalid.
 const (
 	ErrorCodeNil ErrorCode = iota
+	ErrorCodeDepthLimitReached
 	ErrorCodeParser
 	ErrorCodeInvalidIntegerSyntax
 	ErrorCodeInvalidNumberSyntax
@@ -108,6 +111,7 @@ const (
 )
 
 var errorCodeMap = [...]string{
+	ErrorCodeDepthLimitReached:             "depth limit reached",
 	ErrorCodeParser:                        "parser error",
 	ErrorCodeInvalidIntegerSyntax:          "invalid integer syntax",
 	ErrorCodeInvalidNumberSyntax:           "invalid number syntax",
diff --git a/core/vm/sqlvm/parser/parser.go b/core/vm/sqlvm/parser/parser.go
index a90fec71c..8ed94e7aa 100644
--- a/core/vm/sqlvm/parser/parser.go
+++ b/core/vm/sqlvm/parser/parser.go
@@ -9,20 +9,40 @@ import (
 	"github.com/dexon-foundation/dexon/core/vm/sqlvm/parser/internal"
 )
 
-func walkSelfFirst(n ast.Node, v func(ast.Node, []ast.Node)) {
+type visitor func(ast.Node, []ast.Node)
+
+func walkSelfFirst(n ast.Node, v visitor) bool {
+	return walkSelfFirstWithDepth(n, v, 0)
+}
+
+func walkSelfFirstWithDepth(n ast.Node, v visitor, d int) bool {
+	if d >= ast.DepthLimit {
+		return false
+	}
 	c := n.GetChildren()
+	r := true
 	v(n, c)
 	for i := range c {
-		walkSelfFirst(c[i], v)
+		r = r && walkSelfFirstWithDepth(c[i], v, d+1)
 	}
+	return r
 }
 
-func walkChildrenFirst(n ast.Node, v func(ast.Node, []ast.Node)) {
+func walkChildrenFirst(n ast.Node, v visitor) bool {
+	return walkChildrenFirstWithDepth(n, v, 0)
+}
+
+func walkChildrenFirstWithDepth(n ast.Node, v visitor, d int) bool {
+	if d >= ast.DepthLimit {
+		return false
+	}
 	c := n.GetChildren()
+	r := true
 	for i := range c {
-		walkChildrenFirst(c[i], v)
+		r = r && walkChildrenFirstWithDepth(c[i], v, d+1)
 	}
 	v(n, c)
+	return r
 }
 
 // Parse parses SQL commands text and return an AST.
@@ -65,7 +85,8 @@ func Parse(b []byte) ([]ast.Node, error) {
 		if stmts[i] == nil {
 			continue
 		}
-		walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) {
+		r := true
+		r = r && walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) {
 			minBegin := uint32(len(eb))
 			maxEnd := uint32(0)
 			for _, cn := range append(c, n) {
@@ -83,7 +104,7 @@ func Parse(b []byte) ([]ast.Node, error) {
 			n.SetPosition(minBegin)
 			n.SetLength(maxEnd - minBegin)
 		})
-		walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) {
+		r = r && walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) {
 			begin := n.GetPosition()
 			end := begin + n.GetLength()
 			fixedBegin, ok := encMap[begin]
@@ -97,9 +118,22 @@ func Parse(b []byte) ([]ast.Node, error) {
 			n.SetPosition(fixedBegin)
 			n.SetLength(fixedEnd - fixedBegin)
 		})
+		if !r {
+			return nil, errors.ErrorList{
+				errors.Error{
+					Position: 0,
+					Category: errors.ErrorCategoryLimit,
+					Code:     errors.ErrorCodeDepthLimitReached,
+					Token:    "",
+					Prefix:   "",
+					Message: fmt.Sprintf("reach syntax tree depth limit %d",
+						ast.DepthLimit),
+				},
+			}
+		}
 	}
 	if pigeonErr == nil {
-		return stmts, pigeonErr
+		return stmts, nil
 	}
 
 	// Process errors.
-- 
cgit v1.2.3