|
14 | 14 | from devito.types.args import ArgProvider |
15 | 15 | from devito.types.basic import Symbol, DataSymbol, Scalar |
16 | 16 | from devito.types.constant import Constant |
| 17 | +from devito.types.relational import relational_min, relational_max |
17 | 18 |
|
18 | 19 |
|
19 | 20 | __all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension', |
@@ -1015,7 +1016,20 @@ def _arg_defaults(self, _min=None, size=None, alias=None): |
1015 | 1016 | dim = alias or self |
1016 | 1017 | if dim.uses_symbolic_factor: |
1017 | 1018 | factor = defaults[dim.symbolic_factor.name] = self.factor_data |
1018 | | - defaults[dim.parent.max_name] = range(0, factor*size - 1) |
| 1019 | + if dim.condition is None: |
| 1020 | + d0 = 0 |
| 1021 | + d1 = sympy.S.Infinity |
| 1022 | + else: |
| 1023 | + d0 = relational_min(dim.condition, dim.parent) |
| 1024 | + d1 = relational_max(dim.condition, dim.parent) |
| 1025 | + if d1 < sympy.S.Infinity: |
| 1026 | + # We make sure the condition size matches the input size |
| 1027 | + size0 = (d1 - d0 + factor) // factor |
| 1028 | + if size < size0: |
| 1029 | + raise ValueError(f"Incompatible size for ConditionalDimension " |
| 1030 | + f"{self.name}: {size} < {size0}") |
| 1031 | + else: |
| 1032 | + defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1) |
1019 | 1033 |
|
1020 | 1034 | return defaults |
1021 | 1035 |
|
@@ -1482,7 +1496,7 @@ def _arg_defaults(self, **kwargs): |
1482 | 1496 | def _arg_values(self, *args, **kwargs): |
1483 | 1497 | return {} |
1484 | 1498 |
|
1485 | | - def _arg_check(self, *args): |
| 1499 | + def _arg_check(self, *args, **kwargs): |
1486 | 1500 | """A CustomDimension performs no runtime checks.""" |
1487 | 1501 | return |
1488 | 1502 |
|
|
0 commit comments