This post covers my work optimizing a Grouped GEMM kernel for GPUMode's NVFP4 kernel competition.


GPUMode recently ran a competition around several GEMM-like workloads using NVFP4, a 4-bit floating-point format introduced with Blackwell. I submitted kernels to both the Dual GEMM and Grouped GEMM leaderboards.

This post focuses on the Grouped GEMM workload and the kernel optimizations that improved my submission.


On the Grouped GEMM problem, my final kernel placed 6th on the NVIDIA leaderboard and 18th on the Modal leaderboard. I'll go through the main optimizations I used (TMA, warp specialization, thread block clusters), then cover what still separates this kernel from the top entries.


Foundations of the Kernel

First, a quick look at the workload itself: Grouped GEMM and the NVFP4 data format.


Grouped GEMM and MoE

In Mixture-of-Experts transformers, each token is routed (by a small gating network) to one or more experts, so different experts receive different numbers of tokens every step.

MoE Token Routing
After top-k routing, tokens are reorganized by expert, creating variable-sized per-expert input batches.

As a result, the sequence length dimension M (the number of tokens routed to a specific expert) varies widely across experts. The projection dimensions K (input hidden state) and N (expert intermediate size) remain uniform across the layer.

So the MoE forward pass decomposes into many independent GEMM operations (one per expert), all sharing K and N but with different M values.

Referring to the routing diagram above:

ExpertTokensGEMM
1T1, T2, T3M=3
2T1M=1
3T2, T4M=2
Per-expert GEMM
Each expert receives a different number of tokens, so the GEMMs have different M dimensions, while K and N are shared across experts for a given projection.

A grouped GEMM is a single GPU kernel launch that executes a list of independent GEMMs together. In MoE, that maps to routing: each dispatched expert gets its own token subset, so each expert corresponds to one GEMM in the launch list.


For expert e, the FFN projection is:

1Y_e = X_e @ W_e

where X_e contains activations for tokens routed to expert e, W_e is the expert's weight matrix, and Y_e is the output. A grouped GEMM bundles these into one batched computation:

1{ X_1 @ W_1, X_2 @ W_2, X_3 @ W_3, ... }

Instead of paying launch overhead for multiple independent kernels, grouped GEMM combines them into one launch. The host passes an array of problem descriptors:

1(A_1, B_1, C_1, M_1, N, K)
2(A_2, B_2, C_2, M_2, N, K)
3(A_3, B_3, C_3, M_3, N, K)

At the kernel level, the output space of each expert GEMM is partitioned into block-level tiles. Those tiles are then merged into one flattened global work queue.

Streaming Multiprocessors (SMs) pull from that queue and interleave tiles from different experts to keep occupancy up.


NVFP4

To ease memory bandwidth and capacity pressure in large models, Blackwell introduced NVFP4, a 4-bit floating-point data type. By quantizing weights and activations to 4 bits with block-level scaling, NVFP4 shrinks memory traffic and boosts throughput, with minimal impact on accuracy.

At the bit level, NVFP4 is an E2M1 representation (1 sign bit, 2 exponent bits, and 1 mantissa bit), yielding an effective dynamic range of approximately −6 to 6.

NVFP4 Format

Because 4 bits doesn't give you much dynamic range, NVFP4 relies on block-level scaling. Every block of 16 elements gets a higher-precision scale factor. The true value can be recovered with x_fp = x_nvfp4 * scale.

For the kernel, this means handling two memory streams per operand: while loading the quantized A and B matrices, it also has to asynchronously fetch their scale factors (SFA and SFB) from global memory into shared memory.


Kernel Worklog

The GPUMode competition evaluated kernels on four input shapes, each stressing a distinct workload. I'll use those same four shapes throughout the post.

Shape A – Deep-K

1g=8, K=7168, N=4096, mixed M=[80,176,128,72,64,248,96,160]

This shape has a large inner dimension (K=7168) paired with small M values, so the inner loop runs for a long time.

It stresses deep-K pipeline behavior and sustained producer/consumer overlap.

Shape B – Wide-N

1g=8, K=2048, N=7168, mixed M=[40,76,168,72,164,148,196,160]

