diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster.py index 4993c2f7fed3..bb04cebfc4a6 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster.py @@ -142,23 +142,17 @@ def _initialize_warp_meshes(self): self._initialize_warp_meshes_from_clone_plan(plan) def _initialize_warp_meshes_from_clone_plan(self, plan) -> None: - """Initialize rectangular mesh buffers from ClonePlan source rows. - - The current PR keeps the existing rectangular kernel ABI. Environments - with fewer meshes than the target maximum are padded with a valid mesh - placed far outside the ray-cast range. A follow-up PR can replace this - padded representation with a new kernel. - """ - target_records_by_expr: dict[ - str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]] - ] = {} - dummy_mesh_id: int | None = None + """Initialize flat mesh slots from ClonePlan source rows.""" + slot_env_ids: list[int] = [] + slot_mesh_ids: list[int] = [] + slot_positions: list[tuple[float, float, float]] = [] + slot_orientations: list[tuple[float, float, float, float]] = [] self._mesh_views = [] + self._slot_ranges_by_target_expr: dict[str, tuple[int, int]] = {} for target_cfg in self._raycast_targets_cfg: - records_per_env: list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]] = [ - [] for _ in range(self._num_envs) - ] + target_start = len(slot_mesh_ids) + meshes_added_per_env = [0 for _ in range(self._num_envs)] matches = self._collect_clone_plan_matches(plan, target_cfg.prim_expr) if matches: for row, source_root, source_expr in matches: @@ -171,7 +165,6 @@ def _initialize_warp_meshes_from_clone_plan(self, plan) -> None: prototype_records = [] for target_prim in target_prims: mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg) - dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id source_root_prim = self.stage.GetPrimAtPath(source_root) local_pos, local_quat = sim_utils.resolve_prim_pose(target_prim, source_root_prim) prototype_records.append((mesh_id, local_pos, local_quat)) @@ -187,13 +180,11 @@ def _initialize_warp_meshes_from_clone_plan(self, plan) -> None: mesh_pos_t, mesh_quat_t = math_utils.combine_frame_transforms( root_pos_t, root_quat_t, local_pos_t, local_quat_t ) - records_per_env[env_id].append( - ( - mesh_id, - tuple(float(v) for v in mesh_pos_t[0].tolist()), - tuple(float(v) for v in mesh_quat_t[0].tolist()), - ) - ) + slot_env_ids.append(env_id) + slot_mesh_ids.append(mesh_id) + slot_positions.append(tuple(float(v) for v in mesh_pos_t[0].tolist())) + slot_orientations.append(tuple(float(v) for v in mesh_quat_t[0].tolist())) + meshes_added_per_env[env_id] += 1 else: target_prims = sim_utils.find_matching_prims(target_cfg.prim_expr) if len(target_prims) == 0: @@ -201,19 +192,23 @@ def _initialize_warp_meshes_from_clone_plan(self, plan) -> None: records = [] for target_prim in target_prims: mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg) - dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id pos, quat = sim_utils.resolve_prim_pose(target_prim) records.append((mesh_id, tuple(float(v) for v in pos), tuple(float(v) for v in quat))) for env_id in range(self._num_envs): - records_per_env[env_id].extend(records) - - self._num_meshes_per_env[target_cfg.prim_expr] = max(len(records) for records in records_per_env) - target_records_by_expr[target_cfg.prim_expr] = records_per_env + for mesh_id, pos, quat in records: + slot_env_ids.append(env_id) + slot_mesh_ids.append(mesh_id) + slot_positions.append(pos) + slot_orientations.append(quat) + meshes_added_per_env[env_id] += 1 + + self._num_meshes_per_env[target_cfg.prim_expr] = max(meshes_added_per_env) + self._slot_ranges_by_target_expr[target_cfg.prim_expr] = (target_start, len(slot_mesh_ids)) self._mesh_views.append( self._create_tracked_target_view(target_cfg.prim_expr) if target_cfg.track_mesh_transforms else None ) - self._install_rectangular_mesh_table(target_records_by_expr, dummy_mesh_id) + self._install_flat_mesh_slots(slot_env_ids, slot_mesh_ids, slot_positions, slot_orientations) def _collect_clone_plan_matches(self, plan, target_expr: str) -> list[tuple[int, str, str]]: target_env0 = _target_expr_for_env(target_expr, 0) @@ -339,14 +334,16 @@ def _create_tracked_target_view(self, target_prim_path: str): raise NotImplementedError("Tracked multi-mesh targets must be implemented by the active physics backend.") def _initialize_warp_meshes_from_stage(self): - """Parse mesh prim expressions from USD and install the rectangular mesh table.""" - target_records_by_expr: dict[ - str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]] - ] = {} - dummy_mesh_id: int | None = None + """Parse mesh prim expressions from USD and install the flat slot table.""" + slot_env_ids: list[int] = [] + slot_mesh_ids: list[int] = [] + slot_positions: list[tuple[float, float, float]] = [] + slot_orientations: list[tuple[float, float, float, float]] = [] self._mesh_views = [] + self._slot_ranges_by_target_expr: dict[str, tuple[int, int]] = {} for target_cfg in self._raycast_targets_cfg: + target_start = len(slot_mesh_ids) target_prims = sim_utils.find_matching_prims(target_cfg.prim_expr) if len(target_prims) == 0: raise RuntimeError(f"Failed to find a prim at path expression: {target_cfg.prim_expr}") @@ -354,7 +351,6 @@ def _initialize_warp_meshes_from_stage(self): records = [] for target_prim in target_prims: mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg) - dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id pos, quat = sim_utils.resolve_prim_pose(target_prim) records.append((mesh_id, tuple(float(v) for v in pos), tuple(float(v) for v in quat))) @@ -369,57 +365,39 @@ def _initialize_warp_meshes_from_stage(self): n_meshes = len(records) // self._num_envs per_env_records = [records[i * n_meshes : (i + 1) * n_meshes] for i in range(self._num_envs)] + for env_id, env_records in enumerate(per_env_records): + for mesh_id, pos, quat in env_records: + slot_env_ids.append(env_id) + slot_mesh_ids.append(mesh_id) + slot_positions.append(pos) + slot_orientations.append(quat) + self._num_meshes_per_env[target_cfg.prim_expr] = max(len(env_records) for env_records in per_env_records) - target_records_by_expr[target_cfg.prim_expr] = per_env_records + self._slot_ranges_by_target_expr[target_cfg.prim_expr] = (target_start, len(slot_mesh_ids)) self._mesh_views.append( self._create_tracked_target_view(target_cfg.prim_expr) if target_cfg.track_mesh_transforms else None ) - self._install_rectangular_mesh_table(target_records_by_expr, dummy_mesh_id) + self._install_flat_mesh_slots(slot_env_ids, slot_mesh_ids, slot_positions, slot_orientations) - def _install_rectangular_mesh_table( + def _install_flat_mesh_slots( self, - target_records_by_expr: dict[ - str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]] - ], - dummy_mesh_id: int | None, + slot_env_ids: list[int], + slot_mesh_ids: list[int], + slot_positions: list[tuple[float, float, float]], + slot_orientations: list[tuple[float, float, float, float]], ) -> None: - """Pack per-target mesh records into the rectangular table used by the existing kernel.""" - if dummy_mesh_id is None: + """Install the compact per-environment slot arrays used by the flat kernel.""" + if not slot_mesh_ids: raise RuntimeError( f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}" ) - - dummy_record = (dummy_mesh_id, (1.0e9, 1.0e9, 1.0e9), (0.0, 0.0, 0.0, 1.0)) - multi_mesh_ids_flattened: list[list[int]] = [] - mesh_positions: list[list[tuple[float, float, float]]] = [] - mesh_orientations: list[list[tuple[float, float, float, float]]] = [] - - for env_id in range(self._num_envs): - meshes_in_env: list[int] = [] - positions_in_env: list[tuple[float, float, float]] = [] - orientations_in_env: list[tuple[float, float, float, float]] = [] - for target_cfg in self._raycast_targets_cfg: - records = list(target_records_by_expr[target_cfg.prim_expr][env_id]) - records.extend([dummy_record] * (self._num_meshes_per_env[target_cfg.prim_expr] - len(records))) - for mesh_id, pos, quat in records: - meshes_in_env.append(mesh_id) - positions_in_env.append(pos) - orientations_in_env.append(quat) - multi_mesh_ids_flattened.append(meshes_in_env) - mesh_positions.append(positions_in_env) - mesh_orientations.append(orientations_in_env) - - total_n_meshes_per_env = len(multi_mesh_ids_flattened[0]) - self._mesh_ids_wp = wp.array2d(multi_mesh_ids_flattened, dtype=wp.uint64, device=self.device) - self._mesh_positions_w = wp.zeros((self._num_envs, total_n_meshes_per_env), dtype=wp.vec3, device=self.device) - self._mesh_orientations_w = wp.zeros( - (self._num_envs, total_n_meshes_per_env), dtype=wp.quat, device=self.device - ) - self._mesh_positions_w_torch = wp.to_torch(self._mesh_positions_w) - self._mesh_orientations_w_torch = wp.to_torch(self._mesh_orientations_w) - self._mesh_positions_w_torch[:] = torch.tensor(mesh_positions, dtype=torch.float32, device=self.device) - self._mesh_orientations_w_torch[:] = torch.tensor(mesh_orientations, dtype=torch.float32, device=self.device) + self._slot_env_ids_wp = wp.array(slot_env_ids, dtype=wp.int32, device=self.device) + self._slot_mesh_ids_wp = wp.array(slot_mesh_ids, dtype=wp.uint64, device=self.device) + self._slot_mesh_positions_w = wp.array(slot_positions, dtype=wp.vec3, device=self.device) + self._slot_mesh_orientations_w = wp.array(slot_orientations, dtype=wp.quat, device=self.device) + self._slot_mesh_positions_w_torch = wp.to_torch(self._slot_mesh_positions_w) + self._slot_mesh_orientations_w_torch = wp.to_torch(self._slot_mesh_orientations_w) def _initialize_rays_impl(self): super()._initialize_rays_impl() @@ -440,13 +418,11 @@ def _update_mesh_transforms(self) -> None: """Update world-frame mesh positions and orientations for dynamically tracked targets. Iterates over all tracked views and writes the current world poses into - the rectangular mesh pose buffers. Static (non-tracked) targets are - skipped; their initial poses were set during :meth:`_initialize_warp_meshes`. + the flat slot pose buffers. Static (non-tracked) targets are skipped; + their initial poses were set during :meth:`_initialize_warp_meshes`. """ - mesh_idx = 0 for view, target_cfg in zip(self._mesh_views, self._raycast_targets_cfg): if not target_cfg.track_mesh_transforms: - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] continue # update position of the target meshes @@ -455,15 +431,18 @@ def _update_mesh_transforms(self) -> None: pos_w = pos_w.squeeze(0) if len(pos_w.shape) == 3 else pos_w ori_w = ori_w.squeeze(0) if len(ori_w.shape) == 3 else ori_w - count = getattr(view, "count", pos_w.shape[0]) - if count != 1: - count = count // self._num_envs - pos_w = pos_w.view(self._num_envs, count, 3) - ori_w = ori_w.view(self._num_envs, count, 4) - - self._mesh_positions_w_torch[:, mesh_idx : mesh_idx + count] = pos_w - self._mesh_orientations_w_torch[:, mesh_idx : mesh_idx + count] = ori_w - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] + slot_start, slot_end = self._slot_ranges_by_target_expr[target_cfg.prim_expr] + slot_count = slot_end - slot_start + if pos_w.shape[0] == 1 and slot_count > 1: + pos_w = pos_w.repeat(slot_count, 1) + ori_w = ori_w.repeat(slot_count, 1) + if pos_w.shape[0] != slot_count: + raise RuntimeError( + f"Tracked target '{target_cfg.prim_expr}' produced {pos_w.shape[0]} poses, " + f"but the raycaster has {slot_count} mesh slots for that target." + ) + self._slot_mesh_positions_w_torch[slot_start:slot_end] = pos_w + self._slot_mesh_orientations_w_torch[slot_start:slot_end] = ori_w def _update_buffers_impl(self, env_mask: wp.array): """Fills the buffers of the sensor data.""" @@ -484,15 +463,14 @@ def _update_buffers_impl(self, env_mask: wp.array): device=self._device, ) - n_meshes = self._mesh_ids_wp.shape[1] - - # Ray-cast against all meshes; closest hit wins via atomic_min on ray_distance. + # Ray-cast against all mesh slots; closest hit wins via atomic_min on ray_distance. wp.launch( - warp_kernels.raycast_dynamic_meshes_kernel, - dim=(n_meshes, self._num_envs, self.num_rays), + warp_kernels.raycast_dynamic_mesh_slots_kernel, + dim=(self._slot_mesh_ids_wp.shape[0], self.num_rays), inputs=[ env_mask, - self._mesh_ids_wp, + self._slot_env_ids_wp, + self._slot_mesh_ids_wp, self._ray_starts_w, self._ray_directions_w, self._data._ray_hits_w, @@ -500,8 +478,8 @@ def _update_buffers_impl(self, env_mask: wp.array): self._dummy_normal_w, self._dummy_face_id_w, self._ray_mesh_id_w, - self._mesh_positions_w, - self._mesh_orientations_w, + self._slot_mesh_positions_w, + self._slot_mesh_orientations_w, float(self.cfg.max_distance), int(False), int(False), diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster_camera.py b/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster_camera.py index 9c750031a83e..ab33c30664b6 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster_camera.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/base_multi_mesh_ray_caster_camera.py @@ -239,15 +239,14 @@ def _update_buffers_impl(self, env_mask: wp.array): device=self._device, ) - n_meshes = self._mesh_ids_wp.shape[1] - - # Ray-cast against all meshes; closest hit wins via atomic_min on ray_distance. + # Ray-cast against all mesh slots; closest hit wins via atomic_min on ray_distance. wp.launch( - warp_kernels.raycast_dynamic_meshes_kernel, - dim=(n_meshes, self._num_envs, self.num_rays), + warp_kernels.raycast_dynamic_mesh_slots_kernel, + dim=(self._slot_mesh_ids_wp.shape[0], self.num_rays), inputs=[ env_mask, - self._mesh_ids_wp, + self._slot_env_ids_wp, + self._slot_mesh_ids_wp, self._ray_starts_w, self._ray_directions_w, self._ray_hits_w_cam, @@ -255,8 +254,8 @@ def _update_buffers_impl(self, env_mask: wp.array): self._ray_normal_w, self._ray_face_id_w, self._ray_mesh_id_w, - self._mesh_positions_w, - self._mesh_orientations_w, + self._slot_mesh_positions_w, + self._slot_mesh_orientations_w, float(CAMERA_RAYCAST_MAX_DIST), int(return_normal), int(False), diff --git a/source/isaaclab/isaaclab/utils/warp/kernels.py b/source/isaaclab/isaaclab/utils/warp/kernels.py index efcdbfe63f1e..36de50d74bb4 100644 --- a/source/isaaclab/isaaclab/utils/warp/kernels.py +++ b/source/isaaclab/isaaclab/utils/warp/kernels.py @@ -326,6 +326,55 @@ def raycast_dynamic_meshes_kernel( ray_mesh_id[tid_env, tid_ray] = wp.int16(tid_mesh_id) +@wp.kernel(enable_backward=False) +def raycast_dynamic_mesh_slots_kernel( + env_mask: wp.array(dtype=wp.bool), + slot_env_ids: wp.array(dtype=wp.int32), + mesh: wp.array(dtype=wp.uint64), + ray_starts: wp.array2d(dtype=wp.vec3), + ray_directions: wp.array2d(dtype=wp.vec3), + ray_hits: wp.array2d(dtype=wp.vec3), + ray_distance: wp.array2d(dtype=wp.float32), + ray_normal: wp.array2d(dtype=wp.vec3), + ray_face_id: wp.array2d(dtype=wp.int32), + ray_mesh_id: wp.array2d(dtype=wp.int16), + mesh_positions: wp.array(dtype=wp.vec3), + mesh_rotations: wp.array(dtype=wp.quat), + max_dist: float = 1e6, + return_normal: int = False, + return_face_id: int = False, + return_mesh_id: int = False, +): + """Ray-cast against a flat list of per-environment mesh slots. + + Launch with ``dim=(num_slots, num_rays)``. Each slot carries the owning + environment id, allowing heterogeneous scenes where environments have + different mesh counts without padding to a rectangular table. + """ + slot_id, tid_ray = wp.tid() + tid_env = slot_env_ids[slot_id] + if not env_mask[tid_env]: + return + + mesh_pose = wp.transform(mesh_positions[slot_id], mesh_rotations[slot_id]) + mesh_pose_inv = wp.transform_inverse(mesh_pose) + direction = wp.transform_vector(mesh_pose_inv, ray_directions[tid_env, tid_ray]) + start_pos = wp.transform_point(mesh_pose_inv, ray_starts[tid_env, tid_ray]) + + mesh_query_ray_t = wp.mesh_query_ray(mesh[slot_id], start_pos, direction, max_dist) + if mesh_query_ray_t.result: + wp.atomic_min(ray_distance, tid_env, tid_ray, mesh_query_ray_t.t) + if mesh_query_ray_t.t == ray_distance[tid_env, tid_ray]: + hit_pos = start_pos + mesh_query_ray_t.t * direction + ray_hits[tid_env, tid_ray] = wp.transform_point(mesh_pose, hit_pos) + if return_normal == 1: + ray_normal[tid_env, tid_ray] = wp.transform_vector(mesh_pose, mesh_query_ray_t.normal) + if return_face_id == 1: + ray_face_id[tid_env, tid_ray] = mesh_query_ray_t.face + if return_mesh_id == 1: + ray_mesh_id[tid_env, tid_ray] = wp.int16(slot_id) + + @wp.kernel(enable_backward=False) def reshape_tiled_image( tiled_image_buffer: Any, diff --git a/source/isaaclab/test/sensors/test_ray_caster_integration.py b/source/isaaclab/test/sensors/test_ray_caster_integration.py index fea9eb21fa87..0c48aa1c6dd2 100644 --- a/source/isaaclab/test/sensors/test_ray_caster_integration.py +++ b/source/isaaclab/test/sensors/test_ray_caster_integration.py @@ -375,21 +375,14 @@ def test_multi_mesh_consumes_clone_plan_without_usd_object_clones(sim_ground): env1_object = stage.GetPrimAtPath("/World/envs/env_1/Object") assert env0_object is not None and env0_object.IsValid() assert env1_object is not None and env1_object.IsValid() - # No USD clone was authored for env_2: its mesh table entry comes from ClonePlan. + # No USD clone was authored for env_2: its mesh slot comes from ClonePlan. env2_object = stage.GetPrimAtPath("/World/envs/env_2/Object") assert env2_object is None or not env2_object.IsValid() - # This PR intentionally keeps the rectangular dynamic-mesh kernel. Heterogeneous - # ClonePlan rows are represented by padding shorter environments with a dummy - # mesh pose far outside the ray-cast range; a follow-up PR can replace this - # with a new kernel. - mesh_positions = sensor._mesh_positions_w_torch.cpu() - assert sensor._mesh_ids_wp.shape == (num_envs, 2) - assert mesh_positions.shape == (num_envs, 2, 3) - assert torch.linalg.norm(mesh_positions[0, 1]) > 1.0e8 - assert torch.linalg.norm(mesh_positions[1, 0]) < 1.0e8 - assert torch.linalg.norm(mesh_positions[1, 1]) < 1.0e8 - assert torch.linalg.norm(mesh_positions[2, 1]) > 1.0e8 + # Heterogeneous ClonePlan rows should not be padded to a rectangular table: + # envs 0 and 2 use the one-part source, env 1 uses the two-part source. + slot_env_ids = wp.to_torch(sensor._slot_env_ids_wp).cpu() + assert torch.equal(torch.bincount(slot_env_ids, minlength=num_envs), torch.tensor([1, 2, 1])) hits = sensor.data.ray_hits_w.torch assert torch.isfinite(hits[0]).any(), "env_0 should hit the single-part prototype" @@ -516,9 +509,9 @@ def test_update_mesh_transforms_non_identity_offset(sim_ground): # Verify mesh position: body at (0,0,2) rotated 90deg Z, child offset (1,0,0) local # Expected: (0, 0, 2) + rotate(90degZ, (1,0,0)) = (0, 0, 2) + (0, 1, 0) = (0, 1, 2) - mesh_pos = sensor._mesh_positions_w_torch.clone() + mesh_pos = sensor._slot_mesh_positions_w_torch.clone() np.testing.assert_allclose( - mesh_pos[0, 0].cpu().numpy(), + mesh_pos[0].cpu().numpy(), [0.0, 1.0, 2.0], atol=0.15, err_msg=( diff --git a/source/isaaclab_newton/isaaclab_newton/sensors/ray_caster/ray_caster.py b/source/isaaclab_newton/isaaclab_newton/sensors/ray_caster/ray_caster.py index e4f9dd4bba67..ff36387bc9b1 100644 --- a/source/isaaclab_newton/isaaclab_newton/sensors/ray_caster/ray_caster.py +++ b/source/isaaclab_newton/isaaclab_newton/sensors/ray_caster/ray_caster.py @@ -161,13 +161,11 @@ def _create_tracked_target_view(self: Any, target_prim_path: str): return wp.array(site_indices, dtype=wp.int32, device=self._device) def _update_mesh_transforms(self: Any) -> None: - """Refresh dynamic multi-mesh targets directly from Newton sites.""" + """Refresh dynamic multi-mesh target slots directly from Newton sites.""" if not hasattr(self, "_mesh_views"): return - mesh_idx = 0 for site_indices, target_cfg in zip(self._mesh_views, self._raycast_targets_cfg): if not target_cfg.track_mesh_transforms: - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] continue count = site_indices.shape[0] @@ -177,14 +175,18 @@ def _update_mesh_transforms(self: Any) -> None: self._update_newton_site_transforms(site_indices, pose_buf, pos_buf, quat_buf) pos_w = wp.to_torch(pos_buf) quat_w = wp.to_torch(quat_buf) - if count != 1: - count = count // self._num_envs - pos_w = pos_w.view(self._num_envs, count, 3) - quat_w = quat_w.view(self._num_envs, count, 4) - - self._mesh_positions_w_torch[:, mesh_idx : mesh_idx + count] = pos_w - self._mesh_orientations_w_torch[:, mesh_idx : mesh_idx + count] = quat_w - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] + slot_start, slot_end = self._slot_ranges_by_target_expr[target_cfg.prim_expr] + slot_count = slot_end - slot_start + if count == 1 and slot_count > 1: + pos_w = pos_w.repeat(slot_count, 1) + quat_w = quat_w.repeat(slot_count, 1) + if pos_w.shape[0] != slot_count: + raise RuntimeError( + f"Tracked target '{target_cfg.prim_expr}' produced {pos_w.shape[0]} poses, " + f"but the raycaster has {slot_count} mesh slots for that target." + ) + self._slot_mesh_positions_w_torch[slot_start:slot_end] = pos_w + self._slot_mesh_orientations_w_torch[slot_start:slot_end] = quat_w def _update_newton_site_transforms( self: Any, diff --git a/source/isaaclab_physx/isaaclab_physx/sensors/ray_caster/ray_caster.py b/source/isaaclab_physx/isaaclab_physx/sensors/ray_caster/ray_caster.py index 45ebd828957c..1ff37ca7e4e3 100644 --- a/source/isaaclab_physx/isaaclab_physx/sensors/ray_caster/ray_caster.py +++ b/source/isaaclab_physx/isaaclab_physx/sensors/ray_caster/ray_caster.py @@ -150,13 +150,11 @@ def _create_tracked_target_view(self: Any, target_prim_path: str): return physics_sim_view.create_rigid_body_view(body_expr.replace(".*", "*")) def _update_mesh_transforms(self: Any) -> None: - """Refresh dynamic multi-mesh targets directly from PhysX views.""" + """Refresh dynamic multi-mesh target slots directly from PhysX views.""" if not hasattr(self, "_mesh_views"): return - mesh_idx = 0 for view, target_cfg in zip(self._mesh_views, self._raycast_targets_cfg): if not target_cfg.track_mesh_transforms: - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] continue transforms = view.get_transforms() @@ -164,15 +162,18 @@ def _update_mesh_transforms(self: Any) -> None: pos_w = transforms_t[:, 0:3] ori_w = transforms_t[:, 3:7] - count = view.count - if count != 1: - count = count // self._num_envs - pos_w = pos_w.view(self._num_envs, count, 3) - ori_w = ori_w.view(self._num_envs, count, 4) - - self._mesh_positions_w_torch[:, mesh_idx : mesh_idx + count] = pos_w - self._mesh_orientations_w_torch[:, mesh_idx : mesh_idx + count] = ori_w - mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] + slot_start, slot_end = self._slot_ranges_by_target_expr[target_cfg.prim_expr] + slot_count = slot_end - slot_start + if pos_w.shape[0] == 1 and slot_count > 1: + pos_w = pos_w.repeat(slot_count, 1) + ori_w = ori_w.repeat(slot_count, 1) + if pos_w.shape[0] != slot_count: + raise RuntimeError( + f"Tracked target '{target_cfg.prim_expr}' produced {pos_w.shape[0]} poses, " + f"but the raycaster has {slot_count} mesh slots for that target." + ) + self._slot_mesh_positions_w_torch[slot_start:slot_end] = pos_w + self._slot_mesh_orientations_w_torch[slot_start:slot_end] = ori_w class RayCaster(_PhysXRayCasterMixin, BaseRayCaster):