Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,17 @@ def _initialize_warp_meshes(self):
self._initialize_warp_meshes_from_clone_plan(plan)

def _initialize_warp_meshes_from_clone_plan(self, plan) -> None:
"""Initialize rectangular mesh buffers from ClonePlan source rows.

The current PR keeps the existing rectangular kernel ABI. Environments
with fewer meshes than the target maximum are padded with a valid mesh
placed far outside the ray-cast range. A follow-up PR can replace this
padded representation with a new kernel.
"""
target_records_by_expr: dict[
str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]]
] = {}
dummy_mesh_id: int | None = None
"""Initialize flat mesh slots from ClonePlan source rows."""
slot_env_ids: list[int] = []
slot_mesh_ids: list[int] = []
slot_positions: list[tuple[float, float, float]] = []
slot_orientations: list[tuple[float, float, float, float]] = []
self._mesh_views = []
self._slot_ranges_by_target_expr: dict[str, tuple[int, int]] = {}

for target_cfg in self._raycast_targets_cfg:
records_per_env: list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]] = [
[] for _ in range(self._num_envs)
]
target_start = len(slot_mesh_ids)
meshes_added_per_env = [0 for _ in range(self._num_envs)]
matches = self._collect_clone_plan_matches(plan, target_cfg.prim_expr)
if matches:
for row, source_root, source_expr in matches:
Expand All @@ -171,7 +165,6 @@ def _initialize_warp_meshes_from_clone_plan(self, plan) -> None:
prototype_records = []
for target_prim in target_prims:
mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg)
dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id
source_root_prim = self.stage.GetPrimAtPath(source_root)
local_pos, local_quat = sim_utils.resolve_prim_pose(target_prim, source_root_prim)
prototype_records.append((mesh_id, local_pos, local_quat))
Expand All @@ -187,33 +180,35 @@ def _initialize_warp_meshes_from_clone_plan(self, plan) -> None:
mesh_pos_t, mesh_quat_t = math_utils.combine_frame_transforms(
root_pos_t, root_quat_t, local_pos_t, local_quat_t
)
records_per_env[env_id].append(
(
mesh_id,
tuple(float(v) for v in mesh_pos_t[0].tolist()),
tuple(float(v) for v in mesh_quat_t[0].tolist()),
)
)
slot_env_ids.append(env_id)
slot_mesh_ids.append(mesh_id)
slot_positions.append(tuple(float(v) for v in mesh_pos_t[0].tolist()))
slot_orientations.append(tuple(float(v) for v in mesh_quat_t[0].tolist()))
meshes_added_per_env[env_id] += 1
else:
target_prims = sim_utils.find_matching_prims(target_cfg.prim_expr)
if len(target_prims) == 0:
raise RuntimeError(f"Failed to find a prim at path expression: {target_cfg.prim_expr}")
records = []
for target_prim in target_prims:
mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg)
dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id
pos, quat = sim_utils.resolve_prim_pose(target_prim)
records.append((mesh_id, tuple(float(v) for v in pos), tuple(float(v) for v in quat)))
for env_id in range(self._num_envs):
records_per_env[env_id].extend(records)

self._num_meshes_per_env[target_cfg.prim_expr] = max(len(records) for records in records_per_env)
target_records_by_expr[target_cfg.prim_expr] = records_per_env
for mesh_id, pos, quat in records:
slot_env_ids.append(env_id)
slot_mesh_ids.append(mesh_id)
slot_positions.append(pos)
slot_orientations.append(quat)
meshes_added_per_env[env_id] += 1

self._num_meshes_per_env[target_cfg.prim_expr] = max(meshes_added_per_env)
self._slot_ranges_by_target_expr[target_cfg.prim_expr] = (target_start, len(slot_mesh_ids))
self._mesh_views.append(
self._create_tracked_target_view(target_cfg.prim_expr) if target_cfg.track_mesh_transforms else None
)

self._install_rectangular_mesh_table(target_records_by_expr, dummy_mesh_id)
self._install_flat_mesh_slots(slot_env_ids, slot_mesh_ids, slot_positions, slot_orientations)

def _collect_clone_plan_matches(self, plan, target_expr: str) -> list[tuple[int, str, str]]:
target_env0 = _target_expr_for_env(target_expr, 0)
Expand Down Expand Up @@ -339,22 +334,23 @@ def _create_tracked_target_view(self, target_prim_path: str):
raise NotImplementedError("Tracked multi-mesh targets must be implemented by the active physics backend.")