This shape has a large number of output columns (N=7168) and very small M values, so memory bandwidth is dominated by streaming matrix B and its scale factors.

It stresses wide-N behavior, B/SFB traffic, and work distribution across many tiles.

Shape C – Mid-sized

1g=2, K=4096, N=3072, M=[192,320]

This is a mid-sized group where M, N, and K are all moderately large. It is useful for checking whether mainloop improvements show up in end-to-end latency.

Shape D – Shorter-K

1g=2, K=1536, N=4096, M=[128,384]

This shape has a very short inner loop, so fixed kernel costs dominate execution.

It heavily penalizes inefficiencies in CPU launch overhead, pipeline draining at the end of the loop, and the final memory-write phase (epilogue).


Here is how the kernel improved on those four shapes as the implementation evolved (all times in µs):

VersionShape AShape BShape CShape DGeo mean
v164145114280.2240
v231328611163.7159 (-34%)
v312912843.622.663.5 (-60%)
v447.545.014.310.523.8 (-63%)

Kernel reference

The competition organizers provided a reference implementation in CuTeDSL. Its benchmark numbers were:

VersionShape AShape BShape CShape D
Reference351 µs341 µs167 µs146 µs

The geometric mean is 238 µs, which is the starting point.


v1: Correctness-First Baseline

My first version (v1) was a correctness-first CUDA/C++ implementation of the reference kernel. I used TMA (Tensor Memory Accelerator) to bulk-load tiles from global memory into shared memory asynchronously.

To hide the memory latency, I set up a 2-stage software pipeline using asynchronous mbarriers to overlap the TMA loads with tensor core compute.

For the actual compute, I used tcgen05.mma for mixed-precision NVFP4 math. It writes results directly into Tensor Memory (TMEM), which avoids putting that accumulation pressure on regular registers.

Host-side

To make use of Blackwell's Tensor Memory Accelerator, the kernel must describe the memory layout of its matrices to the hardware. Instead of standard pointers, TMA uses Tensor Maps.

1// Host-side Tensor Map descriptor for a tiled NVFP4 operand.
2cuTensorMapEncodeTiled(
3    &tmap,
4    CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B,
5    3, (void*)ptr, globalDim, globalStrides, boxDim, elementStrides,
6    /* config flags */
7);

The important fields are the global tensor shape, the global strides, and the tile shape. Encoding those on the CPU hands index arithmetic and out-of-bounds handling off to dedicated hardware on the GPU.

Bulk Asynchronous Data Loads

On the device side, a single thread issues a hardware command to move an entire 2D or 3D tile with TMA. The usual alternative is to have every thread cooperatively load a few bytes into SMEM.

1__device__ inline void tma_3d_gmem2smem(
2    int dst_smem, const void *tmap_ptr, int x, int y, int z, int mbar_addr) {
3  asm volatile(
4    "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes... "
5    "[%0], [%1, {%2, %3, %4}], [%5], %6;"
6    :: "r"(dst_smem), "l"(tmap_ptr), "r"(x), "r"(y), "r"(z), "r"(mbar_addr) ...
7  );
8}

The cp.async.bulk.tensor PTX instruction tells the TMA engine to fetch the exact sub-tile directly from global memory to shared memory (dst_smem). It runs fully asynchronously in the background. When the transfer finishes, it signals a hardware barrier (mbarrier_addr).

Computing the MMA

Once the FP4 data and scale factors are in shared memory, the kernel passes them to the tensor cores (tcgen05).

1__device__ __forceinline__ void tcgen05_mma_nvfp4(
2  int d_tmem, uint64_t a_desc, uint64_t b_desc, ...
3) {
4  asm volatile(
5    "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 "
6    "  [%0], %1, %2, %3, [%4], [%5], p;"
7    :: "r"(d_tmem), "l"(a_desc), "l"(b_desc) ...
8  );
9}

Note that the destination accumulator ([%0] / d_tmem) is not a standard GPU register. It points to TMEM (Tensor Memory), a specialized on-chip memory space used to hold accumulated matrix results without blowing up register pressure.

Double-buffered main loop

To maximize performance, the memory fetches (TMA) and the math (MMA) must happen simultaneously. The baseline does this with a two-stage software pipeline.

