diff options
author | Gav Wood <g@ethdev.com> | 2015-01-27 07:01:25 +0800 |
---|---|---|
committer | Gav Wood <g@ethdev.com> | 2015-01-27 07:01:25 +0800 |
commit | 8d09d8deb75a28b43a493d8afee5e743a3b201bb (patch) | |
tree | b99e93e50d7203026cfad74e0e70ac6de654f6b1 | |
parent | fe37aad4d52f2ddd3d9fd6f6249a9a94394eaf1c (diff) | |
parent | f59cda76def4b1d24236e5a110bfebbd688c5d1a (diff) | |
download | dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar.gz dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar.bz2 dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar.lz dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar.xz dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.tar.zst dexon-solidity-8d09d8deb75a28b43a493d8afee5e743a3b201bb.zip |
Merge pull request #856 from chriseth/sol_modifiers
Function modifiers.
-rw-r--r-- | AST.cpp | 131 | ||||
-rwxr-xr-x | AST.h | 143 | ||||
-rw-r--r-- | ASTForward.h | 5 | ||||
-rw-r--r-- | ASTJsonConverter.cpp | 3 | ||||
-rw-r--r-- | ASTPrinter.cpp | 36 | ||||
-rw-r--r-- | ASTPrinter.h | 6 | ||||
-rw-r--r-- | ASTVisitor.h | 12 | ||||
-rw-r--r-- | AST_accept.h | 56 | ||||
-rw-r--r-- | CallGraph.cpp | 35 | ||||
-rw-r--r-- | CallGraph.h | 15 | ||||
-rw-r--r-- | Compiler.cpp | 107 | ||||
-rw-r--r-- | Compiler.h | 16 | ||||
-rw-r--r-- | CompilerContext.cpp | 28 | ||||
-rw-r--r-- | CompilerContext.h | 13 | ||||
-rw-r--r-- | NameAndTypeResolver.cpp | 21 | ||||
-rw-r--r-- | NameAndTypeResolver.h | 4 | ||||
-rw-r--r-- | Parser.cpp | 79 | ||||
-rw-r--r-- | Parser.h | 4 | ||||
-rw-r--r-- | Token.h | 1 | ||||
-rw-r--r-- | Types.cpp | 32 | ||||
-rw-r--r-- | Types.h | 23 | ||||
-rw-r--r-- | grammar.txt | 6 |
22 files changed, 630 insertions, 146 deletions
@@ -41,10 +41,15 @@ TypeError ASTNode::createTypeError(string const& _description) const return TypeError() << errinfo_sourceLocation(getLocation()) << errinfo_comment(_description); } +TypePointer ContractDefinition::getType(ContractDefinition const* _currentContract) const +{ + return make_shared<TypeType>(make_shared<ContractType>(*this), _currentContract); +} + void ContractDefinition::checkTypeRequirements() { - for (ASTPointer<InheritanceSpecifier> const& base: getBaseContracts()) - base->checkTypeRequirements(); + for (ASTPointer<InheritanceSpecifier> const& baseSpecifier: getBaseContracts()) + baseSpecifier->checkTypeRequirements(); checkIllegalOverrides(); @@ -53,6 +58,9 @@ void ContractDefinition::checkTypeRequirements() BOOST_THROW_EXCEPTION(constructor->getReturnParameterList()->createTypeError( "Non-empty \"returns\" directive for constructor.")); + for (ASTPointer<ModifierDefinition> const& modifier: getFunctionModifiers()) + modifier->checkTypeRequirements(); + for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) function->checkTypeRequirements(); @@ -89,15 +97,22 @@ FunctionDefinition const* ContractDefinition::getConstructor() const void ContractDefinition::checkIllegalOverrides() const { + // TODO unify this at a later point. for this we need to put the constness and the access specifier + // into the types map<string, FunctionDefinition const*> functions; + map<string, ModifierDefinition const*> modifiers; // We search from derived to base, so the stored item causes the error. for (ContractDefinition const* contract: getLinearizedBaseContracts()) + { for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) { if (function->isConstructor()) - continue; // constructors can neither be overriden nor override anything - FunctionDefinition const*& override = functions[function->getName()]; + continue; // constructors can neither be overridden nor override anything + string const& name = function->getName(); + if (modifiers.count(name)) + BOOST_THROW_EXCEPTION(modifiers[name]->createTypeError("Override changes function to modifier.")); + FunctionDefinition const*& override = functions[name]; if (!override) override = function.get(); else if (override->isPublic() != function->isPublic() || @@ -105,6 +120,18 @@ void ContractDefinition::checkIllegalOverrides() const FunctionType(*override) != FunctionType(*function)) BOOST_THROW_EXCEPTION(override->createTypeError("Override changes extended function signature.")); } + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + { + string const& name = modifier->getName(); + if (functions.count(name)) + BOOST_THROW_EXCEPTION(functions[name]->createTypeError("Override changes modifier to function.")); + ModifierDefinition const*& override = modifiers[name]; + if (!override) + override = modifier.get(); + else if (ModifierType(*override) != ModifierType(*modifier)) + BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier signature.")); + } + } } vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition::getInterfaceFunctionList() const @@ -141,6 +168,11 @@ void InheritanceSpecifier::checkTypeRequirements() BOOST_THROW_EXCEPTION(createTypeError("Invalid type for argument in constructer call.")); } +TypePointer StructDefinition::getType(ContractDefinition const*) const +{ + return make_shared<TypeType>(make_shared<StructType>(*this)); +} + void StructDefinition::checkMemberTypes() const { for (ASTPointer<VariableDeclaration> const& member: getMembers()) @@ -169,11 +201,18 @@ void StructDefinition::checkRecursion() const } } +TypePointer FunctionDefinition::getType(ContractDefinition const*) const +{ + return make_shared<FunctionType>(*this); +} + void FunctionDefinition::checkTypeRequirements() { for (ASTPointer<VariableDeclaration> const& var: getParameters() + getReturnParameters()) if (!var->getType()->canLiveOutsideStorage()) BOOST_THROW_EXCEPTION(var->createTypeError("Type is required to live outside storage.")); + for (ASTPointer<ModifierInvocation> const& modifier: m_functionModifiers) + modifier->checkTypeRequirements(); m_body->checkTypeRequirements(); } @@ -183,6 +222,40 @@ string FunctionDefinition::getCanonicalSignature() const return getName() + FunctionType(*this).getCanonicalSignature(); } +Declaration::LValueType VariableDeclaration::getLValueType() const +{ + if (dynamic_cast<FunctionDefinition const*>(getScope()) || dynamic_cast<ModifierDefinition const*>(getScope())) + return Declaration::LValueType::LOCAL; + else + return Declaration::LValueType::STORAGE; +} + +TypePointer ModifierDefinition::getType(ContractDefinition const*) const +{ + return make_shared<ModifierType>(*this); +} + +void ModifierDefinition::checkTypeRequirements() +{ + m_body->checkTypeRequirements(); +} + +void ModifierInvocation::checkTypeRequirements() +{ + m_modifierName->checkTypeRequirements(); + for (ASTPointer<Expression> const& argument: m_arguments) + argument->checkTypeRequirements(); + + ModifierDefinition const* modifier = dynamic_cast<ModifierDefinition const*>(m_modifierName->getReferencedDeclaration()); + solAssert(modifier, "Function modifier not found."); + vector<ASTPointer<VariableDeclaration>> const& parameters = modifier->getParameters(); + if (parameters.size() != m_arguments.size()) + BOOST_THROW_EXCEPTION(createTypeError("Wrong argument count for modifier invocation.")); + for (size_t i = 0; i < m_arguments.size(); ++i) + if (!m_arguments[i]->getType()->isImplicitlyConvertibleTo(*parameters[i]->getType())) + BOOST_THROW_EXCEPTION(createTypeError("Invalid type for argument in modifier invocation.")); +} + void Block::checkTypeRequirements() { for (shared_ptr<Statement> const& statement: m_statements) @@ -218,7 +291,8 @@ void Return::checkTypeRequirements() { if (!m_expression) return; - solAssert(m_returnParameters, "Return parameters not assigned."); + if (!m_returnParameters) + BOOST_THROW_EXCEPTION(createTypeError("Return arguments not allowed.")); if (m_returnParameters->getParameters().size() != 1) BOOST_THROW_EXCEPTION(createTypeError("Different number of arguments in return statement " "than in returns declaration.")); @@ -394,7 +468,7 @@ void MemberAccess::checkTypeRequirements() BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found or not " "visible in " + type.toString())); //@todo later, this will not always be STORAGE - m_lvalue = type.getCategory() == Type::Category::STRUCT ? LValueType::STORAGE : LValueType::NONE; + m_lvalue = type.getCategory() == Type::Category::STRUCT ? Declaration::LValueType::STORAGE : Declaration::LValueType::NONE; } void IndexAccess::checkTypeRequirements() @@ -406,52 +480,17 @@ void IndexAccess::checkTypeRequirements() MappingType const& type = dynamic_cast<MappingType const&>(*m_base->getType()); m_index->expectType(*type.getKeyType()); m_type = type.getValueType(); - m_lvalue = LValueType::STORAGE; + m_lvalue = Declaration::LValueType::STORAGE; } void Identifier::checkTypeRequirements() { solAssert(m_referencedDeclaration, "Identifier not resolved."); - VariableDeclaration const* variable = dynamic_cast<VariableDeclaration const*>(m_referencedDeclaration); - if (variable) - { - if (!variable->getType()) - BOOST_THROW_EXCEPTION(createTypeError("Variable referenced before type could be determined.")); - m_type = variable->getType(); - m_lvalue = variable->isLocalVariable() ? LValueType::LOCAL : LValueType::STORAGE; - return; - } - //@todo can we unify these with TypeName::toType()? - StructDefinition const* structDef = dynamic_cast<StructDefinition const*>(m_referencedDeclaration); - if (structDef) - { - // note that we do not have a struct type here - m_type = make_shared<TypeType>(make_shared<StructType>(*structDef)); - return; - } - FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(m_referencedDeclaration); - if (functionDef) - { - // a function reference is not a TypeType, because calling a TypeType converts to the type. - // Calling a function (e.g. function(12), otherContract.function(34)) does not do a type - // conversion. - m_type = make_shared<FunctionType>(*functionDef); - return; - } - ContractDefinition const* contractDef = dynamic_cast<ContractDefinition const*>(m_referencedDeclaration); - if (contractDef) - { - m_type = make_shared<TypeType>(make_shared<ContractType>(*contractDef), m_currentContract); - return; - } - MagicVariableDeclaration const* magicVariable = dynamic_cast<MagicVariableDeclaration const*>(m_referencedDeclaration); - if (magicVariable) - { - m_type = magicVariable->getType(); - return; - } - BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Declaration reference of unknown/forbidden type.")); + m_lvalue = m_referencedDeclaration->getLValueType(); + m_type = m_referencedDeclaration->getType(m_currentContract); + if (!m_type) + BOOST_THROW_EXCEPTION(createTypeError("Declaration referenced before type could be determined.")); } void ElementaryTypeNameExpression::checkTypeRequirements() @@ -132,6 +132,8 @@ private: class Declaration: public ASTNode { public: + enum class LValueType { NONE, LOCAL, STORAGE }; + Declaration(Location const& _location, ASTPointer<ASTString> const& _name): ASTNode(_location), m_name(_name), m_scope(nullptr) {} @@ -142,6 +144,13 @@ public: Declaration const* getScope() const { return m_scope; } void setScope(Declaration const* _scope) { m_scope = _scope; } + /// @returns the type of expressions referencing this declaration. + /// The current contract has to be given since this context can change the type, especially of + /// contract types. + virtual TypePointer getType(ContractDefinition const* m_currentContract = nullptr) const = 0; + /// @returns the lvalue type of expressions referencing this declaration + virtual LValueType getLValueType() const { return LValueType::NONE; } + private: ASTPointer<ASTString> m_name; Declaration const* m_scope; @@ -161,12 +170,14 @@ public: std::vector<ASTPointer<InheritanceSpecifier>> const& _baseContracts, std::vector<ASTPointer<StructDefinition>> const& _definedStructs, std::vector<ASTPointer<VariableDeclaration>> const& _stateVariables, - std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions): + std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions, + std::vector<ASTPointer<ModifierDefinition>> const& _functionModifiers): Declaration(_location, _name), m_baseContracts(_baseContracts), m_definedStructs(_definedStructs), m_stateVariables(_stateVariables), m_definedFunctions(_definedFunctions), + m_functionModifiers(_functionModifiers), m_documentation(_documentation) {} @@ -176,8 +187,11 @@ public: std::vector<ASTPointer<InheritanceSpecifier>> const& getBaseContracts() const { return m_baseContracts; } std::vector<ASTPointer<StructDefinition>> const& getDefinedStructs() const { return m_definedStructs; } std::vector<ASTPointer<VariableDeclaration>> const& getStateVariables() const { return m_stateVariables; } + std::vector<ASTPointer<ModifierDefinition>> const& getFunctionModifiers() const { return m_functionModifiers; } std::vector<ASTPointer<FunctionDefinition>> const& getDefinedFunctions() const { return m_definedFunctions; } + virtual TypePointer getType(ContractDefinition const* m_currentContract) const override; + /// Checks that there are no illegal overrides, that the constructor does not have a "returns" /// and calls checkTypeRequirements on all its functions. void checkTypeRequirements(); @@ -207,6 +221,7 @@ private: std::vector<ASTPointer<StructDefinition>> m_definedStructs; std::vector<ASTPointer<VariableDeclaration>> m_stateVariables; std::vector<ASTPointer<FunctionDefinition>> m_definedFunctions; + std::vector<ASTPointer<ModifierDefinition>> m_functionModifiers; ASTPointer<ASTString> m_documentation; std::vector<ContractDefinition const*> m_linearizedBaseContracts; @@ -245,6 +260,8 @@ public: std::vector<ASTPointer<VariableDeclaration>> const& getMembers() const { return m_members; } + virtual TypePointer getType(ContractDefinition const*) const override; + /// Checks that the members do not include any recursive structs and have valid types /// (e.g. no functions). void checkValidityOfMembers() const; @@ -276,7 +293,20 @@ private: std::vector<ASTPointer<VariableDeclaration>> m_parameters; }; -class FunctionDefinition: public Declaration +/** + * Abstract class that is added to each AST node that can store local variables. + */ +class VariableScope +{ +public: + void addLocalVariable(VariableDeclaration const& _localVariable) { m_localVariables.push_back(&_localVariable); } + std::vector<VariableDeclaration const*> const& getLocalVariables() const { return m_localVariables; } + +private: + std::vector<VariableDeclaration const*> m_localVariables; +}; + +class FunctionDefinition: public Declaration, public VariableScope { public: FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name, @@ -285,11 +315,13 @@ public: ASTPointer<ASTString> const& _documentation, ASTPointer<ParameterList> const& _parameters, bool _isDeclaredConst, + std::vector<ASTPointer<ModifierInvocation>> const& _modifiers, ASTPointer<ParameterList> const& _returnParameters, ASTPointer<Block> const& _body): Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor), m_parameters(_parameters), m_isDeclaredConst(_isDeclaredConst), + m_functionModifiers(_modifiers), m_returnParameters(_returnParameters), m_body(_body), m_documentation(_documentation) @@ -301,6 +333,7 @@ public: bool isPublic() const { return m_isPublic; } bool isConstructor() const { return m_isConstructor; } bool isDeclaredConst() const { return m_isDeclaredConst; } + std::vector<ASTPointer<ModifierInvocation>> const& getModifiers() const { return m_functionModifiers; } std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } ParameterList const& getParameterList() const { return *m_parameters; } std::vector<ASTPointer<VariableDeclaration>> const& getReturnParameters() const { return m_returnParameters->getParameters(); } @@ -310,8 +343,7 @@ public: /// Can contain a nullptr in which case indicates absence of documentation ASTPointer<ASTString> const& getDocumentation() const { return m_documentation; } - void addLocalVariable(VariableDeclaration const& _localVariable) { m_localVariables.push_back(&_localVariable); } - std::vector<VariableDeclaration const*> const& getLocalVariables() const { return m_localVariables; } + virtual TypePointer getType(ContractDefinition const*) const override; /// Checks that all parameters have allowed types and calls checkTypeRequirements on the body. void checkTypeRequirements(); @@ -326,11 +358,10 @@ private: bool m_isConstructor; ASTPointer<ParameterList> m_parameters; bool m_isDeclaredConst; + std::vector<ASTPointer<ModifierInvocation>> m_functionModifiers; ASTPointer<ParameterList> m_returnParameters; ASTPointer<Block> m_body; ASTPointer<ASTString> m_documentation; - - std::vector<VariableDeclaration const*> m_localVariables; }; /** @@ -350,10 +381,10 @@ public: /// Returns the declared or inferred type. Can be an empty pointer if no type was explicitly /// declared and there is no assignment to the variable that fixes the type. - std::shared_ptr<Type const> const& getType() const { return m_type; } + TypePointer getType(ContractDefinition const* = nullptr) const { return m_type; } void setType(std::shared_ptr<Type const> const& _type) { m_type = _type; } - bool isLocalVariable() const { return !!dynamic_cast<FunctionDefinition const*>(getScope()); } + virtual LValueType getLValueType() const override; private: ASTPointer<TypeName> m_typeName; ///< can be empty ("var") @@ -362,6 +393,64 @@ private: }; /** + * Definition of a function modifier. + */ +class ModifierDefinition: public Declaration, public VariableScope +{ +public: + ModifierDefinition(Location const& _location, + ASTPointer<ASTString> const& _name, + ASTPointer<ASTString> const& _documentation, + ASTPointer<ParameterList> const& _parameters, + ASTPointer<Block> const& _body): + Declaration(_location, _name), m_documentation(_documentation), + m_parameters(_parameters), m_body(_body) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } + ParameterList const& getParameterList() const { return *m_parameters; } + Block const& getBody() const { return *m_body; } + + virtual TypePointer getType(ContractDefinition const* = nullptr) const override; + + /// @return A shared pointer of an ASTString. + /// Can contain a nullptr in which case indicates absence of documentation + ASTPointer<ASTString> const& getDocumentation() const { return m_documentation; } + + void checkTypeRequirements(); + +private: + ASTPointer<ASTString> m_documentation; + ASTPointer<ParameterList> m_parameters; + ASTPointer<Block> m_body; +}; + +/** + * Invocation/usage of a modifier in a function header. + */ +class ModifierInvocation: public ASTNode +{ +public: + ModifierInvocation(Location const& _location, ASTPointer<Identifier> const& _name, + std::vector<ASTPointer<Expression>> _arguments): + ASTNode(_location), m_modifierName(_name), m_arguments(_arguments) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + ASTPointer<Identifier> const& getName() const { return m_modifierName; } + std::vector<ASTPointer<Expression>> const& getArguments() const { return m_arguments; } + + void checkTypeRequirements(); + +private: + ASTPointer<Identifier> m_modifierName; + std::vector<ASTPointer<Expression>> m_arguments; +}; + +/** * Pseudo AST node that is used as declaration for "this", "msg", "tx", "block" and the global * functions when such an identifier is encountered. Will never have a valid location in the source code. */ @@ -375,7 +464,7 @@ public: virtual void accept(ASTConstVisitor&) const override { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("MagicVariableDeclaration used inside real AST.")); } - std::shared_ptr<Type const> const& getType() const { return m_type; } + virtual TypePointer getType(ContractDefinition const* = nullptr) const override { return m_type; } private: std::shared_ptr<Type const> m_type; @@ -503,6 +592,21 @@ private: }; /** + * Special placeholder statement denoted by "_" used in function modifiers. This is replaced by + * the original function when the modifier is applied. + */ +class PlaceholderStatement: public Statement +{ +public: + PlaceholderStatement(Location const& _location): Statement(_location) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + virtual void checkTypeRequirements() override { } +}; + +/** * If-statement with an optional "else" part. Note that "else if" is modeled by having a new * if-statement as the false (else) body. */ @@ -618,12 +722,8 @@ public: virtual void accept(ASTConstVisitor& _visitor) const override; virtual void checkTypeRequirements() override; - void setFunctionReturnParameters(ParameterList const& _parameters) { m_returnParameters = &_parameters; } - ParameterList const& getFunctionReturnParameters() const - { - solAssert(m_returnParameters, ""); - return *m_returnParameters; - } + void setFunctionReturnParameters(ParameterList const* _parameters) { m_returnParameters = _parameters; } + ParameterList const* getFunctionReturnParameters() const { return m_returnParameters; } Expression const* getExpression() const { return m_expression.get(); } private: @@ -686,16 +786,13 @@ private: */ class Expression: public ASTNode { -protected: - enum class LValueType { NONE, LOCAL, STORAGE }; - public: - Expression(Location const& _location): ASTNode(_location), m_lvalue(LValueType::NONE), m_lvalueRequested(false) {} + Expression(Location const& _location): ASTNode(_location) {} virtual void checkTypeRequirements() = 0; std::shared_ptr<Type const> const& getType() const { return m_type; } - bool isLValue() const { return m_lvalue != LValueType::NONE; } - bool isLocalLValue() const { return m_lvalue == LValueType::LOCAL; } + bool isLValue() const { return m_lvalue != Declaration::LValueType::NONE; } + bool isLocalLValue() const { return m_lvalue == Declaration::LValueType::LOCAL; } /// Helper function, infer the type via @ref checkTypeRequirements and then check that it /// is implicitly convertible to @a _expectedType. If not, throw exception. @@ -712,9 +809,9 @@ protected: std::shared_ptr<Type const> m_type; //! If this expression is an lvalue (i.e. something that can be assigned to) and is stored //! locally or in storage. This is set during calls to @a checkTypeRequirements() - LValueType m_lvalue; + Declaration::LValueType m_lvalue = Declaration::LValueType::NONE; //! Whether the outer expression requested the address (true) or the value (false) of this expression. - bool m_lvalueRequested; + bool m_lvalueRequested = false; }; /// Assignment, can also be a compound assignment. diff --git a/ASTForward.h b/ASTForward.h index da0a8812..aa5cd49c 100644 --- a/ASTForward.h +++ b/ASTForward.h @@ -43,6 +43,8 @@ class StructDefinition; class ParameterList; class FunctionDefinition; class VariableDeclaration; +class ModifierDefinition; +class ModifierInvocation; class MagicVariableDeclaration; class TypeName; class ElementaryTypeName; @@ -50,6 +52,7 @@ class UserDefinedTypeName; class Mapping; class Statement; class Block; +class PlaceholderStatement; class IfStatement; class BreakableStatement; class WhileStatement; @@ -72,6 +75,8 @@ class Identifier; class ElementaryTypeNameExpression; class Literal; +class VariableScope; + // Used as pointers to AST nodes, to be replaced by more clever pointers, e.g. pointers which do // not do reference counting but point to a special memory area that is completely released // explicitly. diff --git a/ASTJsonConverter.cpp b/ASTJsonConverter.cpp index 04ddee0a..d9332990 100644 --- a/ASTJsonConverter.cpp +++ b/ASTJsonConverter.cpp @@ -118,9 +118,10 @@ bool ASTJsonConverter::visit(FunctionDefinition const& _node) bool ASTJsonConverter::visit(VariableDeclaration const& _node) { + bool isLocalVariable = (_node.getLValueType() == VariableDeclaration::LValueType::LOCAL); addJsonNode("VariableDeclaration", { make_pair("name", _node.getName()), - make_pair("local", boost::lexical_cast<std::string>(_node.isLocalVariable()))}, + make_pair("local", boost::lexical_cast<std::string>(isLocalVariable))}, true); return true; } diff --git a/ASTPrinter.cpp b/ASTPrinter.cpp index 916fca1e..85bc8825 100644 --- a/ASTPrinter.cpp +++ b/ASTPrinter.cpp @@ -87,6 +87,20 @@ bool ASTPrinter::visit(VariableDeclaration const& _node) return goDeeper(); } +bool ASTPrinter::visit(ModifierDefinition const& _node) +{ + writeLine("ModifierDefinition \"" + _node.getName() + "\""); + printSourcePart(_node); + return goDeeper(); +} + +bool ASTPrinter::visit(ModifierInvocation const& _node) +{ + writeLine("ModifierInvocation \"" + _node.getName()->getName() + "\""); + printSourcePart(_node); + return goDeeper(); +} + bool ASTPrinter::visit(TypeName const& _node) { writeLine("TypeName"); @@ -129,6 +143,13 @@ bool ASTPrinter::visit(Block const& _node) return goDeeper(); } +bool ASTPrinter::visit(PlaceholderStatement const& _node) +{ + writeLine("PlaceholderStatement"); + printSourcePart(_node); + return goDeeper(); +} + bool ASTPrinter::visit(IfStatement const& _node) { writeLine("IfStatement"); @@ -322,6 +343,16 @@ void ASTPrinter::endVisit(VariableDeclaration const&) m_indentation--; } +void ASTPrinter::endVisit(ModifierDefinition const&) +{ + m_indentation--; +} + +void ASTPrinter::endVisit(ModifierInvocation const&) +{ + m_indentation--; +} + void ASTPrinter::endVisit(TypeName const&) { m_indentation--; @@ -352,6 +383,11 @@ void ASTPrinter::endVisit(Block const&) m_indentation--; } +void ASTPrinter::endVisit(PlaceholderStatement const&) +{ + m_indentation--; +} + void ASTPrinter::endVisit(IfStatement const&) { m_indentation--; diff --git a/ASTPrinter.h b/ASTPrinter.h index fc5fb4ac..7f267bdf 100644 --- a/ASTPrinter.h +++ b/ASTPrinter.h @@ -48,12 +48,15 @@ public: bool visit(ParameterList const& _node) override; bool visit(FunctionDefinition const& _node) override; bool visit(VariableDeclaration const& _node) override; + bool visit(ModifierDefinition const& _node) override; + bool visit(ModifierInvocation const& _node) override; bool visit(TypeName const& _node) override; bool visit(ElementaryTypeName const& _node) override; bool visit(UserDefinedTypeName const& _node) override; bool visit(Mapping const& _node) override; bool visit(Statement const& _node) override; bool visit(Block const& _node) override; + bool visit(PlaceholderStatement const& _node) override; bool visit(IfStatement const& _node) override; bool visit(BreakableStatement const& _node) override; bool visit(WhileStatement const& _node) override; @@ -82,12 +85,15 @@ public: void endVisit(ParameterList const&) override; void endVisit(FunctionDefinition const&) override; void endVisit(VariableDeclaration const&) override; + void endVisit(ModifierDefinition const&) override; + void endVisit(ModifierInvocation const&) override; void endVisit(TypeName const&) override; void endVisit(ElementaryTypeName const&) override; void endVisit(UserDefinedTypeName const&) override; void endVisit(Mapping const&) override; void endVisit(Statement const&) override; void endVisit(Block const&) override; + void endVisit(PlaceholderStatement const&) override; void endVisit(IfStatement const&) override; void endVisit(BreakableStatement const&) override; void endVisit(WhileStatement const&) override; diff --git a/ASTVisitor.h b/ASTVisitor.h index 33a8a338..ecab00c3 100644 --- a/ASTVisitor.h +++ b/ASTVisitor.h @@ -49,12 +49,15 @@ public: virtual bool visit(ParameterList&) { return true; } virtual bool visit(FunctionDefinition&) { return true; } virtual bool visit(VariableDeclaration&) { return true; } + virtual bool visit(ModifierDefinition&) { return true; } + virtual bool visit(ModifierInvocation&) { return true; } virtual bool visit(TypeName&) { return true; } virtual bool visit(ElementaryTypeName&) { return true; } virtual bool visit(UserDefinedTypeName&) { return true; } virtual bool visit(Mapping&) { return true; } virtual bool visit(Statement&) { return true; } virtual bool visit(Block&) { return true; } + virtual bool visit(PlaceholderStatement&) { return true; } virtual bool visit(IfStatement&) { return true; } virtual bool visit(BreakableStatement&) { return true; } virtual bool visit(WhileStatement&) { return true; } @@ -85,12 +88,15 @@ public: virtual void endVisit(ParameterList&) { } virtual void endVisit(FunctionDefinition&) { } virtual void endVisit(VariableDeclaration&) { } + virtual void endVisit(ModifierDefinition&) { } + virtual void endVisit(ModifierInvocation&) { } virtual void endVisit(TypeName&) { } virtual void endVisit(ElementaryTypeName&) { } virtual void endVisit(UserDefinedTypeName&) { } virtual void endVisit(Mapping&) { } virtual void endVisit(Statement&) { } virtual void endVisit(Block&) { } + virtual void endVisit(PlaceholderStatement&) { } virtual void endVisit(IfStatement&) { } virtual void endVisit(BreakableStatement&) { } virtual void endVisit(WhileStatement&) { } @@ -125,12 +131,15 @@ public: virtual bool visit(ParameterList const&) { return true; } virtual bool visit(FunctionDefinition const&) { return true; } virtual bool visit(VariableDeclaration const&) { return true; } + virtual bool visit(ModifierDefinition const&) { return true; } + virtual bool visit(ModifierInvocation const&) { return true; } virtual bool visit(TypeName const&) { return true; } virtual bool visit(ElementaryTypeName const&) { return true; } virtual bool visit(UserDefinedTypeName const&) { return true; } virtual bool visit(Mapping const&) { return true; } virtual bool visit(Statement const&) { return true; } virtual bool visit(Block const&) { return true; } + virtual bool visit(PlaceholderStatement const&) { return true; } virtual bool visit(IfStatement const&) { return true; } virtual bool visit(BreakableStatement const&) { return true; } virtual bool visit(WhileStatement const&) { return true; } @@ -161,12 +170,15 @@ public: virtual void endVisit(ParameterList const&) { } virtual void endVisit(FunctionDefinition const&) { } virtual void endVisit(VariableDeclaration const&) { } + virtual void endVisit(ModifierDefinition const&) { } + virtual void endVisit(ModifierInvocation const&) { } virtual void endVisit(TypeName const&) { } virtual void endVisit(ElementaryTypeName const&) { } virtual void endVisit(UserDefinedTypeName const&) { } virtual void endVisit(Mapping const&) { } virtual void endVisit(Statement const&) { } virtual void endVisit(Block const&) { } + virtual void endVisit(PlaceholderStatement const&) { } virtual void endVisit(IfStatement const&) { } virtual void endVisit(BreakableStatement const&) { } virtual void endVisit(WhileStatement const&) { } diff --git a/AST_accept.h b/AST_accept.h index b77cfe1c..481b150b 100644 --- a/AST_accept.h +++ b/AST_accept.h @@ -65,6 +65,7 @@ void ContractDefinition::accept(ASTVisitor& _visitor) listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); + listAccept(m_functionModifiers, _visitor); } _visitor.endVisit(*this); } @@ -77,6 +78,7 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); + listAccept(m_functionModifiers, _visitor); } _visitor.endVisit(*this); } @@ -142,6 +144,7 @@ void FunctionDefinition::accept(ASTVisitor& _visitor) m_parameters->accept(_visitor); if (m_returnParameters) m_returnParameters->accept(_visitor); + listAccept(m_functionModifiers, _visitor); m_body->accept(_visitor); } _visitor.endVisit(*this); @@ -154,6 +157,7 @@ void FunctionDefinition::accept(ASTConstVisitor& _visitor) const m_parameters->accept(_visitor); if (m_returnParameters) m_returnParameters->accept(_visitor); + listAccept(m_functionModifiers, _visitor); m_body->accept(_visitor); } _visitor.endVisit(*this); @@ -175,6 +179,46 @@ void VariableDeclaration::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void ModifierDefinition::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + { + m_parameters->accept(_visitor); + m_body->accept(_visitor); + } + _visitor.endVisit(*this); +} + +void ModifierDefinition::accept(ASTConstVisitor& _visitor) const +{ + if (_visitor.visit(*this)) + { + m_parameters->accept(_visitor); + m_body->accept(_visitor); + } + _visitor.endVisit(*this); +} + +void ModifierInvocation::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + { + m_modifierName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + +void ModifierInvocation::accept(ASTConstVisitor& _visitor) const +{ + if (_visitor.visit(*this)) + { + m_modifierName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + void TypeName::accept(ASTVisitor& _visitor) { _visitor.visit(*this); @@ -245,6 +289,18 @@ void Block::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void PlaceholderStatement::accept(ASTVisitor& _visitor) +{ + _visitor.visit(*this); + _visitor.endVisit(*this); +} + +void PlaceholderStatement::accept(ASTConstVisitor& _visitor) const +{ + _visitor.visit(*this); + _visitor.endVisit(*this); +} + void IfStatement::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) diff --git a/CallGraph.cpp b/CallGraph.cpp index a671796b..5f8fc547 100644 --- a/CallGraph.cpp +++ b/CallGraph.cpp @@ -33,7 +33,11 @@ namespace solidity void CallGraph::addNode(ASTNode const& _node) { - _node.accept(*this); + if (!m_nodesSeen.count(&_node)) + { + m_workQueue.push(&_node); + m_nodesSeen.insert(&_node); + } } set<FunctionDefinition const*> const& CallGraph::getCalls() @@ -53,20 +57,26 @@ void CallGraph::computeCallGraph() bool CallGraph::visit(Identifier const& _identifier) { - FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); - if (fun) + if (auto fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration())) { - if (m_overrideResolver) - fun = (*m_overrideResolver)(fun->getName()); + if (m_functionOverrideResolver) + fun = (*m_functionOverrideResolver)(fun->getName()); solAssert(fun, "Error finding override for function " + fun->getName()); - addFunction(*fun); + addNode(*fun); + } + if (auto modifier = dynamic_cast<ModifierDefinition const*>(_identifier.getReferencedDeclaration())) + { + if (m_modifierOverrideResolver) + modifier = (*m_modifierOverrideResolver)(modifier->getName()); + solAssert(modifier, "Error finding override for modifier " + modifier->getName()); + addNode(*modifier); } return true; } bool CallGraph::visit(FunctionDefinition const& _function) { - addFunction(_function); + m_functionsSeen.insert(&_function); return true; } @@ -83,7 +93,7 @@ bool CallGraph::visit(MemberAccess const& _memberAccess) for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions()) if (function->getName() == _memberAccess.getMemberName()) { - addFunction(*function); + addNode(*function); return true; } } @@ -91,14 +101,5 @@ bool CallGraph::visit(MemberAccess const& _memberAccess) return true; } -void CallGraph::addFunction(FunctionDefinition const& _function) -{ - if (!m_functionsSeen.count(&_function)) - { - m_functionsSeen.insert(&_function); - m_workQueue.push(&_function); - } -} - } } diff --git a/CallGraph.h b/CallGraph.h index 90176e7e..9af5cdf9 100644 --- a/CallGraph.h +++ b/CallGraph.h @@ -39,9 +39,13 @@ namespace solidity class CallGraph: private ASTConstVisitor { public: - using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; + using FunctionOverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; + using ModifierOverrideResolver = std::function<ModifierDefinition const*(std::string const&)>; - CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {} + CallGraph(FunctionOverrideResolver const& _functionOverrideResolver, + ModifierOverrideResolver const& _modifierOverrideResolver): + m_functionOverrideResolver(&_functionOverrideResolver), + m_modifierOverrideResolver(&_modifierOverrideResolver) {} void addNode(ASTNode const& _node); @@ -53,11 +57,12 @@ private: virtual bool visit(MemberAccess const& _memberAccess) override; void computeCallGraph(); - void addFunction(FunctionDefinition const& _function); - OverrideResolver const* m_overrideResolver; + FunctionOverrideResolver const* m_functionOverrideResolver; + ModifierOverrideResolver const* m_modifierOverrideResolver; + std::set<ASTNode const*> m_nodesSeen; std::set<FunctionDefinition const*> m_functionsSeen; - std::queue<FunctionDefinition const*> m_workQueue; + std::queue<ASTNode const*> m_workQueue; }; } diff --git a/Compiler.cpp b/Compiler.cpp index 5a434a71..5190f93f 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -42,9 +42,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, initializeContext(_contract, _contracts); for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + { for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) if (!function->isConstructor()) m_context.addFunction(*function); + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + m_context.addModifier(*modifier); + } appendFunctionSelector(_contract); for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) @@ -67,6 +71,13 @@ void Compiler::initializeContext(ContractDefinition const& _contract, void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { + std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); + + // Make all modifiers known to the context. + for (ContractDefinition const* contract: bases) + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + m_context.addModifier(*modifier); + // arguments for base constructors, filled in derived-to-base order map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments; set<FunctionDefinition const*> neededFunctions; @@ -74,7 +85,6 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp // Determine the arguments that are used for the base constructors and also which functions // are needed at compile time. - std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); for (ContractDefinition const* contract: bases) { if (FunctionDefinition const* constructor = contract->getConstructor()) @@ -93,7 +103,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp } } - auto overrideResolver = [&](string const& _name) -> FunctionDefinition const* + auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const* { for (ContractDefinition const* contract: bases) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) @@ -101,21 +111,26 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp return function.get(); return nullptr; }; + auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const* + { + return &m_context.getFunctionModifier(_name); + }; - neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver); + neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver, + modifierOverrideResolver); // First add all overrides (or the functions themselves if there is no override) for (FunctionDefinition const* fun: neededFunctions) { FunctionDefinition const* override = nullptr; if (!fun->isConstructor()) - override = overrideResolver(fun->getName()); + override = functionOverrideResolver(fun->getName()); if (!!override && neededFunctions.count(override)) m_context.addFunction(*override); } // now add the rest for (FunctionDefinition const* fun: neededFunctions) - if (fun->isConstructor() || overrideResolver(fun->getName()) != fun) + if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun) m_context.addFunction(*fun); // Call constructors in base-to-derived order. @@ -150,11 +165,7 @@ void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor, FunctionType constructorType(_constructor); eth::AssemblyItem returnLabel = m_context.pushNewTag(); for (unsigned i = 0; i < _arguments.size(); ++i) - { - compileExpression(*_arguments[i]); - ExpressionCompiler::appendTypeConversion(m_context, *_arguments[i]->getType(), - *constructorType.getParameterTypes()[i]); - } + compileExpression(*_arguments[i], constructorType.getParameterTypes()[i]); m_context.appendJumpTo(m_context.getFunctionEntryLabel(_constructor)); m_context << returnLabel; } @@ -179,9 +190,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) } set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, - function<FunctionDefinition const*(string const&)> const& _resolveOverrides) + function<FunctionDefinition const*(string const&)> const& _resolveFunctionOverrides, + function<ModifierDefinition const*(string const&)> const& _resolveModifierOverrides) { - CallGraph callgraph(_resolveOverrides); + CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides); for (ASTNode const* node: _nodes) callgraph.addNode(*node); return callgraph.getCalls(); @@ -280,20 +292,28 @@ bool Compiler::visit(FunctionDefinition const& _function) m_returnTag = m_context.newTag(); m_breakTags.clear(); m_continueTags.clear(); + m_stackCleanupForReturn = 0; + m_currentFunction = &_function; + m_modifierDepth = 0; m_context << m_context.getFunctionEntryLabel(_function); // stack upon entry: [return address] [arg0] [arg1] ... [argn] // reserve additional slots: [retarg0] ... [retargm] [localvar0] ... [localvarp] + unsigned parametersSize = CompilerUtils::getSizeOnStack(_function.getParameters()); + m_context.adjustStackOffset(parametersSize); for (ASTPointer<VariableDeclaration const> const& variable: _function.getParameters()) - m_context.addVariable(*variable); + { + m_context.addVariable(*variable, parametersSize); + parametersSize -= variable->getType()->getSizeOnStack(); + } for (ASTPointer<VariableDeclaration const> const& variable: _function.getReturnParameters()) m_context.addAndInitializeVariable(*variable); for (VariableDeclaration const* localVariable: _function.getLocalVariables()) m_context.addAndInitializeVariable(*localVariable); - _function.getBody().accept(*this); + appendModifierOrFunctionCode(); m_context << m_returnTag; @@ -420,13 +440,15 @@ bool Compiler::visit(Return const& _return) //@todo modifications are needed to make this work with functions returning multiple values if (Expression const* expression = _return.getExpression()) { - compileExpression(*expression); - VariableDeclaration const& firstVariable = *_return.getFunctionReturnParameters().getParameters().front(); - ExpressionCompiler::appendTypeConversion(m_context, *expression->getType(), *firstVariable.getType()); - + solAssert(_return.getFunctionReturnParameters(), "Invalid return parameters pointer."); + VariableDeclaration const& firstVariable = *_return.getFunctionReturnParameters()->getParameters().front(); + compileExpression(*expression, firstVariable.getType()); CompilerUtils(m_context).moveToStackVariable(firstVariable); } + for (unsigned i = 0; i < m_stackCleanupForReturn; ++i) + m_context << eth::Instruction::POP; m_context.appendJumpTo(m_returnTag); + m_context.adjustStackOffset(m_stackCleanupForReturn); return false; } @@ -434,10 +456,7 @@ bool Compiler::visit(VariableDefinition const& _variableDefinition) { if (Expression const* expression = _variableDefinition.getExpression()) { - compileExpression(*expression); - ExpressionCompiler::appendTypeConversion(m_context, - *expression->getType(), - *_variableDefinition.getDeclaration().getType()); + compileExpression(*expression, _variableDefinition.getDeclaration().getType()); CompilerUtils(m_context).moveToStackVariable(_variableDefinition.getDeclaration()); } return false; @@ -451,9 +470,51 @@ bool Compiler::visit(ExpressionStatement const& _expressionStatement) return false; } -void Compiler::compileExpression(Expression const& _expression) +bool Compiler::visit(PlaceholderStatement const&) +{ + ++m_modifierDepth; + appendModifierOrFunctionCode(); + --m_modifierDepth; + return true; +} + +void Compiler::appendModifierOrFunctionCode() +{ + solAssert(m_currentFunction, ""); + if (m_modifierDepth >= m_currentFunction->getModifiers().size()) + m_currentFunction->getBody().accept(*this); + else + { + ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth]; + + ModifierDefinition const& modifier = m_context.getFunctionModifier(modifierInvocation->getName()->getName()); + solAssert(modifier.getParameters().size() == modifierInvocation->getArguments().size(), ""); + for (unsigned i = 0; i < modifier.getParameters().size(); ++i) + { + m_context.addVariable(*modifier.getParameters()[i]); + compileExpression(*modifierInvocation->getArguments()[i], + modifier.getParameters()[i]->getType()); + } + for (VariableDeclaration const* localVariable: modifier.getLocalVariables()) + m_context.addAndInitializeVariable(*localVariable); + + unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier.getParameters()) + + CompilerUtils::getSizeOnStack(modifier.getLocalVariables()); + m_stackCleanupForReturn += c_stackSurplus; + + modifier.getBody().accept(*this); + + for (unsigned i = 0; i < c_stackSurplus; ++i) + m_context << eth::Instruction::POP; + m_stackCleanupForReturn -= c_stackSurplus; + } +} + +void Compiler::compileExpression(Expression const& _expression, TypePointer const& _targetType) { ExpressionCompiler::compileExpression(m_context, _expression, m_optimize); + if (_targetType) + ExpressionCompiler::appendTypeConversion(m_context, *_expression.getType(), *_targetType); } } @@ -31,7 +31,8 @@ namespace solidity { class Compiler: private ASTConstVisitor { public: - explicit Compiler(bool _optimize = false): m_optimize(_optimize), m_context(), m_returnTag(m_context.newTag()) {} + explicit Compiler(bool _optimize = false): m_optimize(_optimize), m_context(), + m_returnTag(m_context.newTag()) {} void compileContract(ContractDefinition const& _contract, std::map<ContractDefinition const*, bytes const*> const& _contracts); @@ -52,7 +53,8 @@ private: /// Recursively searches the call graph and returns all functions referenced inside _nodes. /// _resolveOverride is called to resolve virtual function overrides. std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, - std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride); + std::function<FunctionDefinition const*(std::string const&)> const& _resolveFunctionOverride, + std::function<ModifierDefinition const*(std::string const&)> const& _resolveModifierOverride); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function, from memory if /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. @@ -70,8 +72,13 @@ private: virtual bool visit(Return const& _return) override; virtual bool visit(VariableDefinition const& _variableDefinition) override; virtual bool visit(ExpressionStatement const& _expressionStatement) override; + virtual bool visit(PlaceholderStatement const&) override; - void compileExpression(Expression const& _expression); + /// Appends one layer of function modifier code of the current function, or the function + /// body itself if the last modifier was reached. + void appendModifierOrFunctionCode(); + + void compileExpression(Expression const& _expression, TypePointer const& _targetType = TypePointer()); bool const m_optimize; CompilerContext m_context; @@ -79,6 +86,9 @@ private: std::vector<eth::AssemblyItem> m_breakTags; ///< tag to jump to for a "break" statement std::vector<eth::AssemblyItem> m_continueTags; ///< tag to jump to for a "continue" statement eth::AssemblyItem m_returnTag; ///< tag to jump to for a "return" statement + unsigned m_modifierDepth = 0; + FunctionDefinition const* m_currentFunction; + unsigned m_stackCleanupForReturn; ///< this number of stack elements need to be removed before jump to m_returnTag }; } diff --git a/CompilerContext.cpp b/CompilerContext.cpp index 27ec3efd..ad1877ba 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -43,10 +43,11 @@ void CompilerContext::addStateVariable(VariableDeclaration const& _declaration) m_stateVariablesSize += _declaration.getType()->getStorageSize(); } -void CompilerContext::addVariable(VariableDeclaration const& _declaration) +void CompilerContext::addVariable(VariableDeclaration const& _declaration, + unsigned _offsetToCurrent) { - m_localVariables[&_declaration] = m_localVariablesSize; - m_localVariablesSize += _declaration.getType()->getSizeOnStack(); + solAssert(m_asm.deposit() >= 0 && unsigned(m_asm.deposit()) >= _offsetToCurrent, ""); + m_localVariables[&_declaration] = unsigned(m_asm.deposit()) - _offsetToCurrent; } void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _declaration) @@ -56,7 +57,6 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla int const size = _declaration.getType()->getSizeOnStack(); for (int i = 0; i < size; ++i) *this << u256(0); - m_asm.adjustDeposit(-size); } void CompilerContext::addFunction(FunctionDefinition const& _function) @@ -66,6 +66,11 @@ void CompilerContext::addFunction(FunctionDefinition const& _function) m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag)); } +void CompilerContext::addModifier(ModifierDefinition const& _modifier) +{ + m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier)); +} + bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const { auto ret = m_compiledContracts.find(&_contract); @@ -75,7 +80,7 @@ bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _con bool CompilerContext::isLocalVariable(Declaration const* _declaration) const { - return m_localVariables.count(_declaration) > 0; + return m_localVariables.count(_declaration); } eth::AssemblyItem CompilerContext::getFunctionEntryLabel(FunctionDefinition const& _function) const @@ -92,21 +97,28 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti return res->second.tag(); } +ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const +{ + auto res = m_functionModifiers.find(_name); + solAssert(res != m_functionModifiers.end(), "Function modifier override not found."); + return *res->second; +} + unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const { auto res = m_localVariables.find(&_declaration); solAssert(res != m_localVariables.end(), "Variable not found on stack."); - return m_localVariablesSize - res->second - 1; + return res->second; } unsigned CompilerContext::baseToCurrentStackOffset(unsigned _baseOffset) const { - return _baseOffset + m_asm.deposit(); + return m_asm.deposit() - _baseOffset - 1; } unsigned CompilerContext::currentToBaseStackOffset(unsigned _offset) const { - return -baseToCurrentStackOffset(-_offset); + return m_asm.deposit() - _offset - 1; } u256 CompilerContext::getStorageLocationOfVariable(const Declaration& _declaration) const diff --git a/CompilerContext.h b/CompilerContext.h index cde992d5..d82dfe51 100644 --- a/CompilerContext.h +++ b/CompilerContext.h @@ -42,9 +42,11 @@ public: void addMagicGlobal(MagicVariableDeclaration const& _declaration); void addStateVariable(VariableDeclaration const& _declaration); void startNewFunction() { m_localVariables.clear(); m_asm.setDeposit(0); } - void addVariable(VariableDeclaration const& _declaration); + void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0); void addAndInitializeVariable(VariableDeclaration const& _declaration); void addFunction(FunctionDefinition const& _function); + /// Adds the given modifier to the list by name if the name is not present already. + void addModifier(ModifierDefinition const& _modifier); void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; } bytes const& getCompiledContract(ContractDefinition const& _contract) const; @@ -59,7 +61,8 @@ public: eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; /// @returns the entry label of the given function and takes overrides into account. eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; - /// Returns the distance of the given local variable from the top of the local variable stack. + ModifierDefinition const& getFunctionModifier(std::string const& _name) const; + /// Returns the distance of the given local variable from the bottom of the stack (of the current function). unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns /// the distance of that variable from the current top of the stack. @@ -112,14 +115,14 @@ private: u256 m_stateVariablesSize = 0; /// Storage offsets of state variables std::map<Declaration const*, u256> m_stateVariables; - /// Offsets of local variables on the stack (relative to stack base). + /// Positions of local variables on the stack. std::map<Declaration const*, unsigned> m_localVariables; - /// Sum of stack sizes of local variables - unsigned m_localVariablesSize; /// Labels pointing to the entry points of funcitons. std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; /// Labels pointing to the entry points of function overrides. std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels; + /// Mapping to obtain function modifiers by name. Should be filled from derived to base. + std::map<std::string, ModifierDefinition const*> m_functionModifiers; }; } diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index c67cd727..43201fe1 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -60,6 +60,11 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) ReferencesResolver resolver(*structDef, *this, &_contract, nullptr); for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables()) ReferencesResolver resolver(*variable, *this, &_contract, nullptr); + for (ASTPointer<ModifierDefinition> const& modifier: _contract.getFunctionModifiers()) + { + m_currentScope = &m_scopes[modifier.get()]; + ReferencesResolver resolver(*modifier, *this, &_contract, nullptr); + } for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) { m_currentScope = &m_scopes[function.get()]; @@ -227,6 +232,19 @@ void DeclarationRegistrationHelper::endVisit(FunctionDefinition&) closeCurrentScope(); } +bool DeclarationRegistrationHelper::visit(ModifierDefinition& _modifier) +{ + registerDeclaration(_modifier, true); + m_currentFunction = &_modifier; + return true; +} + +void DeclarationRegistrationHelper::endVisit(ModifierDefinition&) +{ + m_currentFunction = nullptr; + closeCurrentScope(); +} + void DeclarationRegistrationHelper::endVisit(VariableDefinition& _variableDefinition) { // Register the local variables with the function @@ -293,8 +311,7 @@ void ReferencesResolver::endVisit(VariableDeclaration& _variable) bool ReferencesResolver::visit(Return& _return) { - solAssert(m_returnParameters, "Return parameters not set."); - _return.setFunctionReturnParameters(*m_returnParameters); + _return.setFunctionReturnParameters(m_returnParameters); return true; } diff --git a/NameAndTypeResolver.h b/NameAndTypeResolver.h index f97c7ae5..ba327a59 100644 --- a/NameAndTypeResolver.h +++ b/NameAndTypeResolver.h @@ -100,6 +100,8 @@ private: void endVisit(StructDefinition& _struct); bool visit(FunctionDefinition& _function); void endVisit(FunctionDefinition& _function); + bool visit(ModifierDefinition& _modifier); + void endVisit(ModifierDefinition& _modifier); void endVisit(VariableDefinition& _variableDefinition); bool visit(VariableDeclaration& _declaration); @@ -109,7 +111,7 @@ private: std::map<ASTNode const*, DeclarationContainer>& m_scopes; Declaration const* m_currentScope; - FunctionDefinition* m_currentFunction; + VariableScope* m_currentFunction; }; /** @@ -121,6 +121,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() vector<ASTPointer<StructDefinition>> structs; vector<ASTPointer<VariableDeclaration>> stateVariables; vector<ASTPointer<FunctionDefinition>> functions; + vector<ASTPointer<ModifierDefinition>> modifiers; if (m_scanner->getCurrentToken() == Token::IS) do { @@ -152,13 +153,15 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() stateVariables.push_back(parseVariableDeclaration(allowVar)); expectToken(Token::SEMICOLON); } + else if (currentToken == Token::MODIFIER) + modifiers.push_back(parseModifierDefinition()); else - BOOST_THROW_EXCEPTION(createParserError("Function, variable or struct declaration expected.")); + BOOST_THROW_EXCEPTION(createParserError("Function, variable, struct or modifier declaration expected.")); } nodeFactory.markEndPosition(); expectToken(Token::RBRACE); return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs, - stateVariables, functions); + stateVariables, functions, modifiers); } ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier() @@ -189,10 +192,18 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, A ASTPointer<ASTString> name(expectIdentifierToken()); ASTPointer<ParameterList> parameters(parseParameterList()); bool isDeclaredConst = false; - if (m_scanner->getCurrentToken() == Token::CONST) + vector<ASTPointer<ModifierInvocation>> modifiers; + while (true) { - isDeclaredConst = true; - m_scanner->next(); + if (m_scanner->getCurrentToken() == Token::CONST) + { + isDeclaredConst = true; + m_scanner->next(); + } + else if (m_scanner->getCurrentToken() == Token::IDENTIFIER) + modifiers.push_back(parseModifierInvocation()); + else + break; } ASTPointer<ParameterList> returnParameters; if (m_scanner->getCurrentToken() == Token::RETURNS) @@ -212,8 +223,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, A nodeFactory.setEndPositionFromNode(block); bool const c_isConstructor = (_contractName && *name == *_contractName); return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring, - parameters, - isDeclaredConst, returnParameters, block); + parameters, isDeclaredConst, modifiers, + returnParameters, block); } ASTPointer<StructDefinition> Parser::parseStructDefinition() @@ -242,6 +253,50 @@ ASTPointer<VariableDeclaration> Parser::parseVariableDeclaration(bool _allowVar) return nodeFactory.createNode<VariableDeclaration>(type, expectIdentifierToken()); } +ASTPointer<ModifierDefinition> Parser::parseModifierDefinition() +{ + ScopeGuard resetModifierFlag([this]() { m_insideModifier = false; }); + m_insideModifier = true; + + ASTNodeFactory nodeFactory(*this); + ASTPointer<ASTString> docstring; + if (m_scanner->getCurrentCommentLiteral() != "") + docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); + + expectToken(Token::MODIFIER); + ASTPointer<ASTString> name(expectIdentifierToken()); + ASTPointer<ParameterList> parameters; + if (m_scanner->getCurrentToken() == Token::LPAREN) + parameters = parseParameterList(); + else + { + // create an empty parameter list at a zero-length location + ASTNodeFactory nodeFactory(*this); + nodeFactory.setLocationEmpty(); + parameters = nodeFactory.createNode<ParameterList>(vector<ASTPointer<VariableDeclaration>>()); + } + ASTPointer<Block> block = parseBlock(); + nodeFactory.setEndPositionFromNode(block); + return nodeFactory.createNode<ModifierDefinition>(name, docstring, parameters, block); +} + +ASTPointer<ModifierInvocation> Parser::parseModifierInvocation() +{ + ASTNodeFactory nodeFactory(*this); + ASTPointer<Identifier> name = ASTNodeFactory(*this).createNode<Identifier>(expectIdentifierToken()); + vector<ASTPointer<Expression>> arguments; + if (m_scanner->getCurrentToken() == Token::LPAREN) + { + m_scanner->next(); + arguments = parseFunctionCallArguments(); + nodeFactory.markEndPosition(); + expectToken(Token::RPAREN); + } + else + nodeFactory.setEndPositionFromNode(name); + return nodeFactory.createNode<ModifierInvocation>(name, arguments); +} + ASTPointer<TypeName> Parser::parseTypeName(bool _allowVar) { ASTPointer<TypeName> type; @@ -354,8 +409,16 @@ ASTPointer<Statement> Parser::parseStatement() nodeFactory.setEndPositionFromNode(expression); } statement = nodeFactory.createNode<Return>(expression); + break; } - break; + case Token::IDENTIFIER: + if (m_insideModifier && m_scanner->getCurrentLiteral() == "_") + { + statement = ASTNodeFactory(*this).createNode<PlaceholderStatement>(); + m_scanner->next(); + return statement; + } + // fall-through default: statement = parseVarDefOrExprStmt(); } @@ -53,6 +53,8 @@ private: ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName); ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); + ASTPointer<ModifierDefinition> parseModifierDefinition(); + ASTPointer<ModifierInvocation> parseModifierInvocation(); ASTPointer<TypeName> parseTypeName(bool _allowVar); ASTPointer<Mapping> parseMapping(); ASTPointer<ParameterList> parseParameterList(bool _allowEmpty = true); @@ -90,6 +92,8 @@ private: ParserError createParserError(std::string const& _description) const; std::shared_ptr<Scanner> m_scanner; + /// Flag that signifies whether '_' is parsed as a PlaceholderStatement or a regular identifier. + bool m_insideModifier = false; }; } @@ -159,6 +159,7 @@ namespace solidity K(IF, "if", 0) \ K(IMPORT, "import", 0) \ K(MAPPING, "mapping", 0) \ + K(MODIFIER, "modifier", 0) \ K(NEW, "new", 0) \ K(PUBLIC, "public", 0) \ K(PRIVATE, "private", 0) \ @@ -724,6 +724,38 @@ MemberList const& TypeType::getMembers() const return *m_members; } +ModifierType::ModifierType(const ModifierDefinition& _modifier) +{ + TypePointers params; + params.reserve(_modifier.getParameters().size()); + for (ASTPointer<VariableDeclaration> const& var: _modifier.getParameters()) + params.push_back(var->getType()); + swap(params, m_parameterTypes); +} + +bool ModifierType::operator==(Type const& _other) const +{ + if (_other.getCategory() != getCategory()) + return false; + ModifierType const& other = dynamic_cast<ModifierType const&>(_other); + + if (m_parameterTypes.size() != other.m_parameterTypes.size()) + return false; + auto typeCompare = [](TypePointer const& _a, TypePointer const& _b) -> bool { return *_a == *_b; }; + + if (!equal(m_parameterTypes.cbegin(), m_parameterTypes.cend(), + other.m_parameterTypes.cbegin(), typeCompare)) + return false; + return true; +} + +string ModifierType::toString() const +{ + string name = "modifier ("; + for (auto it = m_parameterTypes.begin(); it != m_parameterTypes.end(); ++it) + name += (*it)->toString() + (it + 1 == m_parameterTypes.end() ? "" : ","); + return name + ")"; +} MagicType::MagicType(MagicType::Kind _kind): m_kind(_kind) @@ -75,7 +75,7 @@ class Type: private boost::noncopyable, public std::enable_shared_from_this<Type public: enum class Category { - INTEGER, INTEGER_CONSTANT, BOOL, REAL, STRING, CONTRACT, STRUCT, FUNCTION, MAPPING, VOID, TYPE, MAGIC + INTEGER, INTEGER_CONSTANT, BOOL, REAL, STRING, CONTRACT, STRUCT, FUNCTION, MAPPING, VOID, TYPE, MODIFIER, MAGIC }; ///@{ @@ -464,6 +464,27 @@ private: /** + * The type of a function modifier. Not used for anything for now. + */ +class ModifierType: public Type +{ +public: + virtual Category getCategory() const override { return Category::MODIFIER; } + explicit ModifierType(ModifierDefinition const& _modifier); + + virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { return TypePointer(); } + virtual bool canBeStored() const override { return false; } + virtual u256 getStorageSize() const override { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Storage size of non-storable type type requested.")); } + virtual bool canLiveOutsideStorage() const override { return false; } + virtual bool operator==(Type const& _other) const override; + virtual std::string toString() const override; + +private: + TypePointers m_parameterTypes; +}; + + +/** * Special type for magic variables (block, msg, tx), similar to a struct but without any reference * (it always references a global singleton by name). */ diff --git a/grammar.txt b/grammar.txt index 11d99854..b97dac5d 100644 --- a/grammar.txt +++ b/grammar.txt @@ -1,14 +1,14 @@ ContractDefinition = 'contract' Identifier ( 'is' InheritanceSpecifier (',' InheritanceSpecifier )* )? '{' ContractPart* '}' -ContractPart = VariableDeclaration ';' | StructDefinition | +ContractPart = VariableDeclaration ';' | StructDefinition | ModifierDefinition | FunctionDefinition | 'public:' | 'private:' InheritanceSpecifier = Identifier ( '(' Expression ( ',' Expression )* ')' )? StructDefinition = 'struct' Identifier '{' ( VariableDeclaration (';' VariableDeclaration)* )? '} - -FunctionDefinition = 'function' Identifier ParameterList 'const'? +ModifierDefinition = 'modifier' Identifier ParameterList? Block +FunctionDefinition = 'function' Identifier ParameterList ( Identifier | 'constant' )* ( 'returns' ParameterList )? Block ParameterList = '(' ( VariableDeclaration (',' VariableDeclaration)* )? ')' // semantic restriction: mappings and structs (recursively) containing mappings |