def _initialize_warp_meshes_from_stage(self):
"""Parse mesh prim expressions from USD and install the rectangular mesh table."""
target_records_by_expr: dict[
str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]]
] = {}
dummy_mesh_id: int | None = None
"""Parse mesh prim expressions from USD and install the flat slot table."""
slot_env_ids: list[int] = []
slot_mesh_ids: list[int] = []
slot_positions: list[tuple[float, float, float]] = []
slot_orientations: list[tuple[float, float, float, float]] = []
self._mesh_views = []
self._slot_ranges_by_target_expr: dict[str, tuple[int, int]] = {}

for target_cfg in self._raycast_targets_cfg:
target_start = len(slot_mesh_ids)
target_prims = sim_utils.find_matching_prims(target_cfg.prim_expr)
if len(target_prims) == 0:
raise RuntimeError(f"Failed to find a prim at path expression: {target_cfg.prim_expr}")

records = []
for target_prim in target_prims:
mesh_id = self._load_target_prim_warp_mesh(target_prim, target_cfg)
dummy_mesh_id = mesh_id if dummy_mesh_id is None else dummy_mesh_id
pos, quat = sim_utils.resolve_prim_pose(target_prim)
records.append((mesh_id, tuple(float(v) for v in pos), tuple(float(v) for v in quat)))

Expand All @@ -369,57 +365,39 @@ def _initialize_warp_meshes_from_stage(self):
n_meshes = len(records) // self._num_envs
per_env_records = [records[i * n_meshes : (i + 1) * n_meshes] for i in range(self._num_envs)]

for env_id, env_records in enumerate(per_env_records):
for mesh_id, pos, quat in env_records:
slot_env_ids.append(env_id)
slot_mesh_ids.append(mesh_id)
slot_positions.append(pos)
slot_orientations.append(quat)

self._num_meshes_per_env[target_cfg.prim_expr] = max(len(env_records) for env_records in per_env_records)
target_records_by_expr[target_cfg.prim_expr] = per_env_records
self._slot_ranges_by_target_expr[target_cfg.prim_expr] = (target_start, len(slot_mesh_ids))
self._mesh_views.append(
self._create_tracked_target_view(target_cfg.prim_expr) if target_cfg.track_mesh_transforms else None
)

self._install_rectangular_mesh_table(target_records_by_expr, dummy_mesh_id)
self._install_flat_mesh_slots(slot_env_ids, slot_mesh_ids, slot_positions, slot_orientations)

def _install_rectangular_mesh_table(
def _install_flat_mesh_slots(
self,
target_records_by_expr: dict[
str, list[list[tuple[int, tuple[float, float, float], tuple[float, float, float, float]]]]
],
dummy_mesh_id: int | None,
slot_env_ids: list[int],
slot_mesh_ids: list[int],
slot_positions: list[tuple[float, float, float]],
slot_orientations: list[tuple[float, float, float, float]],
) -> None:
"""Pack per-target mesh records into the rectangular table used by the existing kernel."""
if dummy_mesh_id is None:
"""Install the compact per-environment slot arrays used by the flat kernel."""
if not slot_mesh_ids:
raise RuntimeError(
f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}"
)

dummy_record = (dummy_mesh_id, (1.0e9, 1.0e9, 1.0e9), (0.0, 0.0, 0.0, 1.0))
multi_mesh_ids_flattened: list[list[int]] = []
mesh_positions: list[list[tuple[float, float, float]]] = []
mesh_orientations: list[list[tuple[float, float, float, float]]] = []

for env_id in range(self._num_envs):
meshes_in_env: list[int] = []
positions_in_env: list[tuple[float, float, float]] = []
orientations_in_env: list[tuple[float, float, float, float]] = []
for target_cfg in self._raycast_targets_cfg:
records = list(target_records_by_expr[target_cfg.prim_expr][env_id])
records.extend([dummy_record] * (self._num_meshes_per_env[target_cfg.prim_expr] - len(records)))
for mesh_id, pos, quat in records:
meshes_in_env.append(mesh_id)
positions_in_env.append(pos)
orientations_in_env.append(quat)
multi_mesh_ids_flattened.append(meshes_in_env)
mesh_positions.append(positions_in_env)
mesh_orientations.append(orientations_in_env)