1for (int k_iter = 0; k_iter < num_k_iters; k_iter++) {
2  int stage = k_iter % 2;
3  int next_stage = (k_iter + 1) % 2;
4
5  // Wait for the hardware barrier to confirm TMA has loaded 'stage'
6  if (tid == 0) mbarrier_wait(tma_mbar[stage], phase);
7  __syncthreads();
8
9  // Tell TMA to start fetching the *next* block of data in the background
10  if (tid == 0) issue_tma_loads(..., next_stage);
11
12  // Command tensor cores to multiply the *current* block of data
13  if (tid == 0) issue_mma(..., stage);
14  __syncthreads();
15}

Because TMA and tcgen05 are both asynchronous, the loop above mainly acts as an orchestrator. While the tensor cores compute the MMA for stage 0, the memory subsystem is already fetching the data for stage 1.

That overlap hides memory latency and keeps the compute units busy.


The benchmark timings for this kernel were:

VersionShape AShape BShape CShape D
v1641 µs451 µs142 µs80.2 µs

Bottleneck: Host Orchestration Overhead

Even though the device-side kernel already used TMA and tcgen05, the host side was still wasting a lot of time.

The host processed each problem one by one, paying metadata setup, launch, and synchronization overhead for every expert. Once experts get small (low M), that sequential overhead becomes a large fraction of total runtime.

1// v1: process each independent problem one by one.
2for (int64_t prob_idx = 0; prob_idx < G; prob_idx++) {
3  build_work_items(prob_idx, work_items);
4  cudaMemcpy(d_work_items, work_items.data(), ...);  // per-problem metadata copy
5  grouped_gemm<<<...>>>(...);
6  cudaDeviceSynchronize();                          // per-problem sync
7}

Each launch is cheap in isolation, but the repeated serialization accumulates to meaningful excess latency.


v2: Persistent Kernel

To fix that, v2 moves to a persistent kernel. Instead of launching one kernel per problem, it flattens all tiles from every problem into one global work queue and batches the metadata into a single host-to-device copy.

This amortizes launch overhead across the whole MoE layer. With one kernel launch processing the entire queue, scheduling stays on the GPU and the CPU bottleneck mostly disappears.

On the host side, dispatch changes from a loop of launches to a single launch:

1// Flatten tiles from all problems into one persistent work queue
2std::vector<WorkItem> all_work_items;
3
4for (int64_t prob_idx = 0; prob_idx < G; prob_idx++) {
5  // ... nested loops over M and N tiles ...
6  // Each work item tracks which problem it belongs to
7  all_work_items.push_back({(int)prob_idx, tm, tn});
8}
9
10// Host->device copy of all metadata and worklists in a single batch
11cudaMemcpy(d_problem_infos, problem_infos.data(), ...);
12cudaMemcpy(d_work_items, all_work_items.data(), ...);
13
14// Launch a single persistent kernel over all work items across all problems
15grouped_gemm<<<num_items, ...>>>(...);
16
17// A single synchronization point at the end for all problems
18cudaDeviceSynchronize();

On the device side, each block uses blockIdx.x to grab a work item from the flattened list, look up its problem metadata, and then run the usual TMA + MMA loop.

1__global__ void grouped_gemm(
2  const ProblemInfo* __restrict__ global_probs,
3  const WorkItem* __restrict__ work_items,
4  int num_items
5) {
6  const int global_idx = blockIdx.x;
7  if (global_idx >= num_items) return;
8
9  // Use the global index to fetch the exact tile and problem metadata
10  const WorkItem& work = work_items[global_idx];
11  const ProblemInfo& prob = global_probs[work.problem_idx];
12
13  const int m_offset = work.tile_m * TMA_BLOCK_M;
14  const int n_offset = work.tile_n * TMA_BLOCK_N;
15
16  // Proceed with TMA loads & math for this specific tile...
17}

This one architectural change was enough to cut runtime roughly in half on the larger grouped cases.


The benchmark timings for this kernel were:

VersionShape AShape BShape CShape D
v1641 µs451 µs142 µs80.2 µs
v2313 µs (-51%)286 µs (-37%)111 µs (-22%)63.7 µs (-21%)

v3: Warp Specialization and Dynamic Scheduling

