package parser

import (


// Parser was generated with pigeon v1.0.0-99-gbb0192c.
//go:generate pigeon -no-recover -o grammar.go grammar.peg
//go:generate sh -c "sed -f grammar.sed grammar.go > grammar_new.go"
//go:generate mv grammar_new.go grammar.go
//go:generate goimports -w grammar.go

func prepend(x interface{}, xs []interface{}) []interface{} {
    return append([]interface{}{x}, xs...)

func assertSlice(x interface{}) []interface{} {
    if x == nil {
        return nil
    return x.([]interface{})

func assertNodeSlice(x interface{}) []ast.Node {
    xs := assertSlice(x)
    ns := make([]ast.Node, len(xs))
    for i := 0; i < len(xs); i++ {
        if xs[i] != nil {
            ns[i] = xs[i].(ast.Node)
    return ns

func assertExprSlice(x interface{}) []ast.ExprNode {
    xs := assertSlice(x)
    es := make([]ast.ExprNode, len(xs))
    for i := 0; i < len(xs); i++ {
        if xs[i] != nil {
            es[i] = xs[i].(ast.ExprNode)
    return es

// TODO(wmin0): finish it.
func isAddress(h []byte) bool {
    return false

func hexToInteger(h []byte) *ast.IntegerValueNode {
    d := decimal.Zero
    l := len(h)
    base := decimal.New(16, 0)
    for idx, b := range h {
        i, err := strconv.ParseInt(string([]byte{b}), 16, 32)
        if err != nil {
            panic(fmt.Sprintf("invalid hex digit %s: %v", []byte{b}, err))
        d = d.Add(
            decimal.New(i, 0).
                Mul(base.Pow(decimal.New(int64(l-idx-1), 0))),
    node := &ast.IntegerValueNode{}
    node.IsAddress = isAddress(h)
    node.V = d
    return node

func hexToBytes(h []byte) []byte {
    bs := make([]byte, hex.DecodedLen(len(h)))
    _, err := hex.Decode(bs, h)
    if err != nil {
        panic(fmt.Sprintf("invalid hex string %s: %v", h, err))
    return bs

func convertNumError(err error) errors.ErrorCode {
    if err == nil {
        return errors.ErrorCodeNil
    switch err.(*strconv.NumError).Err {
    case strconv.ErrSyntax:
        return errors.ErrorCodeInvalidIntegerSyntax
    case strconv.ErrRange:
        return errors.ErrorCodeIntegerOutOfRange
    panic(fmt.Sprintf("unknown NumError: %v", err))

func convertDecimalError(err error) errors.ErrorCode {
    if err == nil {
        return errors.ErrorCodeNil
    errStr := err.Error()
    if strings.HasSuffix(errStr, "decimal: fractional part too long") {
        return errors.ErrorCodeFractionalPartTooLong
    } else if strings.HasSuffix(errStr, "decimal: exponent is not numeric") {
        return errors.ErrorCodeInvalidNumberSyntax
    } else if strings.HasSuffix(errStr, "decimal: too many .s") {
        return errors.ErrorCodeInvalidNumberSyntax
    panic(fmt.Sprintf("unknown decimal error: %v", err))

func toUint(b []byte) (uint32, errors.ErrorCode) {
    i, err := strconv.ParseUint(string(b), 10, 32)
    return uint32(i), convertNumError(err)

func toDecimal(b []byte) (decimal.Decimal, errors.ErrorCode) {
    if len(b) > 0 && b[0] == byte('.') {
        b = append([]byte{'0'}, b...)
    d, err := decimal.NewFromString(string(b))
    return d, convertDecimalError(err)

func toLower(b []byte) []byte {
    return bytes.ToLower(b)

func joinBytes(x []interface{}) []byte {
    bs := []byte{}
    for _, b := range x {
        bs = append(bs, b.([]byte)...)
    return bs

func opSetSubject(op ast.BinaryOperator, s ast.ExprNode) ast.BinaryOperator {
    return op

func opSetObject(op ast.BinaryOperator, o ast.ExprNode) ast.BinaryOperator {
    return op

func opSetTarget(op ast.UnaryOperator, t ast.ExprNode) ast.UnaryOperator {
    return op

func joinOperator(x ast.ExprNode, o ast.ExprNode) {
    switch op := x.(type) {
    case ast.UnaryOperator:
        joinOperator(op.GetTarget(), o)
    case ast.BinaryOperator:
    case *ast.CastOperatorNode:
        op.SourceExpr = o
    case *ast.InOperatorNode:
        op.Left = o
        panic(fmt.Sprintf("unable to join operators %T and %T", x, o))

func rightJoinOperators(o ast.ExprNode, x []ast.ExprNode) ast.ExprNode {
    if len(x) == 0 {
        return o
    l := len(x)
    for idx := 0; idx < l-1; idx++ {
        joinOperator(x[idx+1], x[idx])
    joinOperator(x[0], o)
    return x[l-1]

func sanitizeBadEscape(s []byte) []byte {
    o := bytes.Buffer{}
    for _, b := range s {
        if b >= 0x20 && b <= 0x7e && b != '\'' {
        } else {
            o.WriteString(fmt.Sprintf("<%02X>", b))
    return o.Bytes()

func decodeString(s []byte) []byte {
    o := bytes.Buffer{}
    for r, i, size := rune(0), 0, 0; i < len(s); i += size {
        r, size = utf8.DecodeRune(s[i:])
        if r > 0xff {
            panic(fmt.Sprintf("invalid encoded rune U+%04X", r))
    return o.Bytes()

func resolveString(s []byte) ([]byte, []byte, errors.ErrorCode) {
    s = decodeString(s)
    o := bytes.Buffer{}
    for i, size := 0, 0; i < len(s); i += size {
        if s[i] == '\\' {
            if i+1 >= len(s) {
                panic("trailing backslash in string literal")
            switch s[i+1] {
            case '\n':
                size = 2

            case '\\':
                size = 2
            case '\'':
                size = 2
            case '"':
                size = 2
            case 'b':
                size = 2
            case 'f':
                size = 2
            case 'n':
                size = 2
            case 'r':
                size = 2
            case 't':
                size = 2
            case 'v':
                size = 2

            case 'x':
                if i+3 >= len(s) {
                    return nil, s[i:], errors.ErrorCodeEscapeSequenceTooShort
                b, err := strconv.ParseUint(string(s[i+2:i+4]), 16, 8)
                if err != nil {
                    return nil, s[i : i+4], convertNumError(err)
                size = 4

            case 'u':
                if i+5 >= len(s) {
                    return nil, s[i:], errors.ErrorCodeEscapeSequenceTooShort
                u, err := strconv.ParseUint(string(s[i+2:i+6]), 16, 16)
                if err != nil {
                    return nil, s[i : i+6], convertNumError(err)
                if u >= 0xd800 && u <= 0xdfff {
                    return nil, s[i : i+6], errors.ErrorCodeInvalidUnicodeCodePoint
                size = 6

            case 'U':
                if i+9 >= len(s) {
                    return nil, s[i:], errors.ErrorCodeEscapeSequenceTooShort
                r, err := strconv.ParseUint(string(s[i+2:i+10]), 16, 32)
                if err != nil {
                    return nil, s[i : i+10], convertNumError(err)
                if r > 0x10ffff || (r >= 0xd800 && r <= 0xdfff) {
                    return nil, s[i : i+10], errors.ErrorCodeInvalidUnicodeCodePoint
                size = 10

                return nil, s[i : i+2], errors.ErrorCodeUnknownEscapeSequence
        } else {
            size = 1
    return o.Bytes(), nil, errors.ErrorCodeNil

// Parse parses SQL commands text and return an AST.
func Parse(b []byte, o ...Option) ([]ast.Node, error) {
    // The string sent from the caller is not guaranteed to be valid UTF-8.
    // We don't really care non-ASCII characters in the string because all
    // keywords and special symbols are defined in ASCII. Therefore, as long
    // as the encoding is compatible with ASCII, we can process text with
    // unknown encoding.
    // However, pigeon requires input text to be valid UTF-8, throwing an error
    // and exiting early when it cannot decode the input as UTF-8. In order to
    // workaround it, we preprocess the input text by assuming each byte value
    // is a Unicode code point and encoding the input text as UTF-8.
    // This means that the byte offset reported by pigeon is wrong. We have to
    // scan the the error list and the AST to fix positions in these structs
    // before returning them to the caller.

    // Encode the input text.
    encBuf := bytes.Buffer{}
    encMap := map[uint32]uint32{}
    for i, c := range b {
        encMap[uint32(encBuf.Len())] = uint32(i)
    encMap[uint32(encBuf.Len())] = uint32(len(b))

    // Prepare arguments and call the parser.
    eb := encBuf.Bytes()
    options := append([]Option{Recover(false)}, o...)
    root, pigeonErr := parse("", eb, options...)
    stmts := assertNodeSlice(root)

    // Process the AST.
    if pigeonErr == nil {
        return stmts, pigeonErr

    // Process errors.
    pigeonErrList := pigeonErr.(errList)
    sqlvmErrList := make(errors.ErrorList, len(pigeonErrList))
    for i := range pigeonErrList {
        parserErr := pigeonErrList[i].(*parserError)
        if sqlvmErr, ok := parserErr.Inner.(errors.Error); ok {
            sqlvmErrList[i] = sqlvmErr
        } else {
            sqlvmErrList[i] = errors.Error{
                Position: uint32(parserErr.pos.offset),
                Category: errors.ErrorCategoryGrammar,
                Code:     errors.ErrorCodeParser,
                Token:    "",
                Prefix:   parserErr.prefix,
                Message:  parserErr.Inner.Error(),
        sqlvmErrList[i].Token =
        if offset, ok := encMap[sqlvmErrList[i].Position]; ok {
            sqlvmErrList[i].Position = offset
        } else {
                "cannot fix byte offset %d", sqlvmErrList[i].Position))
    return stmts, sqlvmErrList