diff options
-rw-r--r-- | AST.cpp | 5 | ||||
-rwxr-xr-x | AST.h | 53 | ||||
-rw-r--r-- | ASTForward.h | 2 | ||||
-rw-r--r-- | ASTPrinter.cpp | 24 | ||||
-rw-r--r-- | ASTPrinter.h | 4 | ||||
-rw-r--r-- | ASTVisitor.h | 8 | ||||
-rw-r--r-- | AST_accept.h | 34 | ||||
-rw-r--r-- | Parser.cpp | 44 | ||||
-rw-r--r-- | Parser.h | 3 | ||||
-rw-r--r-- | Token.h | 1 |
10 files changed, 174 insertions, 4 deletions
@@ -183,6 +183,11 @@ string FunctionDefinition::getCanonicalSignature() const return getName() + FunctionType(*this).getCanonicalSignature(); } +void ModifierDefinition::checkTypeRequirements() +{ + m_body->checkTypeRequirements(); +} + void Block::checkTypeRequirements() { for (shared_ptr<Statement> const& statement: m_statements) @@ -161,12 +161,14 @@ public: std::vector<ASTPointer<InheritanceSpecifier>> const& _baseContracts, std::vector<ASTPointer<StructDefinition>> const& _definedStructs, std::vector<ASTPointer<VariableDeclaration>> const& _stateVariables, - std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions): + std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions, + std::vector<ASTPointer<ModifierDefinition>> const& _functionModifiers): Declaration(_location, _name), m_baseContracts(_baseContracts), m_definedStructs(_definedStructs), m_stateVariables(_stateVariables), m_definedFunctions(_definedFunctions), + m_functionModifiers(_functionModifiers), m_documentation(_documentation) {} @@ -207,6 +209,7 @@ private: std::vector<ASTPointer<StructDefinition>> m_definedStructs; std::vector<ASTPointer<VariableDeclaration>> m_stateVariables; std::vector<ASTPointer<FunctionDefinition>> m_definedFunctions; + std::vector<ASTPointer<ModifierDefinition>> m_functionModifiers; ASTPointer<ASTString> m_documentation; std::vector<ContractDefinition const*> m_linearizedBaseContracts; @@ -362,6 +365,39 @@ private: }; /** + * Definition of a function modifier. + */ +class ModifierDefinition: public Declaration +{ +public: + ModifierDefinition(Location const& _location, + ASTPointer<ASTString> const& _name, + ASTPointer<ASTString> const& _documentation, + ASTPointer<ParameterList> const& _parameters, + ASTPointer<Block> const& _body): + Declaration(_location, _name), m_documentation(_documentation), + m_parameters(_parameters), m_body(_body) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } + ParameterList const& getParameterList() const { return *m_parameters; } + Block const& getBody() const { return *m_body; } + + /// @return A shared pointer of an ASTString. + /// Can contain a nullptr in which case indicates absence of documentation + ASTPointer<ASTString> const& getDocumentation() const { return m_documentation; } + + void checkTypeRequirements(); + +private: + ASTPointer<ASTString> m_documentation; + ASTPointer<ParameterList> m_parameters; + ASTPointer<Block> m_body; +}; + +/** * Pseudo AST node that is used as declaration for "this", "msg", "tx", "block" and the global * functions when such an identifier is encountered. Will never have a valid location in the source code. */ @@ -503,6 +539,21 @@ private: }; /** + * Special placeholder statement denoted by "_" used in function modifiers. This is replaced by + * the original function when the modifier is applied. + */ +class PlaceholderStatement: public Statement +{ +public: + PlaceholderStatement(Location const& _location): Statement(_location) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + virtual void checkTypeRequirements() override { } +}; + +/** * If-statement with an optional "else" part. Note that "else if" is modeled by having a new * if-statement as the false (else) body. */ diff --git a/ASTForward.h b/ASTForward.h index da0a8812..b86f18cf 100644 --- a/ASTForward.h +++ b/ASTForward.h @@ -43,6 +43,7 @@ class StructDefinition; class ParameterList; class FunctionDefinition; class VariableDeclaration; +class ModifierDefinition; class MagicVariableDeclaration; class TypeName; class ElementaryTypeName; @@ -50,6 +51,7 @@ class UserDefinedTypeName; class Mapping; class Statement; class Block; +class PlaceholderStatement; class IfStatement; class BreakableStatement; class WhileStatement; diff --git a/ASTPrinter.cpp b/ASTPrinter.cpp index 916fca1e..86cd028a 100644 --- a/ASTPrinter.cpp +++ b/ASTPrinter.cpp @@ -87,6 +87,13 @@ bool ASTPrinter::visit(VariableDeclaration const& _node) return goDeeper(); } +bool ASTPrinter::visit(ModifierDefinition const& _node) +{ + writeLine("ModifierDefinition \"" + _node.getName() + "\""); + printSourcePart(_node); + return goDeeper(); +} + bool ASTPrinter::visit(TypeName const& _node) { writeLine("TypeName"); @@ -129,6 +136,13 @@ bool ASTPrinter::visit(Block const& _node) return goDeeper(); } +bool ASTPrinter::visit(PlaceholderStatement const& _node) +{ + writeLine("PlaceholderStatement"); + printSourcePart(_node); + return goDeeper(); +} + bool ASTPrinter::visit(IfStatement const& _node) { writeLine("IfStatement"); @@ -322,6 +336,11 @@ void ASTPrinter::endVisit(VariableDeclaration const&) m_indentation--; } +void ASTPrinter::endVisit(ModifierDefinition const&) +{ + m_indentation--; +} + void ASTPrinter::endVisit(TypeName const&) { m_indentation--; @@ -352,6 +371,11 @@ void ASTPrinter::endVisit(Block const&) m_indentation--; } +void ASTPrinter::endVisit(PlaceholderStatement const&) +{ + m_indentation--; +} + void ASTPrinter::endVisit(IfStatement const&) { m_indentation--; diff --git a/ASTPrinter.h b/ASTPrinter.h index fc5fb4ac..ef5c5164 100644 --- a/ASTPrinter.h +++ b/ASTPrinter.h @@ -48,12 +48,14 @@ public: bool visit(ParameterList const& _node) override; bool visit(FunctionDefinition const& _node) override; bool visit(VariableDeclaration const& _node) override; + bool visit(ModifierDefinition const& _node) override; bool visit(TypeName const& _node) override; bool visit(ElementaryTypeName const& _node) override; bool visit(UserDefinedTypeName const& _node) override; bool visit(Mapping const& _node) override; bool visit(Statement const& _node) override; bool visit(Block const& _node) override; + bool visit(PlaceholderStatement const& _node) override; bool visit(IfStatement const& _node) override; bool visit(BreakableStatement const& _node) override; bool visit(WhileStatement const& _node) override; @@ -82,12 +84,14 @@ public: void endVisit(ParameterList const&) override; void endVisit(FunctionDefinition const&) override; void endVisit(VariableDeclaration const&) override; + void endVisit(ModifierDefinition const&) override; void endVisit(TypeName const&) override; void endVisit(ElementaryTypeName const&) override; void endVisit(UserDefinedTypeName const&) override; void endVisit(Mapping const&) override; void endVisit(Statement const&) override; void endVisit(Block const&) override; + void endVisit(PlaceholderStatement const&) override; void endVisit(IfStatement const&) override; void endVisit(BreakableStatement const&) override; void endVisit(WhileStatement const&) override; diff --git a/ASTVisitor.h b/ASTVisitor.h index 33a8a338..94bb9e0b 100644 --- a/ASTVisitor.h +++ b/ASTVisitor.h @@ -49,12 +49,14 @@ public: virtual bool visit(ParameterList&) { return true; } virtual bool visit(FunctionDefinition&) { return true; } virtual bool visit(VariableDeclaration&) { return true; } + virtual bool visit(ModifierDefinition&) { return true; } virtual bool visit(TypeName&) { return true; } virtual bool visit(ElementaryTypeName&) { return true; } virtual bool visit(UserDefinedTypeName&) { return true; } virtual bool visit(Mapping&) { return true; } virtual bool visit(Statement&) { return true; } virtual bool visit(Block&) { return true; } + virtual bool visit(PlaceholderStatement&) { return true; } virtual bool visit(IfStatement&) { return true; } virtual bool visit(BreakableStatement&) { return true; } virtual bool visit(WhileStatement&) { return true; } @@ -85,12 +87,14 @@ public: virtual void endVisit(ParameterList&) { } virtual void endVisit(FunctionDefinition&) { } virtual void endVisit(VariableDeclaration&) { } + virtual void endVisit(ModifierDefinition&) { } virtual void endVisit(TypeName&) { } virtual void endVisit(ElementaryTypeName&) { } virtual void endVisit(UserDefinedTypeName&) { } virtual void endVisit(Mapping&) { } virtual void endVisit(Statement&) { } virtual void endVisit(Block&) { } + virtual void endVisit(PlaceholderStatement&) { } virtual void endVisit(IfStatement&) { } virtual void endVisit(BreakableStatement&) { } virtual void endVisit(WhileStatement&) { } @@ -125,12 +129,14 @@ public: virtual bool visit(ParameterList const&) { return true; } virtual bool visit(FunctionDefinition const&) { return true; } virtual bool visit(VariableDeclaration const&) { return true; } + virtual bool visit(ModifierDefinition const&) { return true; } virtual bool visit(TypeName const&) { return true; } virtual bool visit(ElementaryTypeName const&) { return true; } virtual bool visit(UserDefinedTypeName const&) { return true; } virtual bool visit(Mapping const&) { return true; } virtual bool visit(Statement const&) { return true; } virtual bool visit(Block const&) { return true; } + virtual bool visit(PlaceholderStatement const&) { return true; } virtual bool visit(IfStatement const&) { return true; } virtual bool visit(BreakableStatement const&) { return true; } virtual bool visit(WhileStatement const&) { return true; } @@ -161,12 +167,14 @@ public: virtual void endVisit(ParameterList const&) { } virtual void endVisit(FunctionDefinition const&) { } virtual void endVisit(VariableDeclaration const&) { } + virtual void endVisit(ModifierDefinition const&) { } virtual void endVisit(TypeName const&) { } virtual void endVisit(ElementaryTypeName const&) { } virtual void endVisit(UserDefinedTypeName const&) { } virtual void endVisit(Mapping const&) { } virtual void endVisit(Statement const&) { } virtual void endVisit(Block const&) { } + virtual void endVisit(PlaceholderStatement const&) { } virtual void endVisit(IfStatement const&) { } virtual void endVisit(BreakableStatement const&) { } virtual void endVisit(WhileStatement const&) { } diff --git a/AST_accept.h b/AST_accept.h index b77cfe1c..89786d6f 100644 --- a/AST_accept.h +++ b/AST_accept.h @@ -65,6 +65,7 @@ void ContractDefinition::accept(ASTVisitor& _visitor) listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); + listAccept(m_functionModifiers, _visitor); } _visitor.endVisit(*this); } @@ -77,6 +78,7 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); + listAccept(m_functionModifiers, _visitor); } _visitor.endVisit(*this); } @@ -175,6 +177,26 @@ void VariableDeclaration::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void ModifierDefinition::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + { + m_parameters->accept(_visitor); + m_body->accept(_visitor); + } + _visitor.endVisit(*this); +} + +void ModifierDefinition::accept(ASTConstVisitor& _visitor) const +{ + if (_visitor.visit(*this)) + { + m_parameters->accept(_visitor); + m_body->accept(_visitor); + } + _visitor.endVisit(*this); +} + void TypeName::accept(ASTVisitor& _visitor) { _visitor.visit(*this); @@ -245,6 +267,18 @@ void Block::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void PlaceholderStatement::accept(ASTVisitor& _visitor) +{ + _visitor.visit(*this); + _visitor.endVisit(*this); +} + +void PlaceholderStatement::accept(ASTConstVisitor& _visitor) const +{ + _visitor.visit(*this); + _visitor.endVisit(*this); +} + void IfStatement::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) @@ -121,6 +121,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() vector<ASTPointer<StructDefinition>> structs; vector<ASTPointer<VariableDeclaration>> stateVariables; vector<ASTPointer<FunctionDefinition>> functions; + vector<ASTPointer<ModifierDefinition>> modifiers; if (m_scanner->getCurrentToken() == Token::IS) do { @@ -152,13 +153,15 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() stateVariables.push_back(parseVariableDeclaration(allowVar)); expectToken(Token::SEMICOLON); } + else if (currentToken == Token::MODIFIER) + modifiers.push_back(parseModifierDefinition()); else - BOOST_THROW_EXCEPTION(createParserError("Function, variable or struct declaration expected.")); + BOOST_THROW_EXCEPTION(createParserError("Function, variable, struct or modifier declaration expected.")); } nodeFactory.markEndPosition(); expectToken(Token::RBRACE); return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs, - stateVariables, functions); + stateVariables, functions, modifiers); } ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier() @@ -242,6 +245,33 @@ ASTPointer<VariableDeclaration> Parser::parseVariableDeclaration(bool _allowVar) return nodeFactory.createNode<VariableDeclaration>(type, expectIdentifierToken()); } +ASTPointer<ModifierDefinition> Parser::parseModifierDefinition() +{ + ScopeGuard resetModifierFlag([this]() { m_insideModifier = false; }); + m_insideModifier = true; + + ASTNodeFactory nodeFactory(*this); + ASTPointer<ASTString> docstring; + if (m_scanner->getCurrentCommentLiteral() != "") + docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); + + expectToken(Token::MODIFIER); + ASTPointer<ASTString> name(expectIdentifierToken()); + ASTPointer<ParameterList> parameters; + if (m_scanner->getCurrentToken() == Token::LPAREN) + parameters = parseParameterList(); + else + { + // create an empty parameter list at a zero-length location + ASTNodeFactory nodeFactory(*this); + nodeFactory.setLocationEmpty(); + parameters = nodeFactory.createNode<ParameterList>(vector<ASTPointer<VariableDeclaration>>()); + } + ASTPointer<Block> block = parseBlock(); + nodeFactory.setEndPositionFromNode(block); + return nodeFactory.createNode<ModifierDefinition>(name, docstring, parameters, block); +} + ASTPointer<TypeName> Parser::parseTypeName(bool _allowVar) { ASTPointer<TypeName> type; @@ -354,8 +384,16 @@ ASTPointer<Statement> Parser::parseStatement() nodeFactory.setEndPositionFromNode(expression); } statement = nodeFactory.createNode<Return>(expression); + break; } - break; + case Token::IDENTIFIER: + if (m_insideModifier && m_scanner->getCurrentLiteral() == "_") + { + statement = ASTNodeFactory(*this).createNode<PlaceholderStatement>(); + m_scanner->next(); + return statement; + } + // fall-through default: statement = parseVarDefOrExprStmt(); } @@ -53,6 +53,7 @@ private: ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName); ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); + ASTPointer<ModifierDefinition> parseModifierDefinition(); ASTPointer<TypeName> parseTypeName(bool _allowVar); ASTPointer<Mapping> parseMapping(); ASTPointer<ParameterList> parseParameterList(bool _allowEmpty = true); @@ -90,6 +91,8 @@ private: ParserError createParserError(std::string const& _description) const; std::shared_ptr<Scanner> m_scanner; + /// Flag that signifies whether '_' is parsed as a PlaceholderStatement or a regular identifier. + bool m_insideModifier = false; }; } @@ -159,6 +159,7 @@ namespace solidity K(IF, "if", 0) \ K(IMPORT, "import", 0) \ K(MAPPING, "mapping", 0) \ + K(MODIFIER, "modifier", 0) \ K(NEW, "new", 0) \ K(PUBLIC, "public", 0) \ K(PRIVATE, "private", 0) \ |