v2 improved host-side performance considerably, but the device-side kernel still had a lot of optimization room left. v3 focused on three areas:

  1. The mainloop pipeline was too shallow and intra-CTA synchronization was too restrictive.
  2. Memory traffic was inefficient on skewed shapes like Wide-N.
  3. Fixed overheads in scheduling and the epilogue were hurting performance on shorter problems.

Mainloop: Warp Specialization and Pipeline Deepening

In v2, the software pipeline was limited to 2 stages, and the thread block executed a synchronized control flow loop.

In v3, I rewrote the mainloop around warp specialization.

First, I bumped the pipeline from 2 stages to 4 stages to keep more K-tiles in flight. All the stages are primed up front:

1// v2
2constexpr int NUM_STAGES = 2;
3
4// v3
5constexpr int NUM_STAGES = 4;
6
7for (int k_iter = 0; k_iter < NUM_STAGES && k_iter < num_k_iters; k_iter++) {
8  issue_tma(k_iter, k_iter);
9}

This deeper pipeline keeps more K-tiles in flight at once. With only 2 stages, a small delay in TMA completion or barrier arrival can stall tensor core issue almost immediately. With 4 stages, that same delay is amortized across a larger prefetch window, so the consumer is much less likely to run out of work. This matters most on deep-K problems, where small bubbles repeated many times turn into significant accumulated latency.


Second, I decoupled the work inside the block using warp specialization. Instead of forcing the whole thread block to step through a synchronized load/compute cycle, I split the warps: one producer warp issues the TMA loads, and a consumer warp runs the math:

1// Warp 4: producer
2if (warp_id == TMA_WARP && elect_sync()) {
3  for (int k_iter = NUM_STAGES; k_iter < num_k_iters; k_iter++) {
4    const int stage = k_iter % NUM_STAGES;
5    const int mma_phase = (k_iter / NUM_STAGES - 1) % 2;
6    mbarrier_wait(mbar_base + (NUM_STAGES + stage) * 8, mma_phase);
7    issue_tma(k_iter, stage);
8  }
9}
10
11// Warp 5: consumer
12else if (warp_id == MMA_WARP && elect_sync()) {
13  for (int k_iter = 0; k_iter < num_k_iters; k_iter++) {
14    const int stage = k_iter % NUM_STAGES;
15    const int tma_phase = (k_iter / NUM_STAGES) % 2;
16    mbarrier_wait(mbar_base + stage * 8, tma_phase);
17    tcgen05_mma_nvfp4(...);
18    tcgen05_commit(mbar_base + (NUM_STAGES + stage) * 8);
19  }
20}

This is a better match for the hardware. TMA and tcgen05 are both asynchronous engines, so forcing the entire block to reconverge around them just creates avoidable stalls. Letting one warp drive memory and another drive the MMA makes the producer-consumer split much cleaner.


Third, with a deeper, decoupled pipeline, the kernel needs precise mbarrier phase tracking. Since the pipeline is a circular buffer, the barrier parity has to exactly match the reuse cadence to avoid overwriting data the consumer hasn't read yet:

1const int stage = k_iter % NUM_STAGES;
2const int mma_phase = (k_iter / NUM_STAGES - 1) % 2;
3mbarrier_wait(mbar_base + (NUM_STAGES + stage) * 8, mma_phase);
4issue_tma(k_iter, stage);

This explicit tracking makes sure the producer only blocks when it has to. Together, these changes tightened the steady-state mainloop and showed up clearly on deep-K workloads (Shape A).


Memory Traffic: Dynamic L2 Cache Eviction Policies

On Wide-N shapes, the kernel becomes memory-bound just from streaming matrix B and its scale factors. In that regime, cache behavior matters.

To get more out of the available bandwidth, I used TMA cache eviction hints (EVICT_FIRST, EVICT_LAST). The rule is simple: keep the smaller operand, which has higher reuse, in L2 and stream the larger one:

1uint64_t cache_a, cache_b;
2if (prob.m > prob.n) {
3  cache_a = evict_first;
4  cache_b = evict_last;
5} else {
6  cache_a = evict_last;
7  cache_b = evict_first;
8}
9
10tma_3d_gmem2smem<1>(..., cache_a);
11tma_3d_gmem2smem<1>(..., cache_b);

