diff options
-rw-r--r-- | AST.cpp | 69 | ||||
-rw-r--r-- | AST.h | 34 | ||||
-rw-r--r-- | ASTForward.h | 2 | ||||
-rw-r--r-- | ASTPrinter.cpp | 9 | ||||
-rw-r--r-- | ASTPrinter.h | 3 | ||||
-rw-r--r-- | ASTVisitor.h | 4 | ||||
-rw-r--r-- | CompilerStack.cpp | 87 | ||||
-rw-r--r-- | CompilerStack.h | 5 | ||||
-rw-r--r-- | Parser.cpp | 43 | ||||
-rw-r--r-- | Parser.h | 3 |
10 files changed, 182 insertions, 77 deletions
@@ -33,6 +33,19 @@ namespace dev namespace solidity { +void SourceUnit::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + listAccept(m_nodes, _visitor); + _visitor.endVisit(*this); +} + +void ImportDirective::accept(ASTVisitor& _visitor) +{ + _visitor.visit(*this); + _visitor.endVisit(*this); +} + void ContractDefinition::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) @@ -57,34 +70,6 @@ void StructDefinition::checkValidityOfMembers() checkRecursion(); } -void StructDefinition::checkMemberTypes() -{ - for (ASTPointer<VariableDeclaration> const& member: getMembers()) - if (!member->getType()->canBeStored()) - BOOST_THROW_EXCEPTION(member->createTypeError("Type cannot be used in struct.")); -} - -void StructDefinition::checkRecursion() -{ - set<StructDefinition const*> definitionsSeen; - vector<StructDefinition const*> queue = {this}; - while (!queue.empty()) - { - StructDefinition const* def = queue.back(); - queue.pop_back(); - if (definitionsSeen.count(def)) - BOOST_THROW_EXCEPTION(ParserError() << errinfo_sourceLocation(def->getLocation()) - << errinfo_comment("Recursive struct definition.")); - definitionsSeen.insert(def); - for (ASTPointer<VariableDeclaration> const& member: def->getMembers()) - if (member->getType()->getCategory() == Type::Category::STRUCT) - { - UserDefinedTypeName const& typeName = dynamic_cast<UserDefinedTypeName&>(*member->getTypeName()); - queue.push_back(&dynamic_cast<StructDefinition const&>(*typeName.getReferencedDeclaration())); - } - } -} - void ParameterList::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) @@ -312,6 +297,34 @@ vector<FunctionDefinition const*> ContractDefinition::getInterfaceFunctions() co return exportedFunctions; } +void StructDefinition::checkMemberTypes() +{ + for (ASTPointer<VariableDeclaration> const& member: getMembers()) + if (!member->getType()->canBeStored()) + BOOST_THROW_EXCEPTION(member->createTypeError("Type cannot be used in struct.")); +} + +void StructDefinition::checkRecursion() +{ + set<StructDefinition const*> definitionsSeen; + vector<StructDefinition const*> queue = {this}; + while (!queue.empty()) + { + StructDefinition const* def = queue.back(); + queue.pop_back(); + if (definitionsSeen.count(def)) + BOOST_THROW_EXCEPTION(ParserError() << errinfo_sourceLocation(def->getLocation()) + << errinfo_comment("Recursive struct definition.")); + definitionsSeen.insert(def); + for (ASTPointer<VariableDeclaration> const& member: def->getMembers()) + if (member->getType()->getCategory() == Type::Category::STRUCT) + { + UserDefinedTypeName const& typeName = dynamic_cast<UserDefinedTypeName&>(*member->getTypeName()); + queue.push_back(&dynamic_cast<StructDefinition const&>(*typeName.getReferencedDeclaration())); + } + } +} + void FunctionDefinition::checkTypeRequirements() { for (ASTPointer<VariableDeclaration> const& var: getParameters() + getReturnParameters()) @@ -80,6 +80,40 @@ private: }; /** + * Source unit containing import directives and contract definitions. + */ +class SourceUnit: public ASTNode +{ +public: + SourceUnit(Location const& _location, std::vector<ASTPointer<ASTNode>> const& _nodes): + ASTNode(_location), m_nodes(_nodes) {} + + virtual void accept(ASTVisitor& _visitor) override; + + std::vector<ASTPointer<ASTNode>> getNodes() const { return m_nodes; } + +private: + std::vector<ASTPointer<ASTNode>> m_nodes; +}; + +/** + * Import directive for referencing other files / source objects. + */ +class ImportDirective: public ASTNode +{ +public: + ImportDirective(Location const& _location, ASTPointer<ASTString> const& _url): + ASTNode(_location), m_url(_url) {} + + virtual void accept(ASTVisitor& _visitor) override; + + ASTString const& getURL() const { return *m_url; } + +private: + ASTPointer<ASTString> m_url; +}; + +/** * Abstract AST class for a declaration (contract, function, struct, variable). */ class Declaration: public ASTNode diff --git a/ASTForward.h b/ASTForward.h index a369c8a7..8b4bac1c 100644 --- a/ASTForward.h +++ b/ASTForward.h @@ -34,6 +34,8 @@ namespace solidity { class ASTNode; +class SourceUnit; +class ImportDirective; class Declaration; class ContractDefinition; class StructDefinition; diff --git a/ASTPrinter.cpp b/ASTPrinter.cpp index 987ad11c..c62378fd 100644 --- a/ASTPrinter.cpp +++ b/ASTPrinter.cpp @@ -43,6 +43,13 @@ void ASTPrinter::print(ostream& _stream) } +bool ASTPrinter::visit(ImportDirective& _node) +{ + writeLine("ImportDirective \"" + _node.getURL() + "\""); + printSourcePart(_node); + return goDeeper(); +} + bool ASTPrinter::visit(ContractDefinition& _node) { writeLine("ContractDefinition \"" + _node.getName() + "\""); @@ -270,7 +277,7 @@ bool ASTPrinter::visit(Literal& _node) return goDeeper(); } -void ASTPrinter::endVisit(ASTNode&) +void ASTPrinter::endVisit(ImportDirective&) { m_indentation--; } diff --git a/ASTPrinter.h b/ASTPrinter.h index e0757fbc..1a18fc4a 100644 --- a/ASTPrinter.h +++ b/ASTPrinter.h @@ -42,6 +42,7 @@ public: /// Output the string representation of the AST to _stream. void print(std::ostream& _stream); + bool visit(ImportDirective& _node) override; bool visit(ContractDefinition& _node) override; bool visit(StructDefinition& _node) override; bool visit(ParameterList& _node) override; @@ -73,7 +74,7 @@ public: bool visit(ElementaryTypeNameExpression& _node) override; bool visit(Literal& _node) override; - void endVisit(ASTNode& _node) override; + void endVisit(ImportDirective&) override; void endVisit(ContractDefinition&) override; void endVisit(StructDefinition&) override; void endVisit(ParameterList&) override; diff --git a/ASTVisitor.h b/ASTVisitor.h index 6e579f35..bf1ccc41 100644 --- a/ASTVisitor.h +++ b/ASTVisitor.h @@ -42,6 +42,8 @@ class ASTVisitor { public: virtual bool visit(ASTNode&) { return true; } + virtual bool visit(SourceUnit&) { return true; } + virtual bool visit(ImportDirective&) { return true; } virtual bool visit(ContractDefinition&) { return true; } virtual bool visit(StructDefinition&) { return true; } virtual bool visit(ParameterList&) { return true; } @@ -74,6 +76,8 @@ public: virtual bool visit(Literal&) { return true; } virtual void endVisit(ASTNode&) { } + virtual void endVisit(SourceUnit&) { } + virtual void endVisit(ImportDirective&) { } virtual void endVisit(ContractDefinition&) { } virtual void endVisit(StructDefinition&) { } virtual void endVisit(ParameterList&) { } diff --git a/CompilerStack.cpp b/CompilerStack.cpp index 6535e00d..8f8c84fe 100644 --- a/CompilerStack.cpp +++ b/CompilerStack.cpp @@ -45,10 +45,14 @@ void CompilerStack::parse() { if (!m_scanner) BOOST_THROW_EXCEPTION(CompilerError() << errinfo_comment("Source not available.")); - m_contractASTNode = Parser().parse(m_scanner); + m_sourceUnitASTNode = Parser().parse(m_scanner); m_globalContext = make_shared<GlobalContext>(); - m_globalContext->setCurrentContract(*m_contractASTNode); - NameAndTypeResolver(m_globalContext->getDeclarations()).resolveNamesAndTypes(*m_contractASTNode); + for (ASTPointer<ASTNode> const& node: m_sourceUnitASTNode->getNodes()) + if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) + { + m_globalContext->setCurrentContract(*contract); + NameAndTypeResolver(m_globalContext->getDeclarations()).resolveNamesAndTypes(*contract); + } m_parseSuccessful = true; } @@ -62,10 +66,16 @@ bytes const& CompilerStack::compile(bool _optimize) { if (!m_parseSuccessful) BOOST_THROW_EXCEPTION(CompilerError() << errinfo_comment("Parsing was not successful.")); - m_bytecode.clear(); - m_compiler = make_shared<Compiler>(); - m_compiler->compileContract(*m_contractASTNode, m_globalContext->getMagicVariables()); - return m_bytecode = m_compiler->getAssembledBytecode(_optimize); + //@todo returns only the last contract for now + for (ASTPointer<ASTNode> const& node: m_sourceUnitASTNode->getNodes()) + if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) + { + m_bytecode.clear(); + m_compiler = make_shared<Compiler>(); + m_compiler->compileContract(*contract, m_globalContext->getMagicVariables()); + m_bytecode = m_compiler->getAssembledBytecode(_optimize); + } + return m_bytecode; } bytes const& CompilerStack::compile(string const& _sourceCode, bool _optimize) @@ -87,40 +97,45 @@ string const& CompilerStack::getInterface() BOOST_THROW_EXCEPTION(CompilerError() << errinfo_comment("Parsing was not successful.")); if (m_interface.empty()) { - stringstream interface; - interface << '['; - vector<FunctionDefinition const*> exportedFunctions = m_contractASTNode->getInterfaceFunctions(); - unsigned functionsCount = exportedFunctions.size(); - for (FunctionDefinition const* f: exportedFunctions) - { - auto streamVariables = [&](vector<ASTPointer<VariableDeclaration>> const& _vars) + //@todo returns only the last contract for now + for (ASTPointer<ASTNode> const& node: m_sourceUnitASTNode->getNodes()) + if (ContractDefinition const* contract = dynamic_cast<ContractDefinition*>(node.get())) { - unsigned varCount = _vars.size(); - for (ASTPointer<VariableDeclaration> const& var: _vars) + stringstream interface; + interface << '['; + vector<FunctionDefinition const*> exportedFunctions = contract->getInterfaceFunctions(); + unsigned functionsCount = exportedFunctions.size(); + for (FunctionDefinition const* f: exportedFunctions) { - interface << "{" - << "\"name\":" << escaped(var->getName(), false) << "," - << "\"type\":" << escaped(var->getType()->toString(), false) + auto streamVariables = [&](vector<ASTPointer<VariableDeclaration>> const& _vars) + { + unsigned varCount = _vars.size(); + for (ASTPointer<VariableDeclaration> const& var: _vars) + { + interface << "{" + << "\"name\":" << escaped(var->getName(), false) << "," + << "\"type\":" << escaped(var->getType()->toString(), false) + << "}"; + if (--varCount > 0) + interface << ","; + } + }; + + interface << '{' + << "\"name\":" << escaped(f->getName(), false) << "," + << "\"inputs\":["; + streamVariables(f->getParameters()); + interface << "]," + << "\"outputs\":["; + streamVariables(f->getReturnParameters()); + interface << "]" << "}"; - if (--varCount > 0) + if (--functionsCount > 0) interface << ","; } - }; - - interface << '{' - << "\"name\":" << escaped(f->getName(), false) << "," - << "\"inputs\":["; - streamVariables(f->getParameters()); - interface << "]," - << "\"outputs\":["; - streamVariables(f->getReturnParameters()); - interface << "]" - << "}"; - if (--functionsCount > 0) - interface << ","; - } - interface << ']'; - m_interface = interface.str(); + interface << ']'; + m_interface = interface.str(); + } } return m_interface; } diff --git a/CompilerStack.h b/CompilerStack.h index 6cae8660..19f3cf99 100644 --- a/CompilerStack.h +++ b/CompilerStack.h @@ -32,6 +32,7 @@ namespace solidity { // forward declarations class Scanner; +class SourceUnit; class ContractDefinition; class Compiler; class GlobalContext; @@ -65,7 +66,7 @@ public: /// Returns the previously used scanner, useful for counting lines during error reporting. Scanner const& getScanner() const { return *m_scanner; } - ContractDefinition& getAST() const { return *m_contractASTNode; } + SourceUnit& getAST() const { return *m_sourceUnitASTNode; } /// Compile the given @a _sourceCode to bytecode. If a scanner is provided, it is used for /// scanning the source code - this is useful for printing exception information. @@ -74,7 +75,7 @@ public: private: std::shared_ptr<Scanner> m_scanner; std::shared_ptr<GlobalContext> m_globalContext; - std::shared_ptr<ContractDefinition> m_contractASTNode; + std::shared_ptr<SourceUnit> m_sourceUnitASTNode; bool m_parseSuccessful; std::string m_interface; std::shared_ptr<Compiler> m_compiler; @@ -20,6 +20,7 @@ * Solidity parser. */ +#include <vector> #include <libdevcore/Log.h> #include <libsolidity/BaseTypes.h> #include <libsolidity/Parser.h> @@ -33,13 +34,6 @@ namespace dev namespace solidity { -ASTPointer<ContractDefinition> Parser::parse(shared_ptr<Scanner> const& _scanner) -{ - m_scanner = _scanner; - return parseContractDefinition(); -} - - /// AST node factory that also tracks the begin and end position of an AST node /// while it is being parsed class Parser::ASTNodeFactory @@ -65,6 +59,28 @@ private: Location m_location; }; +ASTPointer<SourceUnit> Parser::parse(shared_ptr<Scanner> const& _scanner) +{ + m_scanner = _scanner; + ASTNodeFactory nodeFactory(*this); + vector<ASTPointer<ASTNode>> nodes; + while (_scanner->getCurrentToken() != Token::EOS) + { + switch (m_scanner->getCurrentToken()) + { + case Token::IMPORT: + nodes.push_back(parseImportDirective()); + break; + case Token::CONTRACT: + nodes.push_back(parseContractDefinition()); + break; + default: + BOOST_THROW_EXCEPTION(createParserError(std::string("Expected import directive or contract definition."))); + } + } + return nodeFactory.createNode<SourceUnit>(nodes); +} + int Parser::getPosition() const { return m_scanner->getCurrentLocation().start; @@ -75,6 +91,18 @@ int Parser::getEndPosition() const return m_scanner->getCurrentLocation().end; } +ASTPointer<ImportDirective> Parser::parseImportDirective() +{ + ASTNodeFactory nodeFactory(*this); + expectToken(Token::IMPORT); + if (m_scanner->getCurrentToken() != Token::STRING_LITERAL) + BOOST_THROW_EXCEPTION(createParserError("Expected string literal (URL).")); + ASTPointer<ASTString> url = getLiteralAndAdvance(); + nodeFactory.markEndPosition(); + expectToken(Token::SEMICOLON); + return nodeFactory.createNode<ImportDirective>(url); +} + ASTPointer<ContractDefinition> Parser::parseContractDefinition() { ASTNodeFactory nodeFactory(*this); @@ -112,7 +140,6 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() } nodeFactory.markEndPosition(); expectToken(Token::RBRACE); - expectToken(Token::EOS); return nodeFactory.createNode<ContractDefinition>(name, structs, stateVariables, functions); } @@ -34,7 +34,7 @@ class Scanner; class Parser { public: - ASTPointer<ContractDefinition> parse(std::shared_ptr<Scanner> const& _scanner); + ASTPointer<SourceUnit> parse(std::shared_ptr<Scanner> const& _scanner); private: class ASTNodeFactory; @@ -46,6 +46,7 @@ private: ///@{ ///@name Parsing functions for the AST nodes + ASTPointer<ImportDirective> parseImportDirective(); ASTPointer<ContractDefinition> parseContractDefinition(); ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); ASTPointer<StructDefinition> parseStructDefinition(); |