WebAssembly techniques to speed up matrix multiplication

jott - mm_wasm # WebAssembly Techniques to Speed Up Matrix Multiplication by 120x *by [@bwasti](https://twitter.com/bwasti)* **** This post is going to use [wasmblr](https://github.com/bwasti/wasmblr) to implement matrix multiplication in pure WebAssembly. Then we'll optimize it until it's comparable to [TensorFlow.js](https://www.tensorflow.org/js). The result is a **~120x speedup** over a Javascript implementation that can process ~45 billion elements per second on an M1 chip. The best thing about WebAssembly is that we'll be able to run all the code [in browser](https://bwasti.github.io/wasmblr/matmul/)! Here's the full [code listing](https://github.com/bwasti/wasmblr/tree/main/matmul_example). <center> <img src="https://i.imgur.com/WAb4K0l.png" style="display:inline;width:480px;max-width:80%;"/> <img src="https://i.imgur.com/cLhu50x.png" style="display:inline;width:480px;max-width:80%;"/> </center> ## Matrix Multiplication Below is a nice visualization of 3x3x3 matrix multiplication: <center> <img src="https://www.mscroggs.co.uk/img/full/multiply_matrices.gif" style="width:480px;max-width:80%;"/> <br> <a target="_blank" href="https://www.mscroggs.co.uk/blog/tags/matrix%20multiplication">(source)</a> </center> The rows of the first matrix are ["dotted"](https://en.wikipedia.org/wiki/Dot_product) with the columns of the second matrix for every possible pairwise combination. In the above example, each dot product is ~6 operations (multiply, add) and we're doing 9 pairwise dot products, so the total number of operations is 54. For this post we'll be focusing on matrices that result in 4.2M, 33.6M and 268M operations per matrix multiplication. ## Baseline We'll start with this implementation: ``` for (let m = 0; m < M; ++m) { for (let n = 0; n < N; ++n) { for (let k = 0; k < K; ++k) { c[m * N + n] += a[m * K + k] * b[k * N + n]; } } } ``` If we're using Float32Arrays, this can process around 380,000,000 elements per second on my M1 MacBook. It's typical to measure performance in this way (rather than "runs per second"), because it's invariant to the size of the matrices involved. Another convention is to use a standard unit like "GFlops" (billion floating point operations per second) and drop all the zeros. The above implementation achieves 0.38GFlops on my machine. Let's make it 120 times faster. ## Implementation To do so, we'll implement matrix multiplication in WebAssembly. I'm going to use [wasmblr](https://github.com/bwasti/wasmblr) because I like C++, but any in-browser assembler will work. Why use an assembler instead of emscripten? One reason is so we can sweep many different optimization variants rather than try to guess the best parameters. It turns out that we'll need to tune independently for Firefox and Chrome. Pre-compiled solutions like emscripten aren't ideal for such iteration as they end up blowing up the codesize. The code for this section is [here](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L154-L259). If you'd like to skip it and just see the optimized code, that's [here](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L261-L410). #### Memory [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L156-L160) Before we write any computation code, we're going to preallocate memory for inputs and outputs. In a particularly advanced implementation, this might involve allocating scratch space as well, but we're not going to do that. WebAssembly deals with pages of size 64KiB. Let's calculate the general number of pages we need: ```cpp auto pages = (M * N + K * N + M * K) * 4 / (1 << 16) + 1; memory(pages).export_("mem"); ``` Instead of messy pointers as arguments, we're just going to hardcode the offsets for the inputs and output: ```cpp auto A_off = 0; auto B_off = M * K * 4; auto C_off = (M * K + K * N) * 4; ``` Now we can just export the memory and let the user write their arrays directly to the heap: ```javascript const mem = instance.exports.mem; const a = new Float32Array(mem.buffer, 0, M * K); const b = new Float32Array(mem.buffer, M * K * 4, K * N); const c = new Float32Array(mem.buffer, (M * K + K * N) * 4, M * N); ``` #### Loops [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L162-L179) Matrix multiplication is naively $O(n^3)$. We'll be sticking with that approach, but curious readers should certainly check out the [Strassen algorthm](https://en.wikipedia.org/wiki/Strassen_algorithm), which offers an algorithmic speedup (at a potentially acceptable memory overhead/numerical instability). In WebAssembly (which is stack based), a loop might look like this: ```cpp auto m = local(i32); // we're going to loop over m i32.const_(0); // push 0 to the stack local.set(m); // set m = 0 loop(void_); // start the loop! // body goes here // ... local.get(m); // stack: [m] i32.const_(1); // stack: [1, m] i32.add(); // stack: [m + 1] local.tee(m); // stack: [m + 1] + update variable "m" i32.const_(M); // stack: [M, m + 1] i32.lt_u(); // stack: [true/false] (check if m + 1 < M) br_if(0); // if true, jump back to the start of the loop end(); // end the loop! ``` We need [three](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L15-L28) of [those](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L76-L104). #### Body [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L181-L225) In the body of the loop we want to load from A, B and C. This is typical of matrix multiplication implementations with an $\alpha$ value of $1$. $$ C' = \alpha C + A \cdot B $$ Each load operation will look something like this: ```cpp // load original value of C local.get(m); // stack: [m] i32.const_(N); // stack: [N, m] i32.mul(); // stack: [m * N] local.get(n); // stack: [n, m * N] i32.add(); // stack: [m * N + n] i32.const_(4); // (size of a floating point number) i32.mul(); // stack: [(m * N + n) * 4] f32.load(0, C_off); ``` [(and do the same for $A$ and $B$)](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L30-L58) Now we can invoke the actual operation. Since $A$, $B$ and $C$ are on the stack in the right order, we simply call `mul` and then `add`. ```cpp // stack: [B, A, C] f32.mul(); // stack: [B * A, C] f32.add(); // stack: [B * A + C] auto c = local(f32); local.set(c) ``` Note that we have to save $C$ to a local variable in order to later store it (WebAssembly's stack and locals are a bit messy this way). The store operation looks a lot like the load operation: ```cpp // store new value to C local.get(m); i32.const_(N); i32.mul(); local.get(n); i32.add(); i32.const_(4); i32.mul(); local.get(c); f32.store(0, C_off); ``` The result of all this [hard work](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L4-L107)? Firefox ``` N=128 (wasmblr): 0.57 gflops ``` Chrome ``` N=128 (wasmblr): 0.59 gflops ``` Nearly 2x faster out of the box! Great. ## Optimization [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L271-L407) We can do better, but we'll need to pull out a couple of non-obvious techniques. #### Vectorization The first thing we can do is start vectorizing the multiplication. Let's vectorize the $N$ dimension. That means we are loading 1 element from $A$ (from dimension $M$) and 4 elements from $B$ and $C$. In pseudo-Javascript that's ```javascript for (let m = 0; m < M; ++m) { for (let n = 0; n < N; n += 4) { for (let k = 0; k < K; ++k) { // splat converts a scalar to a vector A_vec = splat4(A[m * K + k]); B_vec = B[k * N + n]; C_vec = C[m * N + n]; tmp_vec = vec4_mul(A_vec, B_vec); C_vec = vec4_add(tmp_vec, C_vec); } } } ``` Here are the corresponding `wasmblr` calls: ```cpp v128.load32_splat(0, A_off); v128.load(0, B_off); v128.load(0, C_off); // ... v128.f32x4_mul(); v128.f32x4_add(); // ... v128.store(0, C_off); ``` #### Unrolling Unrolling is a simple idea. For example, the code below ``` for (let n = 0; n < 4; ++n { blah(n); } ``` would become ``` blah(n); blah(n + 1); blah(n + 2); blah(n + 3); ``` This will increase the straight line execution of the code. This is important in WebAssembly applications because loop book-keeping (the code to deal with the loop iteration variable) takes a fair number of instructions to execute. #### Local Variables This technique aims to increase the arithmetic intensity of the inner-most loop. Typically, this is done by loading values from memory into registers. WebAssembly doesn't have registers, so we'll be using local variables and crossing our fingers that the browser's JIT figures things out for us. Arithmetic intensity ($I$) refers to the number of arithmetic operations we can perform per load. We'll want to keep the CPU busy while we wait on load instructions. Memory access is really slow, but luckily it can happen in the background. Matrix multiplication involves doing *every* pairwise dot product along the $K$ dimension. That means for every $m$ loads of $A$ and $n$ loads of $B$ we can do $m\cdot n$ multiplications. That means we can amp the arithmetic intensity: $$ I \approx \frac{m \cdot n}{m + n} $$ If we get the $m$ and $n$ high enough, we'll never need to worry about keeping the CPU busy. We're only limited by how many values we can reasonably keep in local variables! (*Aside: if you're curious how "busy" the CPU can get in the world of WebAssembly, you can benchmark varying levels of unrolled independent multiplications. Some numbers can be collected in your browser with this [example](https://bwasti.github.io/wasmblr/flops/).*) To concretize the ideas above, here's a pseudo-implementation (where we'll assume every loop is actually completely unrolled). ```javascript // load into localA, O(k_unroll * m_unroll) for (let k = 0; k < k_unroll; ++k) { for (let m = 0; m < m_unroll; ++m) { localA[m * k_unroll + k] = A[base_A + m * K + k]; } } // load into localB, O(k_unroll * n_unroll) for (let k = 0; k < k_unroll; ++k) { for (let n = 0; n < n_unroll; ++n) { localB[n * k_unroll + k] = B[base_B + k * N + n]; } } // compute C, O(m_unroll * k_unroll * n_unroll) for (let k = 0; k < k_unroll; ++k) { for (let m = 0; m < m_unroll; ++m) { for (let n = 0; n < n_unroll; ++n) { const tmp = localA[m * k_unroll + k] * localB[n * k_unroll + k]; localC[m * k_unroll + n] += tmp; } } } ``` This will work as a sub-program for most implementations assuming correctly calculated `base_A` and `base_B` offsets into the global memory. In the real code, this will involve creating many local variables. ```cpp std::vector<int> load_a; std::vector<int> load_b; for (auto j = 0; j < K_unroll; ++j) { for (auto i = 0; i < M_unroll; ++i) { load_a.emplace_back(local(v128)); } for (auto i = 0; i < N_unroll; ++i) { load_b.emplace_back(local(v128)); } } ``` The actual loading process involves `local.set`ing all the `v128`s we pulled from memory: ```cpp for (auto k_unroll = 0; k_unroll < K_unroll; ++k_unroll) { for (auto m_unroll = 0; m_unroll < M_unroll; ++m_unroll) { local.get(a_off); v128.load32_splat(0, A_off + (m_unroll * K + k_unroll) * 4); local.set(load_a.at(m_unroll * K_unroll + k_unroll)); } } ``` Note that the above code is unrolling things. We are looping through C++ constructs and emitting WebAssembly. This type of "meta-programming" is particularly useful when writing optimized code. The WebAssembly ends up looking like this: ```php local.get $var99 v32x4.load_splat offset=512 align=1 local.set $var5 local.get $var99 v32x4.load_splat offset=1024 align=1 local.set $var7 local.get $var99 v32x4.load_splat offset=1536 align=1 local.set $var9 local.get $var99 v32x4.load_splat offset=2048 align=1 local.set $var11 ``` The same sort of thing (without splatting) should be done for $B$ and $C$. #### Tuning Finally, we're going to lazily find good parameters for the number of local variables and amount of unrolling by tuning everything. ```javascript for (let m of [1, 2, 4, 8, 16, 32]) { for (let n of [1, 2, 4, 8, 16, 32]) { for (let k of [1, 2, 4, 8, 16, 32]) { let gflops = await bench(mod, M, N, K, m, n, k); } } } ``` This approach works pretty well because Firefox and Chrome end up tuning to different configurations despite implementing the exact same specification. We've discovered properties of their JIT implementations that would have been hard to reason about by looking through the code. ## Results The result of these optimizations is ~150 lines of C++ code and ~50 lines of tuning Javascript. In order to get a sense of how good a job we've done, we can compare our performance with [this benchmark of TensorFlow.js](https://codepen.io/bwasti/pen/GRMebrx?editors=0012), a heavily optimized neural network library. This comparison isn't apples-to-apples because TF.js doesn't have pre-allocated outputs, but it gives us a good sense of how well we've done. Firefox: ```bash N=128 (tfjs-wasm): 9.99 gflops N=256 (tfjs-wasm): 29.43 gflops N=512 (tfjs-wasm): 31.47 gflops N=128 (wasmblr): 43.95 gflops (unroll m: 2, n: 4, k: 16) N=256 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 16) N=512 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 8) ``` Chrome: ```bash N=128 (tfjs-wasm): 29.54 gflops N=256 (tfjs-wasm): 40.38 gflops N=512 (tfjs-wasm): 44.03 gflops N=128 (wasmblr): 46.14 gflops (unroll m: 2, n: 8, k: 1) N=256 (wasmblr): 45.56 gflops (unroll m: 2, n: 8, k: 1) N=512 (wasmblr): 45.98 gflops (unroll m: 2, n: 8, k: 1) ``` <center> <img src="https://i.imgur.com/WAb4K0l.png" style="display:inline;width:480px;max-width:80%;"/> <img src="https://i.imgur.com/cLhu50x.png" style="display:inline;width:480px;max-width:80%;"/> </center> Nice! There are other optimizations worth exploring, such as tiling chunks of the input matrices directly into scratch space (rather than local variables) or tuning loop orders. We might also want to explore different unrolling parameters and resultant "tail" code when the unrolling doesn't evenly divide the input size. This might let us utilize the optimal number of local variables for each JIT. I've left these ideas as an exercise to the reader. :^} Thanks for reading!

Belum ada Komentar untuk "WebAssembly techniques to speed up matrix multiplication"

Posting Komentar

Advertisement