Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/StdCheats.sol
Original file line number Diff line number Diff line change
Expand Up @@ -785,13 +785,25 @@ abstract contract StdCheats is StdCheatsSafe {
dealERC1155(token, to, id, give, false);
}

uint256 private _reflectionRate;

function _reflectionTransform(uint256 give) internal view returns (uint256) {
return give * _reflectionRate;
}

function deal(address token, address to, uint256 give, bool adjust) internal virtual {
// get current balance
(, bytes memory balData) = token.staticcall(abi.encodeWithSelector(0x70a08231, to));
uint256 prevBal = abi.decode(balData, (uint256));

(bool isReflection, bytes memory rateData) = token.staticcall(abi.encodeWithSelector(0x4549b039, 1, false));
// update balance
stdstore.target(token).sig(0x70a08231).with_key(to).checked_write(give);
if (isReflection) {
_reflectionRate = abi.decode(rateData, (uint256));
stdstore.target(token).sig(0x70a08231).with_key(to).checked_write(give, _reflectionTransform);
} else {
stdstore.target(token).sig(0x70a08231).with_key(to).checked_write(give);
}

// update total supply
if (adjust) {
Expand Down
51 changes: 47 additions & 4 deletions src/StdStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,36 @@ library stdStorageSafe {
return (foundLeft && foundRight, offsetLeft, offsetRight);
}

function _identity(uint256 x) internal pure returns (uint256) {
return x;
}

function find(StdStorage storage self) internal returns (FindData storage) {
return find(self, true);
}

function find(StdStorage storage self, bool _clear) internal returns (FindData storage) {
return find(self, _clear, _identity);
}

function find(StdStorage storage self, function(uint256) internal view returns (uint256) transform)
internal
returns (FindData storage)
{
return find(self, true, transform);
}

/// @notice find an arbitrary storage slot given a function sig, input data, address of the contract and a value to check against
/// @dev an optional transform can be applied to the call result before comparison
// slot complexity:
// if flat, will be bytes32(uint256(uint));
// if map, will be keccak256(abi.encode(key, uint(slot)));
// if deep map, will be keccak256(abi.encode(key1, keccak256(abi.encode(key0, uint(slot)))));
// if map struct, will be bytes32(uint256(keccak256(abi.encode(key1, keccak256(abi.encode(key0, uint(slot)))))) + structFieldDepth);
function find(StdStorage storage self, bool _clear) internal returns (FindData storage) {
function find(StdStorage storage self, bool _clear, function(uint256) internal view returns (uint256) transform)
internal
returns (FindData storage)
{
address who = self._target;
bytes4 fsig = self._sig;
uint256 field_depth = self._depth;
Expand Down Expand Up @@ -147,7 +166,7 @@ library stdStorageSafe {
// Check that value between found offsets is equal to the current call result
uint256 curVal = (uint256(prev) & getMaskByOffsets(offsetLeft, offsetRight)) >> offsetRight;

if (uint256(callResult) != curVal) {
if (transform(uint256(callResult)) != curVal) {
continue;
}

Expand Down Expand Up @@ -350,6 +369,13 @@ library stdStorage {
return stdStorageSafe.find(self, _clear).slot;
}

function find(StdStorage storage self, bool _clear, function(uint256) internal view returns (uint256) transform)
internal
returns (uint256)
{
return stdStorageSafe.find(self, _clear, transform).slot;
}

function target(StdStorage storage self, address _target) internal returns (StdStorage storage) {
return stdStorageSafe.target(self, _target);
}
Expand Down Expand Up @@ -398,6 +424,14 @@ library stdStorage {
checked_write(self, bytes32(amt));
}

function checked_write(
StdStorage storage self,
uint256 amt,
function(uint256) internal view returns (uint256) transform
) internal {
checked_write(self, bytes32(amt), transform);
}

function checked_write_int(StdStorage storage self, int256 val) internal {
checked_write(self, bytes32(uint256(val)));
}
Expand All @@ -411,13 +445,21 @@ library stdStorage {
}

function checked_write(StdStorage storage self, bytes32 set) internal {
checked_write(self, set, stdStorageSafe._identity);
}

function checked_write(
StdStorage storage self,
bytes32 set,
function(uint256) internal view returns (uint256) transform
) internal {
address who = self._target;
bytes4 fsig = self._sig;
uint256 field_depth = self._depth;
bytes memory params = stdStorageSafe.getCallParams(self);

if (!self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))].found) {
find(self, false);
find(self, false, transform);
}
FindData storage data = self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))];
if ((data.offsetLeft + data.offsetRight) > 0) {
Expand All @@ -433,7 +475,8 @@ library stdStorage {
);
}
bytes32 curVal = vm.load(who, bytes32(data.slot));
bytes32 valToSet = stdStorageSafe.getUpdatedSlotValue(curVal, uint256(set), data.offsetLeft, data.offsetRight);
bytes32 valToSet =
stdStorageSafe.getUpdatedSlotValue(curVal, transform(uint256(set)), data.offsetLeft, data.offsetRight);

vm.store(who, bytes32(data.slot), valToSet);

Expand Down
33 changes: 33 additions & 0 deletions test/StdCheats.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ contract StdCheatsTest is Test {
assertEq(barToken.balanceOf(address(this)), 10000e18);
}

function test_DealReflectionToken() public {
BarReflection token = new BarReflection();
uint256 dealAmount = 1000e18;
deal(address(token), address(this), dealAmount);
assertEq(token.balanceOf(address(this)), dealAmount);
}

function test_DealTokenAdjustTotalSupply() public {
Bar barToken = new Bar();
address bar = address(barToken);
Expand Down Expand Up @@ -672,6 +679,32 @@ contract BarERC721 {
mapping(address => uint256) private _balances;
}

contract BarReflection {
uint256 private _tTotal = 10000e18;
uint256 private _rTotal = type(uint256).max - (type(uint256).max % _tTotal);
mapping(address => uint256) private _rOwned;

constructor() {
_rOwned[address(this)] = _rTotal;
}

function totalSupply() public view returns (uint256) {
return _tTotal;
}

function balanceOf(address account) public view returns (uint256) {
return _rOwned[account] / _getRate();
}

function reflectionFromToken(uint256 tAmount, bool) public view returns (uint256) {
return tAmount * _getRate();
}

function _getRate() private view returns (uint256) {
return _rTotal / _tTotal;
}
}

contract RevertingContract {
constructor() {
revert();
Expand Down
57 changes: 14 additions & 43 deletions test/StdStorage.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ contract StdStorageTest is Test {

StorageTest internal test;

function _double(uint256 x) internal pure returns (uint256) {
return x * 2;
}

function setUp() public {
test = new StorageTest();
}
Expand Down Expand Up @@ -198,6 +202,12 @@ contract StdStorageTest is Test {
assertEq(1337, test.read_struct_lower(address(1337)));
}

function test_StorageCheckedWriteTransform() public {
MockDoubledStorage ds = new MockDoubledStorage();
stdstore.target(address(ds)).sig("value()").checked_write(uint256(100), _double);
assertEq(ds.value(), 100);
}

function test_RevertStorageConst() public {
StorageTestTarget target = new StorageTestTarget(test);

Expand Down Expand Up @@ -350,16 +360,6 @@ contract StdStorageTest is Test {
stdstore.target(address(test)).sig("edgeCaseArray(uint256)").with_key(uint256(0)).checked_write(1);
assertEq(test.edgeCaseArray(0), 1);
}

// Regression test for https://github.com/foundry-rs/forge-std/issues/740
// `find()` used to infinite-loop on tokens whose `balanceOf` reads multiple
// storage slots and returns a derived value (reflection tokens).
function test_RevertFindOnReflectionToken() public {
MockReflectionToken token = new MockReflectionToken();
ReflectionTokenTarget target = new ReflectionTokenTarget(token);
vm.expectRevert("stdStorage find(StdStorage): Slot(s) not found.");
target.findBalanceOf(address(this));
}
}

contract StorageTestTarget {
Expand All @@ -377,21 +377,6 @@ contract StorageTestTarget {
}
}

contract ReflectionTokenTarget {
using stdStorage for StdStorage;

StdStorage internal stdstore;
MockReflectionToken internal token;

constructor(MockReflectionToken token_) {
token = token_;
}

function findBalanceOf(address who) public {
stdstore.target(address(token)).sig("balanceOf(address)").with_key(who).find();
}
}

contract StorageTest {
uint256 public exists = 1;
mapping(address => uint256) public map_addr;
Expand Down Expand Up @@ -509,24 +494,10 @@ contract StorageTest {
}
}

// Minimal mock of a reflection token: `balanceOf` reads many storage slots
// and always returns a constant, so no single slot mutation can change its
// return value and stdStorage can never find a matching slot.
contract MockReflectionToken {
uint256 internal _a = 1;
uint256 internal _b = 2;
uint256 internal _c = 3;
mapping(address => uint256) internal _balances;

constructor() {
_balances[msg.sender] = 1000 ether;
}
contract MockDoubledStorage {
uint256 private _doubled;

// Reads _a, _b, _c, and _balances[account] but always returns a constant.
// This means mutating any single slot won't change the return value.
function balanceOf(address account) public view returns (uint256) {
uint256 x = _a + _b + _c + _balances[account];
x; // suppress unused warning
return 42;
function value() public view returns (uint256) {
return _doubled / 2;
}
}
Loading