total_n_meshes_per_env = len(multi_mesh_ids_flattened[0])
self._mesh_ids_wp = wp.array2d(multi_mesh_ids_flattened, dtype=wp.uint64, device=self.device)
self._mesh_positions_w = wp.zeros((self._num_envs, total_n_meshes_per_env), dtype=wp.vec3, device=self.device)
self._mesh_orientations_w = wp.zeros(
(self._num_envs, total_n_meshes_per_env), dtype=wp.quat, device=self.device
)
self._mesh_positions_w_torch = wp.to_torch(self._mesh_positions_w)
self._mesh_orientations_w_torch = wp.to_torch(self._mesh_orientations_w)
self._mesh_positions_w_torch[:] = torch.tensor(mesh_positions, dtype=torch.float32, device=self.device)
self._mesh_orientations_w_torch[:] = torch.tensor(mesh_orientations, dtype=torch.float32, device=self.device)
self._slot_env_ids_wp = wp.array(slot_env_ids, dtype=wp.int32, device=self.device)
self._slot_mesh_ids_wp = wp.array(slot_mesh_ids, dtype=wp.uint64, device=self.device)
self._slot_mesh_positions_w = wp.array(slot_positions, dtype=wp.vec3, device=self.device)
self._slot_mesh_orientations_w = wp.array(slot_orientations, dtype=wp.quat, device=self.device)
self._slot_mesh_positions_w_torch = wp.to_torch(self._slot_mesh_positions_w)
self._slot_mesh_orientations_w_torch = wp.to_torch(self._slot_mesh_orientations_w)

def _initialize_rays_impl(self):
super()._initialize_rays_impl()
Expand All @@ -440,13 +418,11 @@ def _update_mesh_transforms(self) -> None:
"""Update world-frame mesh positions and orientations for dynamically tracked targets.

Iterates over all tracked views and writes the current world poses into
the rectangular mesh pose buffers. Static (non-tracked) targets are
skipped; their initial poses were set during :meth:`_initialize_warp_meshes`.
the flat slot pose buffers. Static (non-tracked) targets are skipped;
their initial poses were set during :meth:`_initialize_warp_meshes`.
"""
mesh_idx = 0
for view, target_cfg in zip(self._mesh_views, self._raycast_targets_cfg):
if not target_cfg.track_mesh_transforms:
mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr]
continue

# update position of the target meshes
Expand All @@ -455,15 +431,18 @@ def _update_mesh_transforms(self) -> None:
pos_w = pos_w.squeeze(0) if len(pos_w.shape) == 3 else pos_w
ori_w = ori_w.squeeze(0) if len(ori_w.shape) == 3 else ori_w

count = getattr(view, "count", pos_w.shape[0])
if count != 1:
count = count // self._num_envs
pos_w = pos_w.view(self._num_envs, count, 3)
ori_w = ori_w.view(self._num_envs, count, 4)

self._mesh_positions_w_torch[:, mesh_idx : mesh_idx + count] = pos_w
self._mesh_orientations_w_torch[:, mesh_idx : mesh_idx + count] = ori_w
mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr]
slot_start, slot_end = self._slot_ranges_by_target_expr[target_cfg.prim_expr]
slot_count = slot_end - slot_start
if pos_w.shape[0] == 1 and slot_count > 1:
pos_w = pos_w.repeat(slot_count, 1)
ori_w = ori_w.repeat(slot_count, 1)
if pos_w.shape[0] != slot_count:
raise RuntimeError(
f"Tracked target '{target_cfg.prim_expr}' produced {pos_w.shape[0]} poses, "
f"but the raycaster has {slot_count} mesh slots for that target."
)
self._slot_mesh_positions_w_torch[slot_start:slot_end] = pos_w
self._slot_mesh_orientations_w_torch[slot_start:slot_end] = ori_w

def _update_buffers_impl(self, env_mask: wp.array):
"""Fills the buffers of the sensor data."""
Expand All @@ -484,24 +463,23 @@ def _update_buffers_impl(self, env_mask: wp.array):
device=self._device,
)

n_meshes = self._mesh_ids_wp.shape[1]

