diff options
Diffstat (limited to 'libsolidity/codegen/ContractCompiler.cpp')
-rw-r--r-- | libsolidity/codegen/ContractCompiler.cpp | 137 |
1 files changed, 116 insertions, 21 deletions
diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index aabdbb79..b051d260 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -20,19 +20,18 @@ * Solidity compiler. */ +#include <libsolidity/ast/AST.h> +#include <libsolidity/codegen/AsmCodeGen.h> +#include <libsolidity/codegen/CompilerUtils.h> #include <libsolidity/codegen/ContractCompiler.h> #include <libsolidity/codegen/ExpressionCompiler.h> -#include <libsolidity/codegen/CompilerUtils.h> -#include <libsolidity/ast/AST.h> -#include <libyul/AsmCodeGen.h> -#include <liblangutil/ErrorReporter.h> #include <libevmasm/Instruction.h> #include <libevmasm/Assembly.h> #include <libevmasm/GasMeter.h> +#include <liblangutil/ErrorReporter.h> #include <boost/range/adaptor/reversed.hpp> - #include <algorithm> using namespace std; @@ -268,6 +267,89 @@ void ContractCompiler::appendDelegatecallCheck() // "We have not been called via DELEGATECALL". } +void ContractCompiler::appendInternalSelector( + map<FixedHash<4>, eth::AssemblyItem const> const& _entryPoints, + vector<FixedHash<4>> const& _ids, + eth::AssemblyItem const& _notFoundTag, + size_t _runs +) +{ + // Code for selecting from n functions without split: + // n times: dup1, push4 <id_i>, eq, push2/3 <tag_i>, jumpi + // push2/3 <notfound> jump + // (called SELECT[n]) + // Code for selecting from n functions with split: + // dup1, push4 <pivot>, gt, push2/3<tag_less>, jumpi + // SELECT[n/2] + // tag_less: + // SELECT[n/2] + // + // This means each split adds 16-18 bytes of additional code (note the additional jump out!) + // The average execution cost if we do not split at all are: + // (3 + 3 + 3 + 3 + 10) * n/2 = 24 * n/2 = 12 * n + // If we split once: + // (3 + 3 + 3 + 3 + 10) + 24 * n/4 = 24 * (n/4 + 1) = 6 * n + 24; + // + // We should split if + // _runs * 12 * n > _runs * (6 * n + 24) + 17 * createDataGas + // <=> _runs * 6 * (n - 4) > 17 * createDataGas + // + // Which also means that the execution itself is not profitable + // unless we have at least 5 functions. + + // Start with some comparisons to avoid overflow, then do the actual comparison. + bool split = false; + if (_ids.size() <= 4) + split = false; + else if (_runs > (17 * eth::GasCosts::createDataGas) / 6) + split = true; + else + split = (_runs * 6 * (_ids.size() - 4) > 17 * eth::GasCosts::createDataGas); + + if (split) + { + size_t pivotIndex = _ids.size() / 2; + FixedHash<4> pivot{_ids.at(pivotIndex)}; + m_context << dupInstruction(1) << u256(FixedHash<4>::Arith(pivot)) << Instruction::GT; + eth::AssemblyItem lessTag{m_context.appendConditionalJump()}; + // Here, we have funid >= pivot + vector<FixedHash<4>> larger{_ids.begin() + pivotIndex, _ids.end()}; + appendInternalSelector(_entryPoints, larger, _notFoundTag, _runs); + m_context << lessTag; + // Here, we have funid < pivot + vector<FixedHash<4>> smaller{_ids.begin(), _ids.begin() + pivotIndex}; + appendInternalSelector(_entryPoints, smaller, _notFoundTag, _runs); + } + else + { + for (auto const& id: _ids) + { + m_context << dupInstruction(1) << u256(FixedHash<4>::Arith(id)) << Instruction::EQ; + m_context.appendConditionalJumpTo(_entryPoints.at(id)); + } + m_context.appendJumpTo(_notFoundTag); + } +} + +namespace +{ + +// Helper function to check if any function is payable +bool hasPayableFunctions(ContractDefinition const& _contract) +{ + FunctionDefinition const* fallback = _contract.fallbackFunction(); + if (fallback && fallback->isPayable()) + return true; + + for (auto const& it: _contract.interfaceFunctions()) + if (it.second->isPayable()) + return true; + + return false; +} + +} + void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contract) { map<FixedHash<4>, FunctionTypePointer> interfaceFunctions = _contract.interfaceFunctions(); @@ -279,6 +361,15 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac } FunctionDefinition const* fallback = _contract.fallbackFunction(); + solAssert(!_contract.isLibrary() || !fallback, "Libraries can't have fallback functions"); + + bool needToAddCallvalueCheck = true; + if (!hasPayableFunctions(_contract) && !interfaceFunctions.empty() && !_contract.isLibrary()) + { + appendCallValueCheck(); + needToAddCallvalueCheck = false; + } + eth::AssemblyItem notFound = m_context.newTag(); // directly jump to fallback if the data is too short to contain a function selector // also guards against short data @@ -287,22 +378,26 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac // retrieve the function signature hash from the calldata if (!interfaceFunctions.empty()) + { CompilerUtils(m_context).loadFromMemory(0, IntegerType(CompilerUtils::dataStartOffset * 8), true); - // stack now is: <can-call-non-view-functions>? <funhash> - for (auto const& it: interfaceFunctions) - { - callDataUnpackerEntryPoints.insert(std::make_pair(it.first, m_context.newTag())); - m_context << dupInstruction(1) << u256(FixedHash<4>::Arith(it.first)) << Instruction::EQ; - m_context.appendConditionalJumpTo(callDataUnpackerEntryPoints.at(it.first)); + // stack now is: <can-call-non-view-functions>? <funhash> + vector<FixedHash<4>> sortedIDs; + for (auto const& it: interfaceFunctions) + { + callDataUnpackerEntryPoints.emplace(it.first, m_context.newTag()); + sortedIDs.emplace_back(it.first); + } + std::sort(sortedIDs.begin(), sortedIDs.end()); + appendInternalSelector(callDataUnpackerEntryPoints, sortedIDs, notFound, m_optimise_runs); } - m_context.appendJumpTo(notFound); m_context << notFound; + if (fallback) { solAssert(!_contract.isLibrary(), ""); - if (!fallback->isPayable()) + if (!fallback->isPayable() && needToAddCallvalueCheck) appendCallValueCheck(); solAssert(fallback->isFallback(), ""); @@ -332,7 +427,7 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac m_context.setStackOffset(0); // We have to allow this for libraries, because value of the previous // call is still visible in the delegatecall. - if (!functionType->isPayable() && !_contract.isLibrary()) + if (!functionType->isPayable() && !_contract.isLibrary() && needToAddCallvalueCheck) appendCallValueCheck(); // Return tag is used to jump out of the function. @@ -618,7 +713,7 @@ bool ContractCompiler::visit(InlineAssembly const& _inlineAssembly) } }; solAssert(_inlineAssembly.annotation().analysisInfo, ""); - yul::CodeGenerator::assemble( + CodeGenerator::assemble( _inlineAssembly.operations(), *_inlineAssembly.annotation().analysisInfo, m_context.nonConstAssembly(), @@ -656,14 +751,14 @@ bool ContractCompiler::visit(WhileStatement const& _whileStatement) eth::AssemblyItem loopStart = m_context.newTag(); eth::AssemblyItem loopEnd = m_context.newTag(); - m_breakTags.push_back({loopEnd, m_context.stackHeight()}); + m_breakTags.emplace_back(loopEnd, m_context.stackHeight()); m_context << loopStart; if (_whileStatement.isDoWhile()) { eth::AssemblyItem condition = m_context.newTag(); - m_continueTags.push_back({condition, m_context.stackHeight()}); + m_continueTags.emplace_back(condition, m_context.stackHeight()); _whileStatement.body().accept(*this); @@ -674,7 +769,7 @@ bool ContractCompiler::visit(WhileStatement const& _whileStatement) } else { - m_continueTags.push_back({loopStart, m_context.stackHeight()}); + m_continueTags.emplace_back(loopStart, m_context.stackHeight()); compileExpression(_whileStatement.condition()); m_context << Instruction::ISZERO; m_context.appendConditionalJumpTo(loopEnd); @@ -705,8 +800,8 @@ bool ContractCompiler::visit(ForStatement const& _forStatement) if (_forStatement.initializationExpression()) _forStatement.initializationExpression()->accept(*this); - m_breakTags.push_back({loopEnd, m_context.stackHeight()}); - m_continueTags.push_back({loopNext, m_context.stackHeight()}); + m_breakTags.emplace_back(loopEnd, m_context.stackHeight()); + m_continueTags.emplace_back(loopNext, m_context.stackHeight()); m_context << loopStart; // if there is no terminating condition in for, default is to always be true @@ -932,7 +1027,7 @@ void ContractCompiler::appendModifierOrFunctionCode() if (codeBlock) { - m_returnTags.push_back({m_context.newTag(), m_context.stackHeight()}); + m_returnTags.emplace_back(m_context.newTag(), m_context.stackHeight()); codeBlock->accept(*this); solAssert(!m_returnTags.empty(), ""); |