When the problem is Wide-N, matrix A is the smaller reusable operand, so the kernel should keep it hot in cache and stream B. On taller problems, it's the other way around. The policy just needs to follow the geometry of the problem.

This reduces unnecessary DRAM traffic by improving L2 hits on the operand with the most reuse. The gain shows up clearly on Shape B.


Reducing fixed overheads

Even with a better mainloop and memory system, the kernel was still paying unnecessary fixed costs.

One big source was scheduling. In v2, flattening the work removed the host launch overhead, but tile assignment was still static (blockIdx.x). That would be fine if all tiles cost the same, but MoE is irregular: some experts have more tiles, some have longer K loops, and some finish much faster.

To fix this, v3 moves to true persistent scheduling using a global atomic counter, rather than mapping exactly to blockIdx.x:

1__shared__ int shared_work_idx;
2
3while (true) {
4  if (tid == 0) {
5    shared_work_idx = atomicAdd(work_counter, 1);
6  }
7  __syncthreads();
8
9  int work_idx = shared_work_idx;
10  if (work_idx >= num_items) break;
11
12  // ... process tile ...
13}

Now, as soon as an SM finishes a tile, it grabs the next available one. That smooths out tail effects and improves load balance across the irregular grouped workload.

The other fixed cost was the epilogue. On shorter-K shapes, the MMA loop is not long enough to hide synchronization, pipeline drain, and final stores.

I specialized the epilogue for the common paths to cut this down:

1constexpr int LOW_M_THRESHOLD = 96;
2const bool low_m = (prob.M <= LOW_M_THRESHOLD);
3
4if (low_m) epilogue_store<true>(...);
5else       epilogue_store<false>(...);
6
7if (full_tile && contiguous) {
8  half2* row0_ptr = reinterpret_cast<half2*>(C_ptr + row0 * Cs0 + n_offset);
9  row0_ptr[h2_idx] = __halves2half2(...);
10}

Low-M tiles should not pay for the control-flow overhead of a fully generic store path. And if the output tile is full and contiguous, the kernel can write half2 instead of scalar halves, which reduces instruction count and improves coalescing.

These optimizations matter most when the K-loop is short, which is why Shape D improved so much.


The result is a much better balanced v3, with a more asynchronous mainloop and much lower fixed overheads.

VersionShape AShape BShape CShape D
v2313 µs286 µs111 µs63.7 µs
v3129 µs (-59%)128 µs (-55%)43.6 µs (-61%)22.6 µs (-65%)

v4: Cluster Multicast and Parallel TMA Producers

v3 was a big step up, but it still was not fully saturating the hardware. For v4, I used a couple more Blackwell features: thread block clusters and SMEM multicast. I also split up the TMA workload to reduce producer stalls.


Parallel TMA Producers

In v3, one warp issued all the TMA descriptors for a stage (A, B, SFA, SFB). Even though they're non-blocking, issuing four TMA commands back-to-back takes a measurable amount of time. If the producer falls behind, the consumer warp ends up stalling on the barrier.

Since the CTA has spare warps, v4 splits this work across two dedicated producer warps:

1// Warp 4: producer for A and SFA
2constexpr int TMA_WARP = 4;
3// Warp 6: producer for B and SFB
4constexpr int TMA_WARP_B = 6;
5constexpr int MMA_WARP = 5;
6
7if ((warp_id == TMA_WARP || warp_id == TMA_WARP_B) && elect_sync()) {
8  const bool do_A = (warp_id == TMA_WARP);
9  const bool do_B = (warp_id == TMA_WARP_B);
10
11  if (do_A) {
12    tma_3d_gmem2smem<1>(stage_base, &prob.A_tmap, ...);
13    tma_gmem2smem(stage_base + SFA_off, SFA_src, ...);
14  } else if (do_B) {
15    tma_3d_gmem2smem<1>(stage_base + B_off, &prob.B_tmap, ...);
16    tma_gmem2smem(stage_base + SFB_off, SFB_src, ...);
17  }
18}

This roughly cuts producer-side issue latency in half, so the tensor cores spend much less time waiting for data requests to be issued.


