diff --git a/iron/operators/transpose/design.py b/iron/operators/transpose/design.py index bec4382b..ff5c325c 100644 --- a/iron/operators/transpose/design.py +++ b/iron/operators/transpose/design.py @@ -10,7 +10,9 @@ from aie.iron.controlflow import range_ -def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix=""): +def shuffle_transpose( + dev, M, N, num_columns, num_channels, m, n, s, num_batches=1, func_prefix="" +): num_elements = M * N per_tile_elements = m * n dtype = bfloat16 @@ -34,8 +36,9 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix if s == 8 and (m <= 16 or n <= 16): raise ValueError(f"Kernel tile {s} needs AIE tile rows > 16 and columns > 16.") - # Define tensor types - tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + # Define tensor types. The runtime tensor spans all batches (contiguous matrices); + # per-tile work on the cores is identical regardless of batch count. + tensor_ty = np.ndarray[(num_batches * num_elements,), np.dtype[dtype]] tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] fifodepth = 1 if per_tile_elements > 4096 else 2 @@ -47,13 +50,25 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix # and channels. Partially transposes the input # data so that the kernel only needs to # transpose s*s-sized sub-tiles. + # The L3 tensors hold num_batches contiguous (M,N) matrices stacked along the row + # dimension: in-dims (num_batches*M, N), out-dims (num_batches*N, M); at num_batches==1 + # these are simply (M,N)/(N,M). Each (i,j) column/channel emits one TAP per batch, offset + # by batch*num_elements; the per-batch internal sizes/strides are the same for every batch + # because each matrix is contiguous and row-major. + in_dims = (num_batches * M, N) + out_dims = (num_batches * N, M) taps_in_L3L2 = [ - TensorAccessPattern( - (M, N), - (M // num_channels) * j * N + (N // num_columns) * i, - [M // num_channels // m, N // num_columns // n, m, n], - [m * N, n, N, 1], - ) + [ + TensorAccessPattern( + in_dims, + batch * num_elements + + (M // num_channels) * j * N + + (N // num_columns) * i, + [M // num_channels // m, N // num_columns // n, m, n], + [m * N, n, N, 1], + ) + for batch in range(num_batches) + ] for i in range(num_columns) for j in range(num_channels) ] @@ -68,12 +83,17 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix for j in range(num_channels) ] taps_out_L1L3 = [ - TensorAccessPattern( - (N, M), - (N // num_columns) * i * M + (M // num_channels) * j, - [M // num_channels // m, N // num_columns // n, n, m], - [m, n * M, M, 1], - ) + [ + TensorAccessPattern( + out_dims, + batch * num_elements + + (N // num_columns) * i * M + + (M // num_channels) * j, + [M // num_channels // m, N // num_columns // n, n, m], + [m, n * M, M, 1], + ) + for batch in range(num_batches) + ] for i in range(num_columns) for j in range(num_channels) ] @@ -106,14 +126,17 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix # Define a task that will run on a compute tile def core_body(of_in1, of_out, transpose_kernel): - # Number of sub-matrix "tile" iterations - for _ in range_(N // n // num_columns): - for _ in range_(M // m // num_channels): - elem_in1 = of_in1.acquire(1) - elem_out = of_out.acquire(1) - transpose_kernel(elem_in1, elem_out) - of_out.release(1) - of_in1.release(1) + # Process num_batches contiguous matrices through the same FIFOs: num_batches x the per-matrix + # tile iterations. The kernel only ever sees s*s sub-tiles, so it is batch-agnostic. + for _ in range_(num_batches): + # Number of sub-matrix "tile" iterations + for _ in range_(N // n // num_columns): + for _ in range_(M // m // num_channels): + elem_in1 = of_in1.acquire(1) + elem_out = of_out.acquire(1) + transpose_kernel(elem_in1, elem_out) + of_out.release(1) + of_in1.release(1) # Create a worker to run the task on a compute tile my_workers = [ @@ -134,29 +157,32 @@ def core_body(of_in1, of_out, transpose_kernel): with rt.sequence(tensor_ty, tensor_ty) as (A, C): rt.start(*my_workers) - # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. - tg = rt.task_group() - - # Fill the input objectFIFOs with data - for i in range(num_columns): - for j in range(num_channels): - rt.fill( - of_in1s_L3L2[i * num_channels + j].prod(), - A, - taps_in_L3L2[i * num_channels + j], - task_group=tg, - ) - # Drain the output objectFIFOs with data - for i in range(num_columns): - for j in range(num_channels): - rt.drain( - of_outs[i * num_channels + j].cons(), - C, - taps_out_L1L3[i * num_channels + j], - wait=True, # wait for the transfer to complete and data to be available - task_group=tg, - ) - rt.finish_task_group(tg) + # One task group per batch (each a parallel fill+drain over all columns/channels), so the + # num_batches contiguous matrices stream through the same FIFOs in sequence. + for batch in range(num_batches): + # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. + tg = rt.task_group() + + # Fill the input objectFIFOs with data + for i in range(num_columns): + for j in range(num_channels): + rt.fill( + of_in1s_L3L2[i * num_channels + j].prod(), + A, + taps_in_L3L2[i * num_channels + j][batch], + task_group=tg, + ) + # Drain the output objectFIFOs of data + for i in range(num_columns): + for j in range(num_channels): + rt.drain( + of_outs[i * num_channels + j].cons(), + C, + taps_out_L1L3[i * num_channels + j][batch], + wait=True, # wait for the transfer to complete and data to be available + task_group=tg, + ) + rt.finish_task_group(tg) # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/transpose/op.py b/iron/operators/transpose/op.py index d37e0101..699720c7 100644 --- a/iron/operators/transpose/op.py +++ b/iron/operators/transpose/op.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field +from typing import ClassVar, Dict import aie.utils as aie_utils from iron.common import ( @@ -16,7 +17,13 @@ @dataclass class Transpose(MLIROperator): - """AIE-accelerated transpose operator""" + """AIE-accelerated transpose operator. + + ``num_batches`` > 1 performs that many independent (M,N)->(N,M) transposes on + contiguous matrices laid back-to-back in memory (results concatenated), mirroring + GEMV's batching — the per-batch tile work rides the same ObjectFifos, so B batched + transposes cost ONE dispatch instead of B unrolled ones. + """ M: int N: int @@ -25,8 +32,14 @@ class Transpose(MLIROperator): m: int n: int s: int + num_batches: int = 1 context: object = field(default=None, repr=False) + _name_aliases: ClassVar[Dict[str, str]] = { + **MLIROperator._name_aliases, + "num_batches": "batch", + } + def __post_init__(self): if self.M % self.m != 0: raise ValueError(f"Matrix rows ({self.M}) must be a multiple of {self.m}") @@ -66,6 +79,7 @@ def get_mlir_artifact(self): self.m, self.n, self.s, + self.num_batches, ), ), ) @@ -90,7 +104,8 @@ def get_kernel_artifacts(self): ] def get_arg_spec(self): + batch_dim = (self.num_batches,) if self.num_batches > 1 else () return [ - AIERuntimeArgSpec("in", (self.M * self.N,)), - AIERuntimeArgSpec("out", (self.M * self.N,)), + AIERuntimeArgSpec("in", batch_dim + (self.M * self.N,)), + AIERuntimeArgSpec("out", batch_dim + (self.N * self.M,)), ] diff --git a/iron/operators/transpose/reference.py b/iron/operators/transpose/reference.py index 672a11fd..a2a16843 100644 --- a/iron/operators/transpose/reference.py +++ b/iron/operators/transpose/reference.py @@ -5,9 +5,20 @@ from iron.common.test_utils import torch_dtype_map -def generate_golden_reference(rows: int, cols: int, dtype="bf16", seed=42): +def generate_golden_reference( + rows: int, cols: int, dtype="bf16", seed=42, num_batches=1 +): torch.manual_seed(seed) val_range = 4 - input_tensor = torch.rand(rows, cols, dtype=torch_dtype_map[dtype]) * val_range - output_tensor = torch.transpose(input_tensor, 0, 1) + # num_batches>1: B independent (rows,cols) matrices laid back-to-back; each is + # transposed independently and the results concatenated in the same order. + input_tensor = ( + torch.rand(num_batches, rows, cols, dtype=torch_dtype_map[dtype]) * val_range + ) + output_tensor = torch.stack( + [torch.transpose(input_tensor[b], 0, 1) for b in range(num_batches)] + ) + # drop batch dimension if num_batches == 1 + input_tensor = torch.squeeze(input_tensor, 0) + output_tensor = torch.squeeze(output_tensor, 0) return {"input": input_tensor, "output": output_tensor} diff --git a/iron/operators/transpose/test.py b/iron/operators/transpose/test.py index 19b6242f..1e404193 100755 --- a/iron/operators/transpose/test.py +++ b/iron/operators/transpose/test.py @@ -47,10 +47,29 @@ def get_params(): m, n, s, + 1, marks=marks, ) ) + # num_batches>1: batch B independent same-shape transposes into one dispatch + # (regular shape, single column/channel). num_batches=2 runs in the default + # suite; the larger batch is extensive. + for nb in (2, 4): + params.append( + pytest.param( + 2048, + 64, + 1, + 1, + m, + n, + 8, + nb, + marks=[] if nb == 2 else [pytest.mark.extensive], + ) + ) + return params @@ -58,9 +77,9 @@ def get_params(): Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", get_params()) -def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context): - golden_ref = generate_golden_reference(rows=M, cols=N) +@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s,num_batches", get_params()) +def test_transpose(M, N, aie_columns, channels, m, n, s, num_batches, aie_context): + golden_ref = generate_golden_reference(rows=M, cols=N, num_batches=num_batches) operator = Transpose( M=M, @@ -70,6 +89,7 @@ def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context): m=m, n=n, s=s, + num_batches=num_batches, context=aie_context, )