diff options
Diffstat (limited to 'libsolidity')
-rw-r--r-- | libsolidity/analysis/ConstantEvaluator.cpp | 85 | ||||
-rw-r--r-- | libsolidity/analysis/ConstantEvaluator.h | 18 | ||||
-rw-r--r-- | libsolidity/analysis/ReferencesResolver.cpp | 9 | ||||
-rw-r--r-- | libsolidity/ast/Types.h | 2 | ||||
-rw-r--r-- | libsolidity/codegen/ABIFunctions.cpp | 28 | ||||
-rw-r--r-- | libsolidity/formal/SMTChecker.cpp | 45 | ||||
-rw-r--r-- | libsolidity/formal/SMTChecker.h | 13 | ||||
-rw-r--r-- | libsolidity/formal/SolverInterface.h | 5 |
8 files changed, 135 insertions, 70 deletions
diff --git a/libsolidity/analysis/ConstantEvaluator.cpp b/libsolidity/analysis/ConstantEvaluator.cpp index 4d546e68..83f37f47 100644 --- a/libsolidity/analysis/ConstantEvaluator.cpp +++ b/libsolidity/analysis/ConstantEvaluator.cpp @@ -28,51 +28,42 @@ using namespace std; using namespace dev; using namespace dev::solidity; -/// FIXME: this is pretty much a copy of TypeChecker::endVisit(BinaryOperation) void ConstantEvaluator::endVisit(UnaryOperation const& _operation) { - TypePointer const& subType = _operation.subExpression().annotation().type; - if (!dynamic_cast<RationalNumberType const*>(subType.get())) - m_errorReporter.fatalTypeError(_operation.subExpression().location(), "Invalid constant expression."); - TypePointer t = subType->unaryOperatorResult(_operation.getOperator()); - _operation.annotation().type = t; + auto sub = type(_operation.subExpression()); + if (sub) + setType(_operation, sub->unaryOperatorResult(_operation.getOperator())); } -/// FIXME: this is pretty much a copy of TypeChecker::endVisit(BinaryOperation) void ConstantEvaluator::endVisit(BinaryOperation const& _operation) { - TypePointer const& leftType = _operation.leftExpression().annotation().type; - TypePointer const& rightType = _operation.rightExpression().annotation().type; - if (!dynamic_cast<RationalNumberType const*>(leftType.get())) - m_errorReporter.fatalTypeError(_operation.leftExpression().location(), "Invalid constant expression."); - if (!dynamic_cast<RationalNumberType const*>(rightType.get())) - m_errorReporter.fatalTypeError(_operation.rightExpression().location(), "Invalid constant expression."); - TypePointer commonType = leftType->binaryOperatorResult(_operation.getOperator(), rightType); - if (!commonType) + auto left = type(_operation.leftExpression()); + auto right = type(_operation.rightExpression()); + if (left && right) { - m_errorReporter.typeError( - _operation.location(), - "Operator " + - string(Token::toString(_operation.getOperator())) + - " not compatible with types " + - leftType->toString() + - " and " + - rightType->toString() + auto commonType = left->binaryOperatorResult(_operation.getOperator(), right); + if (!commonType) + m_errorReporter.fatalTypeError( + _operation.location(), + "Operator " + + string(Token::toString(_operation.getOperator())) + + " not compatible with types " + + left->toString() + + " and " + + right->toString() + ); + setType( + _operation, + Token::isCompareOp(_operation.getOperator()) ? + make_shared<BoolType>() : + commonType ); - commonType = leftType; } - _operation.annotation().commonType = commonType; - _operation.annotation().type = - Token::isCompareOp(_operation.getOperator()) ? - make_shared<BoolType>() : - commonType; } void ConstantEvaluator::endVisit(Literal const& _literal) { - _literal.annotation().type = Type::forLiteral(_literal); - if (!_literal.annotation().type) - m_errorReporter.fatalTypeError(_literal.location(), "Invalid literal value."); + setType(_literal, Type::forLiteral(_literal)); } void ConstantEvaluator::endVisit(Identifier const& _identifier) @@ -81,18 +72,34 @@ void ConstantEvaluator::endVisit(Identifier const& _identifier) if (!variableDeclaration) return; if (!variableDeclaration->isConstant()) - m_errorReporter.fatalTypeError(_identifier.location(), "Identifier must be declared constant."); + return; - ASTPointer<Expression> value = variableDeclaration->value(); + ASTPointer<Expression> const& value = variableDeclaration->value(); if (!value) - m_errorReporter.fatalTypeError(_identifier.location(), "Constant identifier declaration must have a constant value."); - - if (!value->annotation().type) + return; + else if (!m_types->count(value.get())) { if (m_depth > 32) m_errorReporter.fatalTypeError(_identifier.location(), "Cyclic constant definition (or maximum recursion depth exhausted)."); - ConstantEvaluator e(*value, m_errorReporter, m_depth + 1); + ConstantEvaluator(m_errorReporter, m_depth + 1, m_types).evaluate(*value); } - _identifier.annotation().type = value->annotation().type; + setType(_identifier, type(*value)); +} + +void ConstantEvaluator::setType(ASTNode const& _node, TypePointer const& _type) +{ + if (_type && _type->category() == Type::Category::RationalNumber) + (*m_types)[&_node] = _type; +} + +TypePointer ConstantEvaluator::type(ASTNode const& _node) +{ + return (*m_types)[&_node]; +} + +TypePointer ConstantEvaluator::evaluate(Expression const& _expr) +{ + _expr.accept(*this); + return type(_expr); } diff --git a/libsolidity/analysis/ConstantEvaluator.h b/libsolidity/analysis/ConstantEvaluator.h index 6725d610..77a357b6 100644 --- a/libsolidity/analysis/ConstantEvaluator.h +++ b/libsolidity/analysis/ConstantEvaluator.h @@ -38,22 +38,32 @@ class TypeChecker; class ConstantEvaluator: private ASTConstVisitor { public: - ConstantEvaluator(Expression const& _expr, ErrorReporter& _errorReporter, size_t _newDepth = 0): + ConstantEvaluator( + ErrorReporter& _errorReporter, + size_t _newDepth = 0, + std::shared_ptr<std::map<ASTNode const*, TypePointer>> _types = std::make_shared<std::map<ASTNode const*, TypePointer>>() + ): m_errorReporter(_errorReporter), - m_depth(_newDepth) + m_depth(_newDepth), + m_types(_types) { - _expr.accept(*this); } + TypePointer evaluate(Expression const& _expr); + private: virtual void endVisit(BinaryOperation const& _operation); virtual void endVisit(UnaryOperation const& _operation); virtual void endVisit(Literal const& _literal); virtual void endVisit(Identifier const& _identifier); + void setType(ASTNode const& _node, TypePointer const& _type); + TypePointer type(ASTNode const& _node); + ErrorReporter& m_errorReporter; /// Current recursion depth. - size_t m_depth; + size_t m_depth = 0; + std::shared_ptr<std::map<ASTNode const*, TypePointer>> m_types; }; } diff --git a/libsolidity/analysis/ReferencesResolver.cpp b/libsolidity/analysis/ReferencesResolver.cpp index f22c95cc..9eee16af 100644 --- a/libsolidity/analysis/ReferencesResolver.cpp +++ b/libsolidity/analysis/ReferencesResolver.cpp @@ -146,11 +146,12 @@ void ReferencesResolver::endVisit(ArrayTypeName const& _typeName) fatalTypeError(_typeName.baseType().location(), "Illegal base type of storage size zero for array."); if (Expression const* length = _typeName.length()) { - if (!length->annotation().type) - ConstantEvaluator e(*length, m_errorReporter); - auto const* lengthType = dynamic_cast<RationalNumberType const*>(length->annotation().type.get()); + TypePointer lengthTypeGeneric = length->annotation().type; + if (!lengthTypeGeneric) + lengthTypeGeneric = ConstantEvaluator(m_errorReporter).evaluate(*length); + RationalNumberType const* lengthType = dynamic_cast<RationalNumberType const*>(lengthTypeGeneric.get()); if (!lengthType || !lengthType->mobileType()) - fatalTypeError(length->location(), "Invalid array length, expected integer literal."); + fatalTypeError(length->location(), "Invalid array length, expected integer literal or constant expression."); else if (lengthType->isFractional()) fatalTypeError(length->location(), "Array with fractional length specified."); else if (lengthType->isNegative()) diff --git a/libsolidity/ast/Types.h b/libsolidity/ast/Types.h index 635279ab..a54e4e09 100644 --- a/libsolidity/ast/Types.h +++ b/libsolidity/ast/Types.h @@ -257,7 +257,7 @@ public: } virtual u256 literalValue(Literal const*) const { - solAssert(false, "Literal value requested for type without literals."); + solAssert(false, "Literal value requested for type without literals: " + toString(false)); } /// @returns a (simpler) type that is encoded in the same way for external function calls. diff --git a/libsolidity/codegen/ABIFunctions.cpp b/libsolidity/codegen/ABIFunctions.cpp index 6648be06..00f59065 100644 --- a/libsolidity/codegen/ABIFunctions.cpp +++ b/libsolidity/codegen/ABIFunctions.cpp @@ -120,7 +120,7 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) Whiskers templ(R"( function <functionName>(headStart, dataEnd) -> <valueReturnParams> { - switch slt(sub(dataEnd, headStart), <minimumSize>) case 1 { revert(0, 0) } + if slt(sub(dataEnd, headStart), <minimumSize>) { revert(0, 0) } <decodeElements> } )"); @@ -151,7 +151,7 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) R"( { let offset := <load>(add(headStart, <pos>)) - switch gt(offset, 0xffffffffffffffff) case 1 { revert(0, 0) } + if gt(offset, 0xffffffffffffffff) { revert(0, 0) } <values> := <abiDecode>(add(headStart, offset), dataEnd) } )" : @@ -1134,7 +1134,7 @@ string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _from R"( // <readableTypeName> function <functionName>(offset, end) -> array { - switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) } + if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) } let length := <retrieveLength> array := <allocate>(<allocationSize>(length)) let dst := array @@ -1169,7 +1169,7 @@ string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _from else { string baseEncodedSize = toCompactHexWithPrefix(_type.baseType()->calldataEncodedSize()); - templ("staticBoundsCheck", "switch gt(add(src, mul(length, " + baseEncodedSize + ")), end) case 1 { revert(0, 0) }"); + templ("staticBoundsCheck", "if gt(add(src, mul(length, " + baseEncodedSize + ")), end) { revert(0, 0) }"); templ("retrieveElementPos", "src"); templ("baseEncodedSize", baseEncodedSize); } @@ -1197,11 +1197,11 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) templ = R"( // <readableTypeName> function <functionName>(offset, end) -> arrayPos, length { - switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) } + if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) } length := calldataload(offset) - switch gt(length, 0xffffffffffffffff) case 1 { revert(0, 0) } + if gt(length, 0xffffffffffffffff) { revert(0, 0) } arrayPos := add(offset, 0x20) - switch gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) case 1 { revert(0, 0) } + if gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) { revert(0, 0) } } )"; else @@ -1209,7 +1209,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) // <readableTypeName> function <functionName>(offset, end) -> arrayPos { arrayPos := offset - switch gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) case 1 { revert(0, 0) } + if gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) { revert(0, 0) } } )"; Whiskers w{templ}; @@ -1235,13 +1235,13 @@ string ABIFunctions::abiDecodingFunctionByteArray(ArrayType const& _type, bool _ Whiskers templ( R"( function <functionName>(offset, end) -> array { - switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) } + if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) } let length := <load>(offset) array := <allocate>(<allocationSize>(length)) mstore(array, length) let src := add(offset, 0x20) let dst := add(array, 0x20) - switch gt(add(src, length), end) case 1 { revert(0, 0) } + if gt(add(src, length), end) { revert(0, 0) } <copyToMemFun>(src, dst, length) } )" @@ -1268,7 +1268,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr Whiskers templ(R"( // <readableTypeName> function <functionName>(headStart, end) -> value { - switch slt(sub(end, headStart), <minimumSize>) case 1 { revert(0, 0) } + if slt(sub(end, headStart), <minimumSize>) { revert(0, 0) } value := <allocate>(<memorySize>) <#members> { @@ -1296,7 +1296,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr dynamic ? R"( let offset := <load>(add(headStart, <pos>)) - switch gt(offset, 0xffffffffffffffff) case 1 { revert(0, 0) } + if gt(offset, 0xffffffffffffffff) { revert(0, 0) } mstore(add(value, <memoryOffset>), <abiDecode>(add(headStart, offset), end)) )" : R"( @@ -1501,7 +1501,7 @@ string ABIFunctions::arrayAllocationSizeFunction(ArrayType const& _type) Whiskers w(R"( function <functionName>(length) -> size { // Make sure we can allocate memory without overflow - switch gt(length, 0xffffffffffffffff) case 1 { revert(0, 0) } + if gt(length, 0xffffffffffffffff) { revert(0, 0) } size := <allocationSize> <addLengthSlot> } @@ -1620,7 +1620,7 @@ string ABIFunctions::allocationFunction() memPtr := mload(<freeMemoryPointer>) let newFreePtr := add(memPtr, size) // protect against overflow - switch or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) case 1 { revert(0, 0) } + if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) { revert(0, 0) } mstore(<freeMemoryPointer>, newFreePtr) } )") diff --git a/libsolidity/formal/SMTChecker.cpp b/libsolidity/formal/SMTChecker.cpp index a22e35d6..d4887a3d 100644 --- a/libsolidity/formal/SMTChecker.cpp +++ b/libsolidity/formal/SMTChecker.cpp @@ -71,6 +71,7 @@ bool SMTChecker::visit(FunctionDefinition const& _function) m_interface->reset(); m_currentSequenceCounter.clear(); m_nextFreeSequenceCounter.clear(); + m_pathConditions.clear(); m_conditionalExecutionHappened = false; initializeLocalVariables(_function); return true; @@ -344,14 +345,14 @@ void SMTChecker::endVisit(FunctionCall const& _funCall) solAssert(args.size() == 1, ""); solAssert(args[0]->annotation().type->category() == Type::Category::Bool, ""); checkCondition(!(expr(*args[0])), _funCall.location(), "Assertion violation"); - m_interface->addAssertion(expr(*args[0])); + addPathImpliedExpression(expr(*args[0])); } else if (funType.kind() == FunctionType::Kind::Require) { solAssert(args.size() == 1, ""); solAssert(args[0]->annotation().type->category() == Type::Category::Bool, ""); checkBooleanNotConstant(*args[0], "Condition is always $VALUE."); - m_interface->addAssertion(expr(*args[0])); + addPathImpliedExpression(expr(*args[0])); } } @@ -514,11 +515,11 @@ void SMTChecker::visitBranch(Statement const& _statement, smt::Expression const* { VariableSequenceCounters sequenceCountersStart = m_currentSequenceCounter; - m_interface->push(); if (_condition) - m_interface->addAssertion(*_condition); + pushPathCondition(*_condition); _statement.accept(*this); - m_interface->pop(); + if (_condition) + popPathCondition(); m_conditionalExecutionHappened = true; m_currentSequenceCounter = sequenceCountersStart; @@ -533,7 +534,7 @@ void SMTChecker::checkCondition( ) { m_interface->push(); - m_interface->addAssertion(_condition); + addPathConjoinedExpression(_condition); vector<smt::Expression> expressionsToEvaluate; vector<string> expressionNames; @@ -605,12 +606,12 @@ void SMTChecker::checkBooleanNotConstant(Expression const& _condition, string co return; m_interface->push(); - m_interface->addAssertion(expr(_condition)); + addPathConjoinedExpression(expr(_condition)); auto positiveResult = checkSatisifable(); m_interface->pop(); m_interface->push(); - m_interface->addAssertion(!expr(_condition)); + addPathConjoinedExpression(!expr(_condition)); auto negatedResult = checkSatisifable(); m_interface->pop(); @@ -828,3 +829,31 @@ smt::Expression SMTChecker::var(Declaration const& _decl) solAssert(m_variables.count(&_decl), ""); return m_variables.at(&_decl); } + +void SMTChecker::popPathCondition() +{ + solAssert(m_pathConditions.size() > 0, "Cannot pop path condition, empty."); + m_pathConditions.pop_back(); +} + +void SMTChecker::pushPathCondition(smt::Expression const& _e) +{ + m_pathConditions.push_back(currentPathConditions() && _e); +} + +smt::Expression SMTChecker::currentPathConditions() +{ + if (m_pathConditions.size() == 0) + return smt::Expression(true); + return m_pathConditions.back(); +} + +void SMTChecker::addPathConjoinedExpression(smt::Expression const& _e) +{ + m_interface->addAssertion(currentPathConditions() && _e); +} + +void SMTChecker::addPathImpliedExpression(smt::Expression const& _e) +{ + m_interface->addAssertion(smt::Expression::implies(currentPathConditions(), _e)); +} diff --git a/libsolidity/formal/SMTChecker.h b/libsolidity/formal/SMTChecker.h index e7481cca..539221cc 100644 --- a/libsolidity/formal/SMTChecker.h +++ b/libsolidity/formal/SMTChecker.h @@ -26,6 +26,7 @@ #include <map> #include <string> +#include <vector> namespace dev { @@ -145,6 +146,17 @@ private: /// The function takes one argument which is the "sequence number". smt::Expression var(Declaration const& _decl); + /// Adds a new path condition + void pushPathCondition(smt::Expression const& _e); + /// Remove the last path condition + void popPathCondition(); + /// Returns the conjunction of all path conditions or True if empty + smt::Expression currentPathConditions(); + /// Conjoin the current path conditions with the given parameter and add to the solver + void addPathConjoinedExpression(smt::Expression const& _e); + /// Add to the solver: the given expression implied by the current path conditions + void addPathImpliedExpression(smt::Expression const& _e); + std::shared_ptr<smt::SolverInterface> m_interface; std::shared_ptr<VariableUsage> m_variableUsage; bool m_conditionalExecutionHappened = false; @@ -152,6 +164,7 @@ private: std::map<Declaration const*, int> m_nextFreeSequenceCounter; std::map<Expression const*, smt::Expression> m_expressions; std::map<Declaration const*, smt::Expression> m_variables; + std::vector<smt::Expression> m_pathConditions; ErrorReporter& m_errorReporter; FunctionDefinition const* m_currentFunction = nullptr; diff --git a/libsolidity/formal/SolverInterface.h b/libsolidity/formal/SolverInterface.h index 74c993e8..88487310 100644 --- a/libsolidity/formal/SolverInterface.h +++ b/libsolidity/formal/SolverInterface.h @@ -72,6 +72,11 @@ public: }, _trueValue.sort); } + static Expression implies(Expression _a, Expression _b) + { + return !std::move(_a) || std::move(_b); + } + friend Expression operator!(Expression _a) { return Expression("not", std::move(_a), Sort::Bool); |