aboutsummaryrefslogtreecommitdiffstats
path: root/libsolidity
diff options
context:
space:
mode:
Diffstat (limited to 'libsolidity')
-rw-r--r--libsolidity/analysis/ConstantEvaluator.cpp85
-rw-r--r--libsolidity/analysis/ConstantEvaluator.h18
-rw-r--r--libsolidity/analysis/ReferencesResolver.cpp9
-rw-r--r--libsolidity/ast/Types.h2
-rw-r--r--libsolidity/codegen/ABIFunctions.cpp28
-rw-r--r--libsolidity/formal/SMTChecker.cpp45
-rw-r--r--libsolidity/formal/SMTChecker.h13
-rw-r--r--libsolidity/formal/SolverInterface.h5
8 files changed, 135 insertions, 70 deletions
diff --git a/libsolidity/analysis/ConstantEvaluator.cpp b/libsolidity/analysis/ConstantEvaluator.cpp
index 4d546e68..83f37f47 100644
--- a/libsolidity/analysis/ConstantEvaluator.cpp
+++ b/libsolidity/analysis/ConstantEvaluator.cpp
@@ -28,51 +28,42 @@ using namespace std;
using namespace dev;
using namespace dev::solidity;
-/// FIXME: this is pretty much a copy of TypeChecker::endVisit(BinaryOperation)
void ConstantEvaluator::endVisit(UnaryOperation const& _operation)
{
- TypePointer const& subType = _operation.subExpression().annotation().type;
- if (!dynamic_cast<RationalNumberType const*>(subType.get()))
- m_errorReporter.fatalTypeError(_operation.subExpression().location(), "Invalid constant expression.");
- TypePointer t = subType->unaryOperatorResult(_operation.getOperator());
- _operation.annotation().type = t;
+ auto sub = type(_operation.subExpression());
+ if (sub)
+ setType(_operation, sub->unaryOperatorResult(_operation.getOperator()));
}
-/// FIXME: this is pretty much a copy of TypeChecker::endVisit(BinaryOperation)
void ConstantEvaluator::endVisit(BinaryOperation const& _operation)
{
- TypePointer const& leftType = _operation.leftExpression().annotation().type;
- TypePointer const& rightType = _operation.rightExpression().annotation().type;
- if (!dynamic_cast<RationalNumberType const*>(leftType.get()))
- m_errorReporter.fatalTypeError(_operation.leftExpression().location(), "Invalid constant expression.");
- if (!dynamic_cast<RationalNumberType const*>(rightType.get()))
- m_errorReporter.fatalTypeError(_operation.rightExpression().location(), "Invalid constant expression.");
- TypePointer commonType = leftType->binaryOperatorResult(_operation.getOperator(), rightType);
- if (!commonType)
+ auto left = type(_operation.leftExpression());
+ auto right = type(_operation.rightExpression());
+ if (left && right)
{
- m_errorReporter.typeError(
- _operation.location(),
- "Operator " +
- string(Token::toString(_operation.getOperator())) +
- " not compatible with types " +
- leftType->toString() +
- " and " +
- rightType->toString()
+ auto commonType = left->binaryOperatorResult(_operation.getOperator(), right);
+ if (!commonType)
+ m_errorReporter.fatalTypeError(
+ _operation.location(),
+ "Operator " +
+ string(Token::toString(_operation.getOperator())) +
+ " not compatible with types " +
+ left->toString() +
+ " and " +
+ right->toString()
+ );
+ setType(
+ _operation,
+ Token::isCompareOp(_operation.getOperator()) ?
+ make_shared<BoolType>() :
+ commonType
);
- commonType = leftType;
}
- _operation.annotation().commonType = commonType;
- _operation.annotation().type =
- Token::isCompareOp(_operation.getOperator()) ?
- make_shared<BoolType>() :
- commonType;
}
void ConstantEvaluator::endVisit(Literal const& _literal)
{
- _literal.annotation().type = Type::forLiteral(_literal);
- if (!_literal.annotation().type)
- m_errorReporter.fatalTypeError(_literal.location(), "Invalid literal value.");
+ setType(_literal, Type::forLiteral(_literal));
}
void ConstantEvaluator::endVisit(Identifier const& _identifier)
@@ -81,18 +72,34 @@ void ConstantEvaluator::endVisit(Identifier const& _identifier)
if (!variableDeclaration)
return;
if (!variableDeclaration->isConstant())
- m_errorReporter.fatalTypeError(_identifier.location(), "Identifier must be declared constant.");
+ return;
- ASTPointer<Expression> value = variableDeclaration->value();
+ ASTPointer<Expression> const& value = variableDeclaration->value();
if (!value)
- m_errorReporter.fatalTypeError(_identifier.location(), "Constant identifier declaration must have a constant value.");
-
- if (!value->annotation().type)
+ return;
+ else if (!m_types->count(value.get()))
{
if (m_depth > 32)
m_errorReporter.fatalTypeError(_identifier.location(), "Cyclic constant definition (or maximum recursion depth exhausted).");
- ConstantEvaluator e(*value, m_errorReporter, m_depth + 1);
+ ConstantEvaluator(m_errorReporter, m_depth + 1, m_types).evaluate(*value);
}
- _identifier.annotation().type = value->annotation().type;
+ setType(_identifier, type(*value));
+}
+
+void ConstantEvaluator::setType(ASTNode const& _node, TypePointer const& _type)
+{
+ if (_type && _type->category() == Type::Category::RationalNumber)
+ (*m_types)[&_node] = _type;
+}
+
+TypePointer ConstantEvaluator::type(ASTNode const& _node)
+{
+ return (*m_types)[&_node];
+}
+
+TypePointer ConstantEvaluator::evaluate(Expression const& _expr)
+{
+ _expr.accept(*this);
+ return type(_expr);
}
diff --git a/libsolidity/analysis/ConstantEvaluator.h b/libsolidity/analysis/ConstantEvaluator.h
index 6725d610..77a357b6 100644
--- a/libsolidity/analysis/ConstantEvaluator.h
+++ b/libsolidity/analysis/ConstantEvaluator.h
@@ -38,22 +38,32 @@ class TypeChecker;
class ConstantEvaluator: private ASTConstVisitor
{
public:
- ConstantEvaluator(Expression const& _expr, ErrorReporter& _errorReporter, size_t _newDepth = 0):
+ ConstantEvaluator(
+ ErrorReporter& _errorReporter,
+ size_t _newDepth = 0,
+ std::shared_ptr<std::map<ASTNode const*, TypePointer>> _types = std::make_shared<std::map<ASTNode const*, TypePointer>>()
+ ):
m_errorReporter(_errorReporter),
- m_depth(_newDepth)
+ m_depth(_newDepth),
+ m_types(_types)
{
- _expr.accept(*this);
}
+ TypePointer evaluate(Expression const& _expr);
+
private:
virtual void endVisit(BinaryOperation const& _operation);
virtual void endVisit(UnaryOperation const& _operation);
virtual void endVisit(Literal const& _literal);
virtual void endVisit(Identifier const& _identifier);
+ void setType(ASTNode const& _node, TypePointer const& _type);
+ TypePointer type(ASTNode const& _node);
+
ErrorReporter& m_errorReporter;
/// Current recursion depth.
- size_t m_depth;
+ size_t m_depth = 0;
+ std::shared_ptr<std::map<ASTNode const*, TypePointer>> m_types;
};
}
diff --git a/libsolidity/analysis/ReferencesResolver.cpp b/libsolidity/analysis/ReferencesResolver.cpp
index f22c95cc..9eee16af 100644
--- a/libsolidity/analysis/ReferencesResolver.cpp
+++ b/libsolidity/analysis/ReferencesResolver.cpp
@@ -146,11 +146,12 @@ void ReferencesResolver::endVisit(ArrayTypeName const& _typeName)
fatalTypeError(_typeName.baseType().location(), "Illegal base type of storage size zero for array.");
if (Expression const* length = _typeName.length())
{
- if (!length->annotation().type)
- ConstantEvaluator e(*length, m_errorReporter);
- auto const* lengthType = dynamic_cast<RationalNumberType const*>(length->annotation().type.get());
+ TypePointer lengthTypeGeneric = length->annotation().type;
+ if (!lengthTypeGeneric)
+ lengthTypeGeneric = ConstantEvaluator(m_errorReporter).evaluate(*length);
+ RationalNumberType const* lengthType = dynamic_cast<RationalNumberType const*>(lengthTypeGeneric.get());
if (!lengthType || !lengthType->mobileType())
- fatalTypeError(length->location(), "Invalid array length, expected integer literal.");
+ fatalTypeError(length->location(), "Invalid array length, expected integer literal or constant expression.");
else if (lengthType->isFractional())
fatalTypeError(length->location(), "Array with fractional length specified.");
else if (lengthType->isNegative())
diff --git a/libsolidity/ast/Types.h b/libsolidity/ast/Types.h
index 635279ab..a54e4e09 100644
--- a/libsolidity/ast/Types.h
+++ b/libsolidity/ast/Types.h
@@ -257,7 +257,7 @@ public:
}
virtual u256 literalValue(Literal const*) const
{
- solAssert(false, "Literal value requested for type without literals.");
+ solAssert(false, "Literal value requested for type without literals: " + toString(false));
}
/// @returns a (simpler) type that is encoded in the same way for external function calls.
diff --git a/libsolidity/codegen/ABIFunctions.cpp b/libsolidity/codegen/ABIFunctions.cpp
index 6648be06..00f59065 100644
--- a/libsolidity/codegen/ABIFunctions.cpp
+++ b/libsolidity/codegen/ABIFunctions.cpp
@@ -120,7 +120,7 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory)
Whiskers templ(R"(
function <functionName>(headStart, dataEnd) -> <valueReturnParams> {
- switch slt(sub(dataEnd, headStart), <minimumSize>) case 1 { revert(0, 0) }
+ if slt(sub(dataEnd, headStart), <minimumSize>) { revert(0, 0) }
<decodeElements>
}
)");
@@ -151,7 +151,7 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory)
R"(
{
let offset := <load>(add(headStart, <pos>))
- switch gt(offset, 0xffffffffffffffff) case 1 { revert(0, 0) }
+ if gt(offset, 0xffffffffffffffff) { revert(0, 0) }
<values> := <abiDecode>(add(headStart, offset), dataEnd)
}
)" :
@@ -1134,7 +1134,7 @@ string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _from
R"(
// <readableTypeName>
function <functionName>(offset, end) -> array {
- switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) }
+ if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) }
let length := <retrieveLength>
array := <allocate>(<allocationSize>(length))
let dst := array
@@ -1169,7 +1169,7 @@ string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _from
else
{
string baseEncodedSize = toCompactHexWithPrefix(_type.baseType()->calldataEncodedSize());
- templ("staticBoundsCheck", "switch gt(add(src, mul(length, " + baseEncodedSize + ")), end) case 1 { revert(0, 0) }");
+ templ("staticBoundsCheck", "if gt(add(src, mul(length, " + baseEncodedSize + ")), end) { revert(0, 0) }");
templ("retrieveElementPos", "src");
templ("baseEncodedSize", baseEncodedSize);
}
@@ -1197,11 +1197,11 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
templ = R"(
// <readableTypeName>
function <functionName>(offset, end) -> arrayPos, length {
- switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) }
+ if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) }
length := calldataload(offset)
- switch gt(length, 0xffffffffffffffff) case 1 { revert(0, 0) }
+ if gt(length, 0xffffffffffffffff) { revert(0, 0) }
arrayPos := add(offset, 0x20)
- switch gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) case 1 { revert(0, 0) }
+ if gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) { revert(0, 0) }
}
)";
else
@@ -1209,7 +1209,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
// <readableTypeName>
function <functionName>(offset, end) -> arrayPos {
arrayPos := offset
- switch gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) case 1 { revert(0, 0) }
+ if gt(add(arrayPos, mul(<length>, <baseEncodedSize>)), end) { revert(0, 0) }
}
)";
Whiskers w{templ};
@@ -1235,13 +1235,13 @@ string ABIFunctions::abiDecodingFunctionByteArray(ArrayType const& _type, bool _
Whiskers templ(
R"(
function <functionName>(offset, end) -> array {
- switch slt(add(offset, 0x1f), end) case 0 { revert(0, 0) }
+ if iszero(slt(add(offset, 0x1f), end)) { revert(0, 0) }
let length := <load>(offset)
array := <allocate>(<allocationSize>(length))
mstore(array, length)
let src := add(offset, 0x20)
let dst := add(array, 0x20)
- switch gt(add(src, length), end) case 1 { revert(0, 0) }
+ if gt(add(src, length), end) { revert(0, 0) }
<copyToMemFun>(src, dst, length)
}
)"
@@ -1268,7 +1268,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr
Whiskers templ(R"(
// <readableTypeName>
function <functionName>(headStart, end) -> value {
- switch slt(sub(end, headStart), <minimumSize>) case 1 { revert(0, 0) }
+ if slt(sub(end, headStart), <minimumSize>) { revert(0, 0) }
value := <allocate>(<memorySize>)
<#members>
{
@@ -1296,7 +1296,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr
dynamic ?
R"(
let offset := <load>(add(headStart, <pos>))
- switch gt(offset, 0xffffffffffffffff) case 1 { revert(0, 0) }
+ if gt(offset, 0xffffffffffffffff) { revert(0, 0) }
mstore(add(value, <memoryOffset>), <abiDecode>(add(headStart, offset), end))
)" :
R"(
@@ -1501,7 +1501,7 @@ string ABIFunctions::arrayAllocationSizeFunction(ArrayType const& _type)
Whiskers w(R"(
function <functionName>(length) -> size {
// Make sure we can allocate memory without overflow
- switch gt(length, 0xffffffffffffffff) case 1 { revert(0, 0) }
+ if gt(length, 0xffffffffffffffff) { revert(0, 0) }
size := <allocationSize>
<addLengthSlot>
}
@@ -1620,7 +1620,7 @@ string ABIFunctions::allocationFunction()
memPtr := mload(<freeMemoryPointer>)
let newFreePtr := add(memPtr, size)
// protect against overflow
- switch or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) case 1 { revert(0, 0) }
+ if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) { revert(0, 0) }
mstore(<freeMemoryPointer>, newFreePtr)
}
)")
diff --git a/libsolidity/formal/SMTChecker.cpp b/libsolidity/formal/SMTChecker.cpp
index a22e35d6..d4887a3d 100644
--- a/libsolidity/formal/SMTChecker.cpp
+++ b/libsolidity/formal/SMTChecker.cpp
@@ -71,6 +71,7 @@ bool SMTChecker::visit(FunctionDefinition const& _function)
m_interface->reset();
m_currentSequenceCounter.clear();
m_nextFreeSequenceCounter.clear();
+ m_pathConditions.clear();
m_conditionalExecutionHappened = false;
initializeLocalVariables(_function);
return true;
@@ -344,14 +345,14 @@ void SMTChecker::endVisit(FunctionCall const& _funCall)
solAssert(args.size() == 1, "");
solAssert(args[0]->annotation().type->category() == Type::Category::Bool, "");
checkCondition(!(expr(*args[0])), _funCall.location(), "Assertion violation");
- m_interface->addAssertion(expr(*args[0]));
+ addPathImpliedExpression(expr(*args[0]));
}
else if (funType.kind() == FunctionType::Kind::Require)
{
solAssert(args.size() == 1, "");
solAssert(args[0]->annotation().type->category() == Type::Category::Bool, "");
checkBooleanNotConstant(*args[0], "Condition is always $VALUE.");
- m_interface->addAssertion(expr(*args[0]));
+ addPathImpliedExpression(expr(*args[0]));
}
}
@@ -514,11 +515,11 @@ void SMTChecker::visitBranch(Statement const& _statement, smt::Expression const*
{
VariableSequenceCounters sequenceCountersStart = m_currentSequenceCounter;
- m_interface->push();
if (_condition)
- m_interface->addAssertion(*_condition);
+ pushPathCondition(*_condition);
_statement.accept(*this);
- m_interface->pop();
+ if (_condition)
+ popPathCondition();
m_conditionalExecutionHappened = true;
m_currentSequenceCounter = sequenceCountersStart;
@@ -533,7 +534,7 @@ void SMTChecker::checkCondition(
)
{
m_interface->push();
- m_interface->addAssertion(_condition);
+ addPathConjoinedExpression(_condition);
vector<smt::Expression> expressionsToEvaluate;
vector<string> expressionNames;
@@ -605,12 +606,12 @@ void SMTChecker::checkBooleanNotConstant(Expression const& _condition, string co
return;
m_interface->push();
- m_interface->addAssertion(expr(_condition));
+ addPathConjoinedExpression(expr(_condition));
auto positiveResult = checkSatisifable();
m_interface->pop();
m_interface->push();
- m_interface->addAssertion(!expr(_condition));
+ addPathConjoinedExpression(!expr(_condition));
auto negatedResult = checkSatisifable();
m_interface->pop();
@@ -828,3 +829,31 @@ smt::Expression SMTChecker::var(Declaration const& _decl)
solAssert(m_variables.count(&_decl), "");
return m_variables.at(&_decl);
}
+
+void SMTChecker::popPathCondition()
+{
+ solAssert(m_pathConditions.size() > 0, "Cannot pop path condition, empty.");
+ m_pathConditions.pop_back();
+}
+
+void SMTChecker::pushPathCondition(smt::Expression const& _e)
+{
+ m_pathConditions.push_back(currentPathConditions() && _e);
+}
+
+smt::Expression SMTChecker::currentPathConditions()
+{
+ if (m_pathConditions.size() == 0)
+ return smt::Expression(true);
+ return m_pathConditions.back();
+}
+
+void SMTChecker::addPathConjoinedExpression(smt::Expression const& _e)
+{
+ m_interface->addAssertion(currentPathConditions() && _e);
+}
+
+void SMTChecker::addPathImpliedExpression(smt::Expression const& _e)
+{
+ m_interface->addAssertion(smt::Expression::implies(currentPathConditions(), _e));
+}
diff --git a/libsolidity/formal/SMTChecker.h b/libsolidity/formal/SMTChecker.h
index e7481cca..539221cc 100644
--- a/libsolidity/formal/SMTChecker.h
+++ b/libsolidity/formal/SMTChecker.h
@@ -26,6 +26,7 @@
#include <map>
#include <string>
+#include <vector>
namespace dev
{
@@ -145,6 +146,17 @@ private:
/// The function takes one argument which is the "sequence number".
smt::Expression var(Declaration const& _decl);
+ /// Adds a new path condition
+ void pushPathCondition(smt::Expression const& _e);
+ /// Remove the last path condition
+ void popPathCondition();
+ /// Returns the conjunction of all path conditions or True if empty
+ smt::Expression currentPathConditions();
+ /// Conjoin the current path conditions with the given parameter and add to the solver
+ void addPathConjoinedExpression(smt::Expression const& _e);
+ /// Add to the solver: the given expression implied by the current path conditions
+ void addPathImpliedExpression(smt::Expression const& _e);
+
std::shared_ptr<smt::SolverInterface> m_interface;
std::shared_ptr<VariableUsage> m_variableUsage;
bool m_conditionalExecutionHappened = false;
@@ -152,6 +164,7 @@ private:
std::map<Declaration const*, int> m_nextFreeSequenceCounter;
std::map<Expression const*, smt::Expression> m_expressions;
std::map<Declaration const*, smt::Expression> m_variables;
+ std::vector<smt::Expression> m_pathConditions;
ErrorReporter& m_errorReporter;
FunctionDefinition const* m_currentFunction = nullptr;
diff --git a/libsolidity/formal/SolverInterface.h b/libsolidity/formal/SolverInterface.h
index 74c993e8..88487310 100644
--- a/libsolidity/formal/SolverInterface.h
+++ b/libsolidity/formal/SolverInterface.h
@@ -72,6 +72,11 @@ public:
}, _trueValue.sort);
}
+ static Expression implies(Expression _a, Expression _b)
+ {
+ return !std::move(_a) || std::move(_b);
+ }
+
friend Expression operator!(Expression _a)
{
return Expression("not", std::move(_a), Sort::Bool);