Optimizing a WebGPU Matmul Kernel for 1TFLOP+ Performance
Building Surfgrad, a high-performant, WebGPU-powered autograd library
I work at Nomic, where many of my colleagues work on building large TSNE-like visualizations work in the browser1. Showing tens of millions of data points in the browser without rendering your computer an oven is no easy challenge. I overhear many of the scaling problems solved by Deepscatter, first developed by Ben Schmidt.
However, many conversations that I overhear tend to revolve around Typescript and how awesome WebGPU is. At the time of writing, I couldn’t find any autograd libraries built with WebGPU2. So as an educational exercise to learn WebGPU and Typescript, I decided to build Surfgrad3, a high-performant, WebGPU-powered autograd library that enables browser-based tensor operations.
In this post, I’ll cover how I optimized a naive WebGPU Matrix Multiplication (matmul) Kernel to 1TFLOPS+ of arithmetic intensity4. The goal isn’t to build the fastest autograd library, but to show the nuances of WebGPU and how it might differ from CUDA.
Perhaps in the future, we can even use Surfgrad for running the next Llama models.
What is WebGPU?
WebGPU is an API designed for people to write GPU code that runs on any phone or computer with a web browser. Previously, people hacked around WebGL to run machine learning workloads like rendering invisible canvas and reading numbers as colors. Now people can take advantage of the increasing power of GPUs5 in laptops and run compute kernels (e.g. data in, data out without any funny business).
WebGPU was created to give the “compute” shader first-class support and open the doors for in-browser, private machine learning development.
The compute (and vertex and fragment) shaders are written in WGSL. WGSL is designed for developers to write a single shader that gets compiled to lower level languages like SPIR-V for Vulkan and MSL for Metal.
Ben’s also written some great articles on what WebGPU is and why it’s important:
WebGPU vs. CUDA
NVIDIA is the most popular choice for hardware and CUDA, its API, is one of the reasons for it but their API only works on NVIDIA hardware.
WebGPU and NVIDIA share similar terminologies, but don’t have the exact same functionality. WebGPU just introduced support for subgroups which allows threads within a group to efficiently share data, which is a big win for things like matrix multiplies where you may recalculate similar values.
WebGPU also sits a half step above CUDA in that it can compiles to other GPU languages like Vulkan and Metal. It’s kind of like React Native for GPU compute shaders.
WebGPU Compute Shader Basics
The smallest unit is a thread which executes the compute shader.
workGroups are groups of threads: they are grouped together and run in parallel (they’re called threadBlocks in CUDA). They can access the same shared memory.
WebGPU can dispatch many of these workGroups at once, whereas CUDA calls this a Grid (which is made of threadBlocks).
Similarly to CUDA, workGroups and dispatching work groups are defined in 3D. The size of a workGroup is defined by @workgroup_size(x, y, z)
where the number of threads per workgroup is x * y * z
.
Writing a Fast Matrix Multiply
Matrix multiplications makes up most of the floating point operations per second (FLOPs) in Large Language Models like GPT-4 and Llama. It is the basic primitive for most training and inference workloads.
Native WebGPU support for Matrix Multiply is limited to small matrices, which aren’t useful for modern Deep Learning workloads when your matrices can be large6.
A quick few notes on notation.
Matrix Multiply
First, a matrix multiply is defined by three matrices: A, B, C.
The total FLOPs required of a matrix multiply are 2 * M * K * N
as each operation requires both a multiply and an add (hence the 2).
Lower Bounding Our Kernel
Following the example from Simon Boehm's great article, we have two 4092x4092 matrices followed by the addition of a 4092x4092 matrix. Similarly, we have
Total FLOPS: 137GFLOPs
Total data to read: 201MB
Total data to store: 67MB
However, I am developing on a Mac M2 Pro which has ~6 TFLOP/s of arithmetic intensity and 200GB/s of memory bandwidth.
So, the fastest the compute kernel can take is
(137GFLOP) / (6TFLOPS/s) = 22ms
and memory access takes
(267MB) / (200GB/s) = 1.34ms
so we should be compute bound (by ~16x too!).
Writing the Kernel
Kernel 1: Naive Kernel
The simplest way to compute a dot product between matrix A and B and write to matrix C is for each row in A (of shape M), iterate over the columns of A (of shape K) and multiply by the corresponding value of B. In Python, this looks like
def matmul(a, b, c):
"""
Perform naive matrix multiplication: C = A * B
:param a: Input matrix A of shape (m, k)
:param b: Input matrix B of shape (k, n)
:param c: Output matrix C of shape (m, n) to store the result
"""
m = len(a)
k = len(a[0])
n = len(b[0])
# Perform the matrix multiplication
for i in range(m):
for j in range(n):
c[i][j] = 0
for l in range(k):
c[i][j] += a[i][l] * b[l][j]
Similar to the Python code above, we define7 our inputs8
struct Dimensions {
M: u32,
K: u32,
N: u32,
}
@group(0) @binding(0) var<uniform> dimensions: Dimensions;
@group(0) @binding(1) var<storage, read> a: array<f32>;
@group(0) @binding(2) var<storage, read> b: array<f32>;
@group(0) @binding(3) var<storage, read_write> result: array<f32>;
and our compute kernel:
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let index = global_id.x;
let row = index / dimensions.N;
let col = index % dimensions.N;
if (index < dimensions.M * dimensions.N) {
var sum = 0.0;
for (var i: u32 = 0u; i < dimensions.K; i = i + 1u) {
sum = sum + a[row * dimensions.K + i] * b[i * dimensions.N + col];
}
result[row * dimensions.N + col] = sum;
}
}
The code is functionally equivalent to the Python code above! We define how big our workGroup size is with workgroup_size(1)
(remember this is represented in 3D).
So, each workGroup, since it’s only one thread, processes one result[i, j]
.
To calculate the full matrix, we need to launch as many entries as there are in the matrix and call dispatchWorkgroups 9
pass.dispatchWorkgroups(a.shape[0] * b.shape[1])
where a.shape == M, b.shape[1] == N
for (most) any MxN matrix.
Now as we see below, we have lots of room for improvement!
The largest square matrix multiply we can calculate is 128x128 due to limits in WebGPU (more on this later). We only achieve 1.64 GFLOPS/s a far cry from the theoretical max of 6 TFLOPS/s.
Why is this kernel so slow? In effect, each workgroup calculates a single entry of the 16,384 total elements (128^2). Although we are running in parallel, each workGroup loads its own copy of the matrices. The overhead to launch more workGroups is likely more than if our workGroup had more threads and calculated more results per workGroup and each workGroup isn’t able to take advantage of any caching of the inputs.
Kernel 2: Moarrr Threads!
With the first kernel, we’re only able to compute small square matrices due to limits on the number of workGroups (maxComputeWorkgroupsPerDimension) you can dispatch at once.
Since we’re launching one workgroup per entry, a 256x256 matrix is larger than our limit!
Remember this line?
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
We can reduce the number of dispatched workGroups by increasing the number of threads per workGroup!
If we update our code
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
we can reduce the number of total dispatched workGroups per dimension:
const WORKGROUP_SIZE = 256;
pass.dispatchWorkgroups((a.shape[0] * b.shape[1]) / WORKGROUP_SIZE);
Why 256? Well, there’s another limit :)
Increasing the workgroupSize, we’re able to improve our kernel by 200x!
Kernel 3: Calculating with 2D workGroups
However doing all the computation in “1 dimension” limits the matrix size we can calculate10
Although we don’t change much about our code, if we distribute our work in 2 dimensions we’re able to bypass these limits and launch more workGroups that are larger. This allows us to calculate a 4096x4096 matmul.
We update our @workgroup_size(8, 8)
, check our bounds,
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.x;
let col = global_id.y;
if (row < dimensions.M && col < dimensions.N) {
var sum : f32 = 0.0;
for (var i: u32 = 0u; i < dimensions.K; i = i + 1u) {
sum = sum + a[row * dimensions.K + i] * b[i * dimensions.N + col];
}
result[row * dimensions.N + col] = sum;
}
}
and dispatch workgroups in 2D
const WORKGROUP_SIZE = 16;
pass.dispatchWorkgroups(
Math.ceil(a.shape[0] / WORKGROUP_SIZE),
Math.ceil(b.shape[1] / WORKGROUP_SIZE),
);
But this is slower than our original kernel! What’s going on?
If we make a small change to the code
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.y;
let col = global_id.x;
we get much better kernel performance.
Why is this? We’re able to take more advantage of cached inputs. The x dimension is incremented before the y dimension in the global_invocation_id
and therefore more threads in each workgroup use the same row in matrix A. Otherwise, the row variable is overwritten at each invocation within the workGroup and each thread has to spend a few extra cycles to read from global memory rather than the cache.
Kernel 4: Kernel Tiling
Another thing to consider is how much work each thread does.
Up to now, each thread only computes one entry. But there is some overhead to launching each workGroup versus computing more than 1 element per thread!
If calculating more elements per thread is faster than the overhead to launch each workGroup, we should see a big speedup11.
To do so, we calculate 4 results per thread (e.g. a 1x4 Tile).
const BLOCKSIZE: u32 = 16;
const TILESIZE: u32 = 4;
@compute @workgroup_size(BLOCKSIZE, BLOCKSIZE)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.y;
let col = global_id.x * TILESIZE;
if (row >= dimensions.M || col >= dimensions.N) {
return;
}
var sum00: f32 = 0.0;
var sum01: f32 = 0.0;
var sum02: f32 = 0.0;
var sum03: f32 = 0.0;
for (var i: u32 = 0u; i < dimensions.K; i = i + 1u) {
let a_elem = a[row * dimensions.K + i];
sum00 = sum00 + a_elem * b[i * dimensions.N + col];
sum01 = sum01 + a_elem * b[i * dimensions.N + col + 1u];
sum02 = sum02 + a_elem * b[i * dimensions.N + col + 2u];
sum03 = sum03 + a_elem * b[i * dimensions.N + col + 3u];
}
result[row * dimensions.N + col] = sum00;
result[row * dimensions.N + col + 1u] = sum01;
result[row * dimensions.N + col + 2u] = sum02;
result[row * dimensions.N + col + 3u] = sum03;
}
The kernel looks roughly the same as before except we’ve unrolled the computation and are calculating TILESIZE
results per thread.
We can take this a step further and calculate 2D results per thread! Instead of calculating 4 elements per single row, we can calculate 4 elements for 4 rows (e.g. a 2D tile).
const BLOCKSIZE: u32 = 16;
const TILE_M: u32 = 4; // Tile size in M dimension
const TILE_N: u32 = 4; // Tile size in N dimension
@compute @workgroup_size(BLOCKSIZE, BLOCKSIZE)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.y * TILE_M;
let col = global_id.x * TILE_N;
// initialize the array with all 0s
var sums: array<array<f32, TILE_N>, TILE_M>;
for (var i = 0u; i < TILE_M; i++) {
for (var j = 0u; j < TILE_N; j++) {
sums[i][j] = 0.0;
}
}
// Compute the 2D tile
for (var k = 0u; k < dimensions.K; k++) {
// for each row
for (var i = 0u; i < TILE_M; i++) {
let a_element = a[(row + i) * dimensions.K + k];
// calculate the dot product
for (var j = 0u; j < TILE_N; j++) {
let b_element = b[k * dimensions.N + (col + j)];
sums[i][j] += a_element * b_element;
}
}
}
// Write results
for (var i = 0u; i < TILE_M; i++) {
for (var j = 0u; j < TILE_N; j++) {
let output_row = row + i;
let output_col = col + j;
if (output_row < dimensions.M && output_col < dimensions.N) {
result[output_row * dimensions.N + output_col] = sums[i][j];
}
}
}
}
Each thread now calculates a 4x4 grid of the output matrix and we see a slight improvement over the last kernel.
Surprisingly, 2D tiling is quite slow. Why haven’t we amortized the time it takes to launch workGroups by doing more work? And why are we slower than doing one item of work per thread?
Kernel 5: Unrolling
To answer the last question, we will need to dig into the compiled WebGPU kernels.
Some compilers will automatically unroll loops if the bounds of the loop are known at compile time. However we’ve been writing a general kernel for variable shaped inputs!
Also when writing at WGSL, we don’t have any control over the directives of the compiler.
Looking at the assembly bitcode compiled from Metal, we can see that the instruction set still includes the for loop!
%51 = phi i32 [ 0, %41 ], [ %61, %50 ]
%52 = add i32 %37, %51
%53 = zext i32 %52 to i64
%54 = getelementptr inbounds [1 x float], ptr addrspace(1) %3, i64 0, i64 %53
%55 = load float, ptr addrspace(1) %54, align 4, !tbaa !27, !alias.scope !43, !noalias !44
%56 = zext i32 %51 to i64
%57 = getelementptr inbounds %struct.type_5, ptr %7, i64 0, i32 0, i64 %49, i32 0, i64 %56
%58 = load float, ptr %57, align 4, !tbaa !27
%59 = fmul fast float %55, %48
%60 = fadd fast float %58, %59
store float %60, ptr %57, align 4, !tbaa !27
%61 = add nuw nsw i32 %51, 1
%62 = icmp eq i32 %61, 4
br i1 %62, label %38, label %50 // branching for loop
Whereas the unrolled WGSL code gets compiled to
...
%141 = fmul fast float %112, %103
%142 = fadd fast float %141, %82
%143 = fmul fast float %116, %103
%144 = fadd fast float %143, %81
%145 = fmul fast float %120, %103
%146 = fadd fast float %145, %80
%147 = fmul fast float %124, %103
%148 = fadd fast float %147, %79
%149 = fmul fast float %112, %107
%150 = fadd fast float %149, %78
%151 = fmul fast float %116, %107
%152 = fadd fast float %151, %77
%153 = fmul fast float %120, %107
%154 = fadd fast float %153, %76
%155 = fmul fast float %124, %107
%156 = fadd fast float %155, %75
%157 = add nuw i32 %91, 1
%158 = icmp eq i32 %157, %27
br i1 %158, label %159, label %74
Because of the manual unrolling, the GPU is able to reduce overhead by not having to initialize and increment the inner loop, take advantage of instruction level parallelism, and amortize the cost of launching fewer workGroups by doing more work per thread. When we had our loop, the kernel (#4) wasn’t able to take advantage of these optimizations and was slower than just launching more workGroups (#3).
And if we make our grid 8x8, we get a 3x boost over the 4x4 loop and surpass 1TFLOP!
Conclusion
Through our efforts, we were able to build a performant matmul kernel that is 1000x faster than the naive kernel and approach Apple M2 Pro’s theoretical peak.
And with frequent updates to WebGPU, there are still optimizations to be made! For example, we didn’t take advantage of subgroups, a feature that is new as of Chrome 125 and should allow for faster memory access and sharing across subgroups to reduce repeated computations.
And a big thank you to Abhishaike Mahajan (who writes an incredible blog) and Elman Mansimov for feedback and encouragement to writing this article!
Visualizing these 2-dimensional maps pose two problems: projecting (e.g. TSNE and UMAP) into a 2D coordinate system is slow and not RAM friendly as you increase dataset size and visualizing millions of datapoints in the browser without turning your laptop into a toaster.
I would be remiss to not mention two repos that do similar thing: webGPT (Transformer based inference only) and webgpu-blas (fast matmul kernels in webGPU).
The format of the blog follows a similar path to Simon Boehm’s article on Optimizing a CUDA Matmul Kernel.
Apple’s M3 Pro has a reported ~7TFLOPS. You can even run Llama3.2 (with ONNX) in your browser with 85 tokens/s
For reference, Llama 3.1 70B has matrices of size (8192x28672)
There’s quite a bit of boilerplate for running WebGPU code from Typescript, which I’ll leave for the curious to explore: https://webgpufundamentals.org/webgpu/lessons/webgpu-fundamentals.html
To simplify the article and amount of code, I removed much of the boilerplate code needed to setup the GPU buffers and only focus on things required for understanding how I optimized WGSL kernels.
Due to another limitation: maxComputeWorkgroupsPerDimension.
And this is something Apple suggests when building compute kernels