diff options
-rw-r--r-- | AST.cpp | 19 | ||||
-rwxr-xr-x | AST.h | 26 | ||||
-rw-r--r-- | ASTForward.h | 1 | ||||
-rw-r--r-- | AST_accept.h | 22 | ||||
-rw-r--r-- | NameAndTypeResolver.cpp | 11 | ||||
-rw-r--r-- | Parser.cpp | 21 | ||||
-rw-r--r-- | Parser.h | 1 | ||||
-rw-r--r-- | grammar.txt | 5 |
8 files changed, 95 insertions, 11 deletions
@@ -43,6 +43,9 @@ TypeError ASTNode::createTypeError(string const& _description) const void ContractDefinition::checkTypeRequirements() { + for (ASTPointer<InheritanceSpecifier> const& base: getBaseContracts()) + base->checkTypeRequirements(); + checkIllegalOverrides(); FunctionDefinition const* constructor = getConstructor(); @@ -123,6 +126,22 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition: return *m_interfaceFunctionList; } +void InheritanceSpecifier::checkTypeRequirements() +{ + m_baseName->checkTypeRequirements(); + for (ASTPointer<Expression> const& argument: m_arguments) + argument->checkTypeRequirements(); + + ContractDefinition const* base = dynamic_cast<ContractDefinition const*>(m_baseName->getReferencedDeclaration()); + solAssert(base, "Base contract not available."); + TypePointers parameterTypes = ContractType(*base).getConstructorType()->getParameterTypes(); + if (parameterTypes.size() != m_arguments.size()) + BOOST_THROW_EXCEPTION(createTypeError("Wrong argument count for constructor call.")); + for (size_t i = 0; i < m_arguments.size(); ++i) + if (!m_arguments[i]->getType()->isImplicitlyConvertibleTo(*parameterTypes[i])) + BOOST_THROW_EXCEPTION(createTypeError("Invalid type for argument in constructer call.")); +} + void StructDefinition::checkMemberTypes() const { for (ASTPointer<VariableDeclaration> const& member: getMembers()) @@ -158,7 +158,7 @@ public: ContractDefinition(Location const& _location, ASTPointer<ASTString> const& _name, ASTPointer<ASTString> const& _documentation, - std::vector<ASTPointer<Identifier>> const& _baseContracts, + std::vector<ASTPointer<InheritanceSpecifier>> const& _baseContracts, std::vector<ASTPointer<StructDefinition>> const& _definedStructs, std::vector<ASTPointer<VariableDeclaration>> const& _stateVariables, std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions): @@ -173,7 +173,7 @@ public: virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - std::vector<ASTPointer<Identifier>> const& getBaseContracts() const { return m_baseContracts; } + std::vector<ASTPointer<InheritanceSpecifier>> const& getBaseContracts() const { return m_baseContracts; } std::vector<ASTPointer<StructDefinition>> const& getDefinedStructs() const { return m_definedStructs; } std::vector<ASTPointer<VariableDeclaration>> const& getStateVariables() const { return m_stateVariables; } std::vector<ASTPointer<FunctionDefinition>> const& getDefinedFunctions() const { return m_definedFunctions; } @@ -203,7 +203,7 @@ private: std::vector<std::pair<FixedHash<4>, FunctionDefinition const*>> const& getInterfaceFunctionList() const; - std::vector<ASTPointer<Identifier>> m_baseContracts; + std::vector<ASTPointer<InheritanceSpecifier>> m_baseContracts; std::vector<ASTPointer<StructDefinition>> m_definedStructs; std::vector<ASTPointer<VariableDeclaration>> m_stateVariables; std::vector<ASTPointer<FunctionDefinition>> m_definedFunctions; @@ -213,6 +213,26 @@ private: mutable std::unique_ptr<std::vector<std::pair<FixedHash<4>, FunctionDefinition const*>>> m_interfaceFunctionList; }; +class InheritanceSpecifier: public ASTNode +{ +public: + InheritanceSpecifier(Location const& _location, ASTPointer<Identifier> const& _baseName, + std::vector<ASTPointer<Expression>> _arguments): + ASTNode(_location), m_baseName(_baseName), m_arguments(_arguments) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + ASTPointer<Identifier> const& getName() const { return m_baseName; } + std::vector<ASTPointer<Expression>> const& getArguments() const { return m_arguments; } + + void checkTypeRequirements(); + +private: + ASTPointer<Identifier> m_baseName; + std::vector<ASTPointer<Expression>> m_arguments; +}; + class StructDefinition: public Declaration { public: diff --git a/ASTForward.h b/ASTForward.h index c960fc8f..da0a8812 100644 --- a/ASTForward.h +++ b/ASTForward.h @@ -38,6 +38,7 @@ class SourceUnit; class ImportDirective; class Declaration; class ContractDefinition; +class InheritanceSpecifier; class StructDefinition; class ParameterList; class FunctionDefinition; diff --git a/AST_accept.h b/AST_accept.h index 7f3db85a..b77cfe1c 100644 --- a/AST_accept.h +++ b/AST_accept.h @@ -61,6 +61,7 @@ void ContractDefinition::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) { + listAccept(m_baseContracts, _visitor); listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); @@ -72,6 +73,7 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const { if (_visitor.visit(*this)) { + listAccept(m_baseContracts, _visitor); listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); @@ -79,6 +81,26 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void InheritanceSpecifier::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + { + m_baseName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + +void InheritanceSpecifier::accept(ASTConstVisitor& _visitor) const +{ + if (_visitor.visit(*this)) + { + m_baseName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + void StructDefinition::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index f208dc78..7df51566 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -48,7 +48,7 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) { m_currentScope = &m_scopes[nullptr]; - for (ASTPointer<Identifier> const& baseContract: _contract.getBaseContracts()) + for (ASTPointer<InheritanceSpecifier> const& baseContract: _contract.getBaseContracts()) ReferencesResolver resolver(*baseContract, *this, &_contract, nullptr); m_currentScope = &m_scopes[&_contract]; @@ -113,18 +113,19 @@ void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) // order in the lists is from derived to base // list of lists to linearize, the last element is the list of direct bases list<list<ContractDefinition const*>> input(1, {&_contract}); - for (ASTPointer<Identifier> const& baseIdentifier: _contract.getBaseContracts()) + for (ASTPointer<InheritanceSpecifier> const& baseSpecifier: _contract.getBaseContracts()) { + ASTPointer<Identifier> baseName = baseSpecifier->getName(); ContractDefinition const* base = dynamic_cast<ContractDefinition const*>( - baseIdentifier->getReferencedDeclaration()); + baseName->getReferencedDeclaration()); if (!base) - BOOST_THROW_EXCEPTION(baseIdentifier->createTypeError("Contract expected.")); + BOOST_THROW_EXCEPTION(baseName->createTypeError("Contract expected.")); // "push_back" has the effect that bases mentioned earlier can overwrite members of bases // mentioned later input.back().push_back(base); vector<ContractDefinition const*> const& basesBases = base->getLinearizedBaseContracts(); if (basesBases.empty()) - BOOST_THROW_EXCEPTION(baseIdentifier->createTypeError("Definition of base has to precede definition of derived contract")); + BOOST_THROW_EXCEPTION(baseName->createTypeError("Definition of base has to precede definition of derived contract")); input.push_front(list<ContractDefinition const*>(basesBases.begin(), basesBases.end())); } vector<ContractDefinition const*> result = cThreeMerge(input); @@ -117,7 +117,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); expectToken(Token::CONTRACT); ASTPointer<ASTString> name = expectIdentifierToken(); - vector<ASTPointer<Identifier>> baseContracts; + vector<ASTPointer<InheritanceSpecifier>> baseContracts; vector<ASTPointer<StructDefinition>> structs; vector<ASTPointer<VariableDeclaration>> stateVariables; vector<ASTPointer<FunctionDefinition>> functions; @@ -125,7 +125,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() do { m_scanner->next(); - baseContracts.push_back(ASTNodeFactory(*this).createNode<Identifier>(expectIdentifierToken())); + baseContracts.push_back(parseInheritanceSpecifier()); } while (m_scanner->getCurrentToken() == Token::COMMA); expectToken(Token::LBRACE); @@ -161,6 +161,23 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() stateVariables, functions); } +ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier() +{ + ASTNodeFactory nodeFactory(*this); + ASTPointer<Identifier> name = ASTNodeFactory(*this).createNode<Identifier>(expectIdentifierToken()); + vector<ASTPointer<Expression>> arguments; + if (m_scanner->getCurrentToken() == Token::LPAREN) + { + m_scanner->next(); + arguments = parseFunctionCallArguments(); + nodeFactory.markEndPosition(); + expectToken(Token::RPAREN); + } + else + nodeFactory.setEndPositionFromNode(name); + return nodeFactory.createNode<InheritanceSpecifier>(name, arguments); +} + ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) { ASTNodeFactory nodeFactory(*this); @@ -49,6 +49,7 @@ private: ///@name Parsing functions for the AST nodes ASTPointer<ImportDirective> parseImportDirective(); ASTPointer<ContractDefinition> parseContractDefinition(); + ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier(); ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); diff --git a/grammar.txt b/grammar.txt index f06d4def..11d99854 100644 --- a/grammar.txt +++ b/grammar.txt @@ -1,7 +1,10 @@ -ContractDefinition = 'contract' Identifier '{' ContractPart* '}' +ContractDefinition = 'contract' Identifier + ( 'is' InheritanceSpecifier (',' InheritanceSpecifier )* )? + '{' ContractPart* '}' ContractPart = VariableDeclaration ';' | StructDefinition | FunctionDefinition | 'public:' | 'private:' +InheritanceSpecifier = Identifier ( '(' Expression ( ',' Expression )* ')' )? StructDefinition = 'struct' Identifier '{' ( VariableDeclaration (';' VariableDeclaration)* )? '} |