diff --git a/src/CSET/operators/misc.py b/src/CSET/operators/misc.py index 34e59aefb..f6412dcca 100644 --- a/src/CSET/operators/misc.py +++ b/src/CSET/operators/misc.py @@ -81,6 +81,25 @@ def remove_attribute( return cubes +def remove_scalar_coords(cubes, coords): + """Remove scalar coordinates. + + examples would be: realization, forecast_reference_time from model cubes. + """ + if not isinstance(cubes, CubeList): + cubes = CubeList([cubes]) + + for cube in cubes: + for coord_name in coords: + if cube.coords(coord_name): + coord = cube.coord(coord_name) + # only remove if scalar + if cube.coord_dims(coord) == (): + cube.remove_coord(coord) + + return cubes + + def addition(addend_1, addend_2): """Addition of two fields. diff --git a/tests/operators/test_misc.py b/tests/operators/test_misc.py index 41cfc7266..156dda765 100644 --- a/tests/operators/test_misc.py +++ b/tests/operators/test_misc.py @@ -541,3 +541,41 @@ def test_extract_common_points_nocommonpoints(vertical_profile_cube): misc.extract_common_points( cubes=iris.cube.CubeList([cube1, cube2]), coordinate="pressure" ) + + +def test_remove_scalar_coord(): + """Test that scalar coordinate be removed.""" + # Create simple 1D cube + data = np.arange(5) + time = iris.coords.DimCoord( + np.arange(5), standard_name="time", units="hours since 1970-01-01" + ) + cube = iris.cube.Cube(data, dim_coords_and_dims=[(time, 0)]) + # Add a scalar coord + realization = iris.coords.AuxCoord(1, long_name="realization") + cube.add_aux_coord(realization) + # Check it's present and scalar + assert cube.coords("realization") + assert cube.coord_dims("realization") == () + # Run function + out = misc.remove_scalar_coords(cube, ["realization"]) + # Check it’s removed + cube_out = out[0] + assert not cube_out.coords("realization") + + +def test_not_remove_non_scalar_coord(): + """Test that non-scalar coordinate is not removed.""" + # Create 1D cube + data = np.arange(5) + time = iris.coords.DimCoord( + np.arange(5), standard_name="time", units="hours since 1970-01-01" + ) + cube = iris.cube.Cube(data, dim_coords_and_dims=[(time, 0)]) + # Confirm it's non-scalar + assert cube.coord_dims("time") != () + # Run function + out = misc.remove_scalar_coords(cube, ["time"]) + # Check it is still present + cube_out = out[0] + assert cube_out.coords("time")