Jasmine Tang

My blog

[ONGOING] tMLIR Transform note

2025-07-27

Resources

Payload and transform

In the Transform dialect there's two kind of IR:

  • If an IR is operating on others, transforming them, that is called a Transform IR.
  • If an IR is being operated on, its called a payload IR.

Structured linalg ops

indexing maps and affine_map: a lot of the time, MLIR ops accept something called affine_map, which is a linear mapping of indexes that an array, a vector or a tensor can iterate on.

iterator_types: now, when you're iterating over these array or vector, you'll need to tell the indexer, for each dimension, if you're "parallel"-ing (keeping its dimension) or "reduction" (reducing the dimension)

For example, the following code's result use vector.contract op. The operation takes in two attributes: indexing_maps and iterator_types. There are three affine_map setting in indexing_maps, as we have 2 inputs and 1 output (making it 3). The singular i is because we only have 1 dimension. The former i stands for the available index variable to use, and the latter i (if present) stands for the index that the array, tensor or vector would use. Since we're producing an integer, the third affine_map goes from (i) to (), thus the reduction type of iterator_types

Note that since we have 8 elements in first input, and 8 elements in second input (same dimension), we can use the same (i).

%init  = arith.constant 0.0 : f32
// Neutral element of multiplication.
%ones = arith.constant dense<1.0> : vector<8xf32>
// Actual contraction.
%result = vector.contract {
  indexing_maps = [affine_map<(i) -> (i)>,
                   affine_map<(i) -> (i)>,
                   affine_map<(i) -> ()>],
  iterator_types = ["reduction"]
} %0, %ones, %init : vector<8xf32>, vector<8xf32> into f32

The pseudo code for the loop is

for i in 0 to 8:
  init += p0[i] * ones[i]

Alright, seems like you quite got that? Let's do another round:

%result = vector.contract {
  indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                   affine_map<(i, j, k) -> (k, j)>,
                   affine_map<(i, j, k) -> (i, j)>],
  iterator_types = ["parallel", "parallel", "reduction"]
} %lhs, %rhs, %init: vector<8x10xf32>, vector<10x16xf32> into vector<8x16xf32>

Here you take 2 input and get out 1 output, so there are 3 affine_maps. Each vector have 3 dimension, so (i, j, k). The first vector uses i and k to index its two dimension. The second vector uses k and j to index its two dimension, and the output vector uses the i and j index. we have 8, 10, 16,so we need to do (i, j, k) instead of just i, j.

Here's the pseudo code in for-loops

for i in 0 to 8:
  for j in 0 to 16:
    for k in 0 to 10:
      init[i, j] += lhs[i, k] * rhs[k, j]

linalg.generic

linalg.generic can be used to implement other operation. Let's do this

linalg.generic {
  indexing_maps [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
  iterator_types = ["parallel"]
} ins(%in : memref<?xf32>) outs(%out : memref<?xf32>) {
^bb0(%in_one : f32, %out_one : f32):
  %c0 = arith.constant 0.0 : f32
  %0 = arith.cmpf ogt %in_one, %c0 : f32
  %1 = arith.select %0, %in_one, %c0 : f32
  linalg.yield %1 : f32
}

The prior code shows you how to use linalg.generic to implement a ReLu filter. The basic block bb0 (marked with a ^ in the front) implements the ReLu kernel on a per-index basis, and the way we index these things are specified by the indexing_maps attribute. Inside this basic block, you can use other scalar operations to implement the kernel:

  • Initialize 0 as a constant to c0.
  • Compare the input scalar to c0, and put the result to %0.
  • Select either the input scalar or c0, depending on the comparison result %0, assigning it to %1
  • Return the %1.

Seems easy enough? hahhahahaha

Alright what about tiling and looping actually?

Let's talk about tiling first. I first learned about tiling through the PMPP book's example on matrix multiplcation. The transform tutorial couldn't have put it any better:

Tiling, in general, can be seen as partitioning the iteration space into smaller parts, or tiles, so that the data required by each part fits into a level of cache for example. The order in which tiles are executed must preserve the original data dependencies.

%0 = scf.forall (%i, %j) in (4, 2)
     shared_outs(%shared = %init) -> (tensor<8x16xf32>) {

  // Scale the loop induction variables by the tile sizes.
  %3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i)
  %4 = affine.apply affine_map<(d0) -> (d0 * 8)>(%j)

  // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced.
  %lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1]
             : tensor<8x10xf32> to tensor<2x10xf32>
  %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1]
             : tensor<10x16xf32> to tensor<10x8xf32>
  %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1]
                : tensor<8x16xf32> to tensor<2x8xf32>

  // This is exactly the same operation as before, but now operating on smaller
  // slices of data.
  %partial =  linalg.generic {
  indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                   affine_map<(i, j, k) -> (k, j)>,
                   affine_map<(i, j, k) -> (i, j)>],
  iterator_types = ["parallel", "parallel", "reduction"]
  } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>)
    outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> {
  ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32):
    %0 = arith.mulf %lhs_one, %rhs_one : f32
    %1 = arith.addf %init_one, %0 : f32
    linalg.yield %1 : f32
  } : tensor<2x8xf32>

  // Terminator for the loop with tensor-insertion semantics. Inserts a slice
  // into a larger tensor, potentially in parallel.
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %partial into %shared[%3, %4] [2, 8] [1, 1]
        : tensor<2x8xf32> into tensor<8x16xf32>
  }
}

scf.forall evaluates a block multiple time in parallel.

A tile in the transform IR is a slice of the original data:

In the case of linalg.generic operations, the iteration space is implicit and is defined by the shape of the operands. Therefore, a tile can be expressed by performing the same operation on a subset (slice) of the original data

Arghhhh i hate this i don't get thissssss arghhhhhhhhhhh what the hell is going onnnnnnnnnnn.

gonna go watch netflix and ice cream and takoyaki