Thread Block Clusters and SMEM Multicast

The biggest measurable win in v4 came from thread block clusters. Clusters let multiple CTAs scheduled on the same GPC communicate over a very fast interconnect.

In a standard GEMM, different CTAs often read the exact same input data. For example, CTAs computing output tiles (M_0, N_0) and (M_0, N_1) both need the exact same row of Matrix A. Usually, each CTA just redundantly fetches this from global memory.

Instead, the kernel can group 4 CTAs into a cluster and map them to the same M-tile across different N-tiles. Now, Matrix A only needs to be fetched from global memory once. The master CTA (Rank 0) issues the TMA load, and the hardware multicasts the data directly into the shared memory of all 4 CTAs over the cluster network.

1if (do_A) {
2  if constexpr (CLUSTER_SIZE > 1) {
3    // CTA rank 0 fetches from global memory and multicasts to the cluster
4    if (cta_rank == 0) {
5      uint16_t mc = (1 << CLUSTER_SIZE) - 1;
6      int cluster_dst = mapa_cta_to_cluster(stage_base, 0);
7      int cluster_mbar = mapa_cta_to_cluster(mbar_addr, 0);
8      tma_3d_gmem2smem_multicast(cluster_dst, &prob.A_tmap, 0, m_offset, off_k / 256, cluster_mbar, mc);
9    }
10  } else {
11    // ... non-cluster fallback ...
12  }
13}

This cuts Matrix A read traffic by up to 4x, which translates to meaningful improvement on memory-bound shapes.


Wide-N Optimization and Pipeline Tuning

Finally, v4 adds dynamic block sizing and pipeline depth selection. For shapes with large N dimensions, the kernel switches from a 128x128 tile to a wider 128x256 tile, computing twice as much output for every A tile loaded.

Because shared memory is tight, a wider tile means fewer pipeline stages fit. The kernel precomputes occupancy and picks either a deeper pipeline (for example, 6 stages for 128x128) for maximum overlap, or a shallower one (2-4 stages for 128x256) so the wider tile still fits.


End results

Combining parallel TMA producers, cluster multicast, and dynamic tile sizing brings the kernel into the 10-40 µs range.

VersionShape AShape BShape CShape D
v3129 µs128 µs43.6 µs22.6 µs
v447.5 µs (-63%)45.0 µs (-65%)14.3 µs (-67%)10.5 µs (-54%)

This cuts end-to-end latency by another 50-65% versus v3, bringing the final geometric mean down to 23.8 µs - roughly a 10x improvement over the initial baseline.


Gap to SoTA

My v4 kernel performed well in the competition, but I still missed several meaningful optimization opportunities.

After the competition ended, I read the winning kernel, which had a geometric mean of 16.029 µs on the NVIDIA leaderboard.

The winning entries are better because several design choices work together well on Blackwell:


1. Cooperative CTA-Level MMA (cta_group::2)

My v4 kernel only uses clusters on the load side, where Matrix A is multicast. Once the data arrives, each CTA still performs its own independent tcgen05.mma.cta_group::1 sequence.

The winning kernel uses cooperative tensor-core execution across 2 CTAs via cta_group::2. The cluster becomes part of the compute primitive itself.

1// My v4
2tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 ...
3
4// Winner
5tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 ...

Running the same MMA instruction stream cooperatively has several benefits:

  • A wider effective math tile.
  • Shared TMEM allocation across CTAs.
  • A larger, naturally aligned accumulator tile for the epilogue.
  • Amortized per-CTA control overhead.

By sharing the A tile across CTAs, the tensor core effectively works on a larger combined tile, which raises compute throughput. My v4 saves bandwidth but misses this more efficient cooperative MMA mode.


2. Output Matrix Transpose (M-major layout)

My kernel writes into a standard PyTorch row-major tensor (contiguous N). That is natural for the framework, but is suboptimal for the hardware.

TMEM readback and thread mapping naturally align with the M dimension. After the MMA finishes, accumulator fragments are arranged for contiguous walks along M, not N. In my v4 kernel, the epilogue wastes instructions remapping and scattering values to fit the N-contiguous destination.

