Today’s challenge is extremely simple so let’s make it quick 😆.
The Challenge: Reverse Array
The task is to reverse an array in place. We get an array of numbers and we need to reverse it, modifying the original array.
Constraints:
- 1 ≤ N ≤ 100,000,000
Example
Input: [1.0, 2.0, 3.0, 4.0]Output: [4.0, 3.0, 2.0, 1.0]Here’s how we’d solve this on the CPU:
arr = [1.0, 2.0, 3.0, 4.0]N = len(arr)
for i in range(N // 2): j = N - 1 - i arr[i], arr[j] = arr[j], arr[i]# arr is now [4.0, 3.0, 2.0, 1.0]We basically do the following:
- Loop over the first half of the array (from
0toN // 2 - 1) - For each position
i, calculate the corresponding position on the other side:j = N - 1 - i - Swap the values at positions
iandj
Notice how if the array has an odd number of elements, we automatically skip the middle element since it doesn’t need to be swapped.
Solving the Challenge
To solve this challenge we’re just going to parallelize the CPU solution above on the GPU.
Given how simple this challenge is and how many challenges we’ve covered so far with in depth pseudocode and explanations, I think this time we can skip the pseudocode and get straight to the solution :)
The Solution
Step 1. Boilerplate
Let’s start with a basic boilerplate:
import torchimport tritonimport triton.language as tl
@triton.jitdef reverse_kernel( input, half, N, BLOCK_SIZE: tl.constexpr): pass
# input is a tensor on the GPUdef solve(input: torch.Tensor, N: int): BLOCK_SIZE = 256 half = N // 2 n_blocks = triton.cdiv(half, BLOCK_SIZE) grid = (n_blocks,)
reverse_kernel[grid]( input, half, N, BLOCK_SIZE=BLOCK_SIZE )From the top in solve:
- Since we get an array, we’re just doing a 1D grid, similar to what we did in the vector addition challenge
BLOCK_SIZE = 256is the number of elements we want each program to handlehalf = N // 2is the midpoint. We only need to process elements up to this point since swapping is symmetric as we covered above.- We calculate
n_blocks = triton.cdiv(half, BLOCK_SIZE)to determine how many programs we need - We pass
halfandNto the kernel so it knows where to stop and how to calculate the right side’s indices
Step 2. Calculating the Indices and the Mask
Now let’s fill in the kernel with the index calculations and the mask:
@triton.jitdef reverse_kernel( input, half, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) left_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) right_offsets = (N - 1) - left_offsets mask = left_offsets < halfFrom the top:
pid = tl.program_id(axis=0)gets the program ID from our 1D gridleft_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)creates the indices for the left side elements this program handles.right_offsets = (N - 1) - left_offsetscalculates the corresponding right side indices.mask = left_offsets < halfcreates a mask that prevents us from processing beyond the midpoint. This makes sure that we don’t try to swap elements in the second half (which would be redundant since we already swapped them).
Step 3. Loading, Swapping, and Storing
Now let’s load the values and perform the swap:
l_val = tl.load(input + left_offsets, mask=mask) r_val = tl.load(input + right_offsets, mask=mask)
tl.store(input + left_offsets, r_val, mask=mask) tl.store(input + right_offsets, l_val, mask=mask)From the top:
l_val = tl.load(input + left_offsets, mask=mask)loads the values from the left side positions, using the mask to prevent out of bounds reads.r_val = tl.load(input + right_offsets, mask=mask)loads the values from the right side positions, using the mask to prevent out of bounds reads.tl.store(input + left_offsets, r_val, mask=mask)stores the right side values into the left side positions, using the mask to prevent undoing the swaps we already made.tl.store(input + right_offsets, l_val, mask=mask)stores the left side values into the right side positions, using the mask to prevent undoing the swaps we already made.
Step 4. Final Code
We’re done! Here’s the complete solution:
import torchimport tritonimport triton.language as tl
@triton.jitdef reverse_kernel( input, half, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) left_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) right_offsets = (N - 1) - left_offsets mask = left_offsets < half
l_val = tl.load(input + left_offsets, mask=mask) r_val = tl.load(input + right_offsets, mask=mask)
tl.store(input + left_offsets, r_val, mask=mask) tl.store(input + right_offsets, l_val, mask=mask)
# input is a tensor on the GPUdef solve(input: torch.Tensor, N: int): BLOCK_SIZE = 256 half = N // 2 n_blocks = triton.cdiv(half, BLOCK_SIZE) grid = (n_blocks,)
reverse_kernel[grid]( input, half, N, BLOCK_SIZE=BLOCK_SIZE )What’s Next
That’s it for this one! The next few challenges are going to be similar in how simple they are until a certain point, so I’m gonna be able to post them faster than usual since it takes less time to write the blogs and explain things. Stay tuned!