# Ray-cast against all meshes; closest hit wins via atomic_min on ray_distance.
# Ray-cast against all mesh slots; closest hit wins via atomic_min on ray_distance.
wp.launch(
warp_kernels.raycast_dynamic_meshes_kernel,
dim=(n_meshes, self._num_envs, self.num_rays),
warp_kernels.raycast_dynamic_mesh_slots_kernel,
dim=(self._slot_mesh_ids_wp.shape[0], self.num_rays),
inputs=[
env_mask,
self._mesh_ids_wp,
self._slot_env_ids_wp,
self._slot_mesh_ids_wp,
self._ray_starts_w,
self._ray_directions_w,
self._data._ray_hits_w,
self._ray_distance_w,
self._dummy_normal_w,
self._dummy_face_id_w,
self._ray_mesh_id_w,
self._mesh_positions_w,
self._mesh_orientations_w,
self._slot_mesh_positions_w,
self._slot_mesh_orientations_w,
float(self.cfg.max_distance),
int(False),
int(False),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,23 @@ def _update_buffers_impl(self, env_mask: wp.array):
device=self._device,
)

n_meshes = self._mesh_ids_wp.shape[1]

# Ray-cast against all meshes; closest hit wins via atomic_min on ray_distance.
# Ray-cast against all mesh slots; closest hit wins via atomic_min on ray_distance.
wp.launch(
warp_kernels.raycast_dynamic_meshes_kernel,
dim=(n_meshes, self._num_envs, self.num_rays),
warp_kernels.raycast_dynamic_mesh_slots_kernel,
dim=(self._slot_mesh_ids_wp.shape[0], self.num_rays),
inputs=[
env_mask,
self._mesh_ids_wp,
self._slot_env_ids_wp,
self._slot_mesh_ids_wp,
self._ray_starts_w,
self._ray_directions_w,
self._ray_hits_w_cam,
self._ray_distance_cam_w,
self._ray_normal_w,
self._ray_face_id_w,
self._ray_mesh_id_w,
self._mesh_positions_w,
self._mesh_orientations_w,
self._slot_mesh_positions_w,
self._slot_mesh_orientations_w,
float(CAMERA_RAYCAST_MAX_DIST),
int(return_normal),
int(False),
Expand Down
49 changes: 49 additions & 0 deletions source/isaaclab/isaaclab/utils/warp/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,55 @@ def raycast_dynamic_meshes_kernel(
ray_mesh_id[tid_env, tid_ray] = wp.int16(tid_mesh_id)


@wp.kernel(enable_backward=False)
def raycast_dynamic_mesh_slots_kernel(
env_mask: wp.array(dtype=wp.bool),
slot_env_ids: wp.array(dtype=wp.int32),
mesh: wp.array(dtype=wp.uint64),
ray_starts: wp.array2d(dtype=wp.vec3),
ray_directions: wp.array2d(dtype=wp.vec3),
ray_hits: wp.array2d(dtype=wp.vec3),
ray_distance: wp.array2d(dtype=wp.float32),
ray_normal: wp.array2d(dtype=wp.vec3),
ray_face_id: wp.array2d(dtype=wp.int32),
ray_mesh_id: wp.array2d(dtype=wp.int16),
mesh_positions: wp.array(dtype=wp.vec3),
mesh_rotations: wp.array(dtype=wp.quat),
max_dist: float = 1e6,
return_normal: int = False,
return_face_id: int = False,
return_mesh_id: int = False,
):
"""Ray-cast against a flat list of per-environment mesh slots.

Launch with ``dim=(num_slots, num_rays)``. Each slot carries the owning
environment id, allowing heterogeneous scenes where environments have
different mesh counts without padding to a rectangular table.
"""
slot_id, tid_ray = wp.tid()
tid_env = slot_env_ids[slot_id]
if not env_mask[tid_env]:
return

mesh_pose = wp.transform(mesh_positions[slot_id], mesh_rotations[slot_id])
mesh_pose_inv = wp.transform_inverse(mesh_pose)
direction = wp.transform_vector(mesh_pose_inv, ray_directions[tid_env, tid_ray])
start_pos = wp.transform_point(mesh_pose_inv, ray_starts[tid_env, tid_ray])

mesh_query_ray_t = wp.mesh_query_ray(mesh[slot_id], start_pos, direction, max_dist)
if mesh_query_ray_t.result:
wp.atomic_min(ray_distance, tid_env, tid_ray, mesh_query_ray_t.t)
if mesh_query_ray_t.t == ray_distance[tid_env, tid_ray]:
hit_pos = start_pos + mesh_query_ray_t.t * direction
ray_hits[tid_env, tid_ray] = wp.transform_point(mesh_pose, hit_pos)
if return_normal == 1:
ray_normal[tid_env, tid_ray] = wp.transform_vector(mesh_pose, mesh_query_ray_t.normal)
if return_face_id == 1:
ray_face_id[tid_env, tid_ray] = mesh_query_ray_t.face
if return_mesh_id == 1:
ray_mesh_id[tid_env, tid_ray] = wp.int16(slot_id)


@wp.kernel(enable_backward=False)
def reshape_tiled_image(
tiled_image_buffer: Any,
Expand Down
Loading
Loading