From 126145a7eca2c054367abc65291ab1c90c51b998 Mon Sep 17 00:00:00 2001 From: Ting-Wei Lan Date: Mon, 4 Mar 2019 16:06:14 +0800 Subject: core: vm: sqlvm: limit the depth of AST to 1024 Since we traverse an AST by calling functions recursively, we have to protect the parser by limiting the depth of an AST. --- core/vm/sqlvm/ast/constants.go | 5 +++++ core/vm/sqlvm/errors/errors.go | 4 ++++ core/vm/sqlvm/parser/parser.go | 48 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 core/vm/sqlvm/ast/constants.go diff --git a/core/vm/sqlvm/ast/constants.go b/core/vm/sqlvm/ast/constants.go new file mode 100644 index 000000000..a11a182ea --- /dev/null +++ b/core/vm/sqlvm/ast/constants.go @@ -0,0 +1,5 @@ +package ast + +// DepthLimit is the limit of AST depth used to prevent exhausting the stack +// when traversing the tree recursively. +const DepthLimit = 1024 diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go index 886f3beb7..696d061c8 100644 --- a/core/vm/sqlvm/errors/errors.go +++ b/core/vm/sqlvm/errors/errors.go @@ -61,12 +61,14 @@ type ErrorCategory uint16 // Error category starts from 1. Zero value is invalid. const ( ErrorCategoryNil ErrorCategory = iota + ErrorCategoryLimit ErrorCategoryGrammar ErrorCategorySemantic ErrorCategoryRuntime ) var errorCategoryMap = [...]string{ + ErrorCategoryLimit: "limit", ErrorCategoryGrammar: "grammar", ErrorCategorySemantic: "semantic", ErrorCategoryRuntime: "runtime", @@ -82,6 +84,7 @@ type ErrorCode uint16 // Error code starts from 1. Zero value is invalid. const ( ErrorCodeNil ErrorCode = iota + ErrorCodeDepthLimitReached ErrorCodeParser ErrorCodeInvalidIntegerSyntax ErrorCodeInvalidNumberSyntax @@ -108,6 +111,7 @@ const ( ) var errorCodeMap = [...]string{ + ErrorCodeDepthLimitReached: "depth limit reached", ErrorCodeParser: "parser error", ErrorCodeInvalidIntegerSyntax: "invalid integer syntax", ErrorCodeInvalidNumberSyntax: "invalid number syntax", diff --git a/core/vm/sqlvm/parser/parser.go b/core/vm/sqlvm/parser/parser.go index a90fec71c..8ed94e7aa 100644 --- a/core/vm/sqlvm/parser/parser.go +++ b/core/vm/sqlvm/parser/parser.go @@ -9,20 +9,40 @@ import ( "github.com/dexon-foundation/dexon/core/vm/sqlvm/parser/internal" ) -func walkSelfFirst(n ast.Node, v func(ast.Node, []ast.Node)) { +type visitor func(ast.Node, []ast.Node) + +func walkSelfFirst(n ast.Node, v visitor) bool { + return walkSelfFirstWithDepth(n, v, 0) +} + +func walkSelfFirstWithDepth(n ast.Node, v visitor, d int) bool { + if d >= ast.DepthLimit { + return false + } c := n.GetChildren() + r := true v(n, c) for i := range c { - walkSelfFirst(c[i], v) + r = r && walkSelfFirstWithDepth(c[i], v, d+1) } + return r } -func walkChildrenFirst(n ast.Node, v func(ast.Node, []ast.Node)) { +func walkChildrenFirst(n ast.Node, v visitor) bool { + return walkChildrenFirstWithDepth(n, v, 0) +} + +func walkChildrenFirstWithDepth(n ast.Node, v visitor, d int) bool { + if d >= ast.DepthLimit { + return false + } c := n.GetChildren() + r := true for i := range c { - walkChildrenFirst(c[i], v) + r = r && walkChildrenFirstWithDepth(c[i], v, d+1) } v(n, c) + return r } // Parse parses SQL commands text and return an AST. @@ -65,7 +85,8 @@ func Parse(b []byte) ([]ast.Node, error) { if stmts[i] == nil { continue } - walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) { + r := true + r = r && walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) { minBegin := uint32(len(eb)) maxEnd := uint32(0) for _, cn := range append(c, n) { @@ -83,7 +104,7 @@ func Parse(b []byte) ([]ast.Node, error) { n.SetPosition(minBegin) n.SetLength(maxEnd - minBegin) }) - walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) { + r = r && walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) { begin := n.GetPosition() end := begin + n.GetLength() fixedBegin, ok := encMap[begin] @@ -97,9 +118,22 @@ func Parse(b []byte) ([]ast.Node, error) { n.SetPosition(fixedBegin) n.SetLength(fixedEnd - fixedBegin) }) + if !r { + return nil, errors.ErrorList{ + errors.Error{ + Position: 0, + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeDepthLimitReached, + Token: "", + Prefix: "", + Message: fmt.Sprintf("reach syntax tree depth limit %d", + ast.DepthLimit), + }, + } + } } if pigeonErr == nil { - return stmts, pigeonErr + return stmts, nil } // Process errors. -- cgit v1.2.3