From ddf5e20d100535c10315d0ae73ba4ed753ef3397 Mon Sep 17 00:00:00 2001 From: Christian Date: Mon, 19 Jan 2015 23:08:48 +0100 Subject: Call constructors of base classes. --- CallGraph.cpp | 44 +++++++++++++++++++++++++----- CallGraph.h | 7 +++-- Compiler.cpp | 71 +++++++++++++++++++++++++++++++++++++++++-------- Compiler.h | 9 ++++--- NameAndTypeResolver.cpp | 1 - 5 files changed, 108 insertions(+), 24 deletions(-) diff --git a/CallGraph.cpp b/CallGraph.cpp index b30afb61..88d874f3 100644 --- a/CallGraph.cpp +++ b/CallGraph.cpp @@ -31,13 +31,9 @@ namespace dev namespace solidity { -void CallGraph::addFunction(FunctionDefinition const& _function) +void CallGraph::addNode(ASTNode const& _node) { - if (!m_functionsSeen.count(&_function)) - { - m_functionsSeen.insert(&_function); - m_workQueue.push(&_function); - } + _node.accept(*this); } set const& CallGraph::getCalls() @@ -63,5 +59,41 @@ bool CallGraph::visit(Identifier const& _identifier) return true; } +bool CallGraph::visit(FunctionDefinition const& _function) +{ + addFunction(_function); + return true; +} + +bool CallGraph::visit(MemberAccess const& _memberAccess) +{ + // used for "BaseContract.baseContractFunction" + if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE) + { + TypeType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); + if (type.getMembers().getMemberType(_memberAccess.getMemberName())) + { + ContractDefinition const& contract = dynamic_cast(*type.getActualType()) + .getContractDefinition(); + for (ASTPointer const& function: contract.getDefinedFunctions()) + if (function->getName() == _memberAccess.getMemberName()) + { + addFunction(*function); + return true; + } + } + } + return true; +} + +void CallGraph::addFunction(FunctionDefinition const& _function) +{ + if (!m_functionsSeen.count(&_function)) + { + m_functionsSeen.insert(&_function); + m_workQueue.push(&_function); + } +} + } } diff --git a/CallGraph.h b/CallGraph.h index f7af64bf..e3558fc2 100644 --- a/CallGraph.h +++ b/CallGraph.h @@ -38,14 +38,17 @@ namespace solidity class CallGraph: private ASTConstVisitor { public: - void addFunction(FunctionDefinition const& _function); + void addNode(ASTNode const& _node); void computeCallGraph(); std::set const& getCalls(); private: - void addFunctionToQueue(FunctionDefinition const& _function); + virtual bool visit(FunctionDefinition const& _function) override; virtual bool visit(Identifier const& _identifier) override; + virtual bool visit(MemberAccess const& _memberAccess) override; + + void addFunction(FunctionDefinition const& _function); std::set m_functionsSeen; std::queue m_workQueue; diff --git a/Compiler.cpp b/Compiler.cpp index aa3022aa..36316b9a 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -67,19 +67,52 @@ void Compiler::initializeContext(ContractDefinition const& _contract, void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { + // arguments for base constructors, filled in derived-to-base order + map> const*> baseArguments; set neededFunctions; - // TODO constructors of base classes - FunctionDefinition const* constructor = _contract.getConstructor(); - if (constructor) - neededFunctions = getFunctionsNeededByConstructor(*constructor); + set nodesUsedInConstructors; + + // Determine the arguments that are used for the base constructors and also which functions + // are needed at compile time. + std::vector const& bases = _contract.getLinearizedBaseContracts(); + for (ContractDefinition const* contract: bases) + { + if (FunctionDefinition const* constructor = contract->getConstructor()) + nodesUsedInConstructors.insert(constructor); + for (ASTPointer const& base: contract->getBaseContracts()) + { + ContractDefinition const* baseContract = dynamic_cast( + base->getName()->getReferencedDeclaration()); + solAssert(baseContract, ""); + if (baseArguments.count(baseContract) == 0) + { + baseArguments[baseContract] = &base->getArguments(); + for (ASTPointer const& arg: base->getArguments()) + nodesUsedInConstructors.insert(arg.get()); + } + } + } + + //@TODO add virtual functions + neededFunctions = getFunctionsCalled(nodesUsedInConstructors); - // TODO we should add the overridden functions for (FunctionDefinition const* fun: neededFunctions) m_context.addFunction(*fun); - // we have many of them now - if (constructor) - appendConstructorCall(*constructor); + // Call constructors in base-to-derived order. + // The Constructor for the most derived contract is called later. + for (unsigned i = 1; i < bases.size(); i++) + { + ContractDefinition const* base = bases[bases.size() - i]; + solAssert(base, ""); + FunctionDefinition const* baseConstructor = base->getConstructor(); + if (!baseConstructor) + continue; + solAssert(baseArguments[base], ""); + appendBaseConstructorCall(*baseConstructor, *baseArguments[base]); + } + if (_contract.getConstructor()) + appendConstructorCall(*_contract.getConstructor()); eth::AssemblyItem sub = m_context.addSubroutine(_runtimeContext.getAssembly()); // stack contains sub size @@ -92,6 +125,21 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp fun->accept(*this); } +void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor, + vector> const& _arguments) +{ + FunctionType constructorType(_constructor); + eth::AssemblyItem returnLabel = m_context.pushNewTag(); + for (unsigned i = 0; i < _arguments.size(); ++i) + { + compileExpression(*_arguments[i]); + ExpressionCompiler::appendTypeConversion(m_context, *_arguments[i]->getType(), + *constructorType.getParameterTypes()[i]); + } + m_context.appendJumpTo(m_context.getFunctionEntryLabel(_constructor)); + m_context << returnLabel; +} + void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) { eth::AssemblyItem returnTag = m_context.pushNewTag(); @@ -111,11 +159,12 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) m_context << returnTag; } -set Compiler::getFunctionsNeededByConstructor(FunctionDefinition const& _constructor) +set Compiler::getFunctionsCalled(set const& _nodes) { + // TODO this does not add virtual functions CallGraph callgraph; - callgraph.addFunction(_constructor); - callgraph.computeCallGraph(); + for (ASTNode const* node: _nodes) + callgraph.addNode(*node); return callgraph.getCalls(); } diff --git a/Compiler.h b/Compiler.h index 073721e0..ea05f38e 100644 --- a/Compiler.h +++ b/Compiler.h @@ -43,12 +43,13 @@ private: void initializeContext(ContractDefinition const& _contract, std::map const& _contracts); /// Adds the code that is run at creation time. Should be run after exchanging the run-time context - /// with a new and initialized context. - /// adds the constructor code. + /// with a new and initialized context. Adds the constructor code. void packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext); + void appendBaseConstructorCall(FunctionDefinition const& _constructor, + std::vector> const& _arguments); void appendConstructorCall(FunctionDefinition const& _constructor); - /// Recursively searches the call graph and returns all functions needed by the constructor (including itself). - std::set getFunctionsNeededByConstructor(FunctionDefinition const& _constructor); + /// Recursively searches the call graph and returns all functions referenced inside _nodes. + std::set getFunctionsCalled(std::set const& _nodes); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function, from memory if /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index 7df51566..ba5ca134 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -31,7 +31,6 @@ namespace dev namespace solidity { - NameAndTypeResolver::NameAndTypeResolver(std::vector const& _globals) { for (Declaration const* declaration: _globals) -- cgit v1.2.3