The winning kernel allocates C in an M-major layout (padding M to a multiple of 16). The epilogue can simply stream results straight out in the hardware's native order.

1// Winner
2new_M = (M + 16 - 1) // 16 * 16
3new_C = torch.empty(new_M * N, dtype=torch.half, device="cuda")
4new_C = new_C.as_strided((M, N, 1), (1, new_M, 0))
5
6// My v4
7const bool contiguous = (Cs1 == 1);
8half2* row0_ptr = reinterpret_cast<half2*>(C_ptr + row0 * Cs0 + n_offset + col_base);

This host-side change dramatically simplifies the device-side epilogue, turning a slow scatter into a fast contiguous store.


3. Host-Side Problem Sorting

Grouped GEMM for MoE is highly irregular. Some experts receive many tokens, others very few. Even with persistent scheduling, that creates a tail effect: most SMs finish early and idle while a few work through the largest remaining experts.

The winning kernel mitigates this by sorting experts by descending M before launch:

1// Winner
2for (int i = 0; i < NUM_GROUPS; i++) {
3  values[i] = A_list[i].size(0);
4}
5argsort_desc<NUM_GROUPS>(values, indices);
6
7// My v4
8for (int64_t i = 0; i < G; i++) {
9  for (int tm = 0; tm < num_tiles_m[i]; tm++) {
10    for (int tn = 0; tn < num_tiles_n[i]; tn++) {
11      cached_work_items_256.push_back({(int)i, tm, tn});
12    }
13  }
14}

This simple host-side trick helps in two ways:

  1. Starting the longest-running experts early to maximize parallel overlap.
  2. Pushing shorter experts to the end to fill holes and smooth out the tail.

Persistent scheduling solves intra-launch assignment, but global work ordering still dictates how quickly the tail drains.


4. Vectorized PTX Epilogues with Cache Hints

While v4 uses half2 stores for contiguous tiles, it still relies on normal C++ pointers and leaves instruction selection to the compiler.

The winner uses inline PTX to emit 256-bit vector stores (st.global.v8.b32), writing 16 half values at once. This is made possible by the M-major layout, which allows threads to write large, aligned vectors without shared-memory transposes.

1// Winner
2asm volatile(
3  "st.relaxed.cta.global.L1::no_allocate%17.v8.b32 [%16], {%0, %2, %4, %6, %8, %10, %12, %14};"
4  ...
5);
6
7// My v4
8row0_ptr[h2_idx] = __float22half2_rn(make_float2(tmp[idx + 0], tmp[idx + 1]));
9row1_ptr[h2_idx] = __float22half2_rn(make_float2(tmp[idx + 2], tmp[idx + 3]));

Furthermore, the winner pairs these stores with explicit cache hints:

  • L1::no_allocate avoids polluting L1 with write-only output traffic.
  • .L2::evict_last prevents the result stream from evicting hot operand data needed by the mainloop.

Especially on short-K shapes (Shape D), wider stores and less cache disruption cut epilogue latency noticeably.


5. Fully Maximized Dynamic Pipeline Depth

My v4 kernel adapts pipeline depth using a small set of presets (for example, 6 stages for 128x128, 2-4 stages for 128x256).

The winner computes the exact number of stages that fit in the available SMEM budget for any given kernel geometry:

1// My v4
2constexpr int NS_DEEP_HI = 6;
3constexpr int NS_DEEP_LO = 3;
4constexpr int NS_WIDE_HI = 4;
5constexpr int NS_WIDE_LO = 2;
6
7// Winner
8constexpr int sm100_size = 227 * 1024;
9constexpr int dynamic_size = AB_size + SF_size + 2 * 8;
10constexpr int static_size = 3 * 2 * 8 + 4;
11constexpr int NUM_STAGES = (sm100_size - static_size) / dynamic_size;

On Blackwell (SM100), shared memory is generous enough that squeezing in one more pipeline stage can matter. It extends the prefetch window, keeps more K-tiles in flight, and gives the producer warp a bit more breathing room.

By leaving some shared memory unused, v4 gives up overlap it could have had. On this hardware, extra on-chip storage is often what buys you better latency hiding, so that tradeoff costs real performance.


Taken together, these missing pieces explain most of the remaining gap.