@@ -754,7 +754,7 @@ def test_buffer_reuse():
754754 assert all (np .all (vsave .data [i - 1 ] == i + 1 ) for i in range (1 , nt + 1 ))
755755
756756
757- def test_multi_cond ():
757+ def test_multi_cond_v0 ():
758758 grid = Grid ((3 , 3 ))
759759 nt = 5
760760
@@ -774,14 +774,47 @@ def test_multi_cond():
774774 T = TimeFunction (grid = grid , name = 'T' , time_order = 0 , space_order = 0 )
775775
776776 eqs = [Eq (T , grid .time_dim )]
777- # this to save times from 0 to nt - 2
777+ # This saves
778+ # - All subsampled times since ct1 is the dimension of f
779+ # - The last time step (ntmod - 2) through ctend (since it's set as ct1 or ctend)
780+ eqs .append (Eq (f , T , implicit_dims = ctend ))
781+
782+ # run operator with buffering
783+ op = Operator (eqs , opt = 'buffering' )
784+ op .apply (time_m = 0 , time_M = ntmod - 2 )
785+
786+ for i in range (nt - 1 ):
787+ assert np .allclose (f .data [i ], i * 2 )
788+ assert np .allclose (f .data [nt - 1 ], ntmod - 2 )
789+
790+
791+ def test_multi_cond_v1 ():
792+ grid = Grid ((3 , 3 ))
793+ nt = 5
794+
795+ x , y = grid .dimensions
796+
797+ factor = 2
798+ ntmod = (nt - 1 ) * factor + 1
799+
800+ ct1 = ConditionalDimension (name = "ct1" , parent = grid .time_dim ,
801+ factor = factor , relation = Or ,
802+ condition = CondEq (grid .time_dim , ntmod - 2 ))
803+
804+ f = TimeFunction (grid = grid , name = 'f' , time_order = 0 ,
805+ space_order = 0 , save = nt , time_dim = ct1 )
806+ T = TimeFunction (grid = grid , name = 'T' , time_order = 0 , space_order = 0 )
807+
808+ eqs = [Eq (T , grid .time_dim )]
809+ # This saves
810+ # - All subsampled times since ct1 is the dimension of f with factor 2
811+ # - The last time step (ntmod - 2) since ct1 also has the condition for ntmod - 2
778812 eqs .append (Eq (f , T ))
779- # this to save the last time sample nt - 1
780- eqs .append (Eq (f .forward , T + 1 , implicit_dims = ctend ))
781813
782814 # run operator with buffering
783815 op = Operator (eqs , opt = 'buffering' )
784816 op .apply (time_m = 0 , time_M = ntmod - 2 )
785817
786- for i in range (nt ):
818+ for i in range (nt - 1 ):
787819 assert np .allclose (f .data [i ], i * 2 )
820+ assert np .allclose (f .data [nt - 1 ], ntmod - 2 )
0 commit comments