Unlike the previous challenges we’ve covered so far, this challenge is extremely easy to solve and honestly I’m pretty happy because those posts took a long time to write and this one is gonna be short in comparison 😂.
The highlight of this challenge is pointer reinterpretation. We’re also going to apply a little typical leetcode trick with bitwise operations.
The Challenge: Color Inversion
The task is to invert the RGB channels of an image while keeping the alpha channel untouched.
The image comes in as a 1D array of uint8 values in RGBA order.
To invert a channel, we just need to subtract it from 255.
Constraints:
- 1 ≤ width ≤ 4096
- 1 ≤ height ≤ 4096
- width * height ≤ 8,388,608
Example
Input: image = [255, 0, 128, 255, 0, 255, 0, 255], width = 1, height = 2Output: image = [0, 255, 127, 255, 255, 0, 255, 255], width = 1, height = 2(Modified in place)Here’s how you’d naively solve it on the CPU:
image = [255, 0, 128, 255, 0, 255, 0, 255]width = 1height = 2for i in range(0, width * height * 4, 4): image[i] = 255 - image[i] image[i + 1] = 255 - image[i + 1] image[i + 2] = 255 - image[i + 2] # image[i + 3] left intact since it's the alpha channel# image now contains [0, 255, 127, 255, 255, 0, 255, 255]- Loop over each pixel by jumping 4 bytes at a time (since each pixel is 4 bytes).
- Subtract each RGB value from 255 and write it back into the image array.
Solving the Challenge
This challenge is extremely easy to solve. We’re just gonna parallelize the CPU solution above on the GPU where each Triton program handles a chunk of the pixels.
We’ll make 2 changes though:
- Instead of reading R, G, B, and A separately as individual bytes, we’ll reinterpret the
u8pointer as au32pointer so we can read all four channels as a single value. - Instead of doing three separate subtractions (
255 - r,255 - g,255 - b), we’ll use the XOR bitwise operation to invert the RGB channels in a single instruction. (Typical leetcode trick)
As always, let’s first start with the pseudocode:
def invert_colors_gpu(image, width, height): BLOCK_SIZE = 1024 # pixels per program n_pixels = width * height # total pixel count
# image is a u8 pointer so loading image[0] will give us the first byte # We'll reinterpret it as a u32 pointer so image[0] will give us the first 4 bytes (RGBA) together gpu_programs.reinterpret_pointer(image, as=u32_pointer) # GPUs use little endian encoding and since the challenge says the values are in RGBA order, # what we end up with as the u32 value is actually ABGR from first byte to last byte
gpu_programs = get_gpu_programs(count=ceil(n_pixels / BLOCK_SIZE)) gpu_programs.generate_indices(offset=program_id * BLOCK_SIZE, count=BLOCK_SIZE, as=offsets)
gpu_programs.load_pixels(from=image, indices=offsets, as=pixels) gpu_programs.xor_pixels(pixels, with=0x00FFFFFF, as=inverted)
gpu_programs.store_pixels(into=image, values=inverted)
def main(image, width, height): invert_colors_gpu(image, width, height)The XOR trick works because when we XOR a byte with 0xFF (all ones),
every bit flips, which gives us the same result as subtracting from 255.
Our 32bit mask 0x00FFFFFF has 0xFF in the lower three bytes and 0x00 in the top byte, so the RGB channels get inverted while the alpha channel stays intact.
Note
When we reinterpret the pointer as a uint32 pointer,
Triton reads the pixels as ABGR because GPUs use little endian encoding.
Since the challenge says the values are in RGBA order,
the first byte (R) becomes the least significant byte in the 32bit value we load, and the last byte (A) becomes the most significant byte.
That’s why our mask is 0x00FFFFFF instead of 0xFFFFFF00 like you might expect.
The Solution
Now let’s write the actual solution code :)
Step 1. Boilerplate
As always, LeetGPU gives us a boilerplate:
import torchimport tritonimport triton.language as tl
@triton.jitdef invert_kernel(image, width, height, BLOCK_SIZE: tl.constexpr): pass
# image is a tensor on the GPUdef solve(image: torch.Tensor, width: int, height: int): BLOCK_SIZE = 1024 n_pixels = width * height grid = (triton.cdiv(n_pixels, BLOCK_SIZE),)
invert_kernel[grid](image, width, height, BLOCK_SIZE)Since we get an array, we’re just doing a 1D grid, similar to what we did in the vector addition challenge. And we don’t have to worry about tiling (like in matrix multiplication) or memory coalescing (like in matrix transpose) since we read from the same place, we write into the same place and it’s also all contiguous! This is so much easier than what we’ve dealt with in the previous challenges!!
Step 2. Reinterpreting the Pointer and Preparing the Indices
@triton.jitdef invert_kernel(image, width, height, BLOCK_SIZE: tl.constexpr): # reinterpret image as a pointer to u32 data rather than u8 image_ptr = image.to(tl.pointer_type(tl.uint32))
n_pixels = width * height pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_pixelsFrom the top:
image_ptr = image.to(tl.pointer_type(tl.uint32))reinterprets the image pointer as a pointer tou32data rather thanu8data. That way each load/store works per pixel instead of per byte.n_pixels = width * heightcalculates the total number of pixels.pid = tl.program_id(0)gets the program ID. We only have one dimension in the grid so0to get the ID from the first axis.offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)builds the pixel indices we want to load to handle in the current program. For example ifpidis 1 andBLOCK_SIZEis 1024, this is[1024..2047].mask = offsets < n_pixelsprevents out of bounds reads for the last block ifBLOCK_SIZEdoesn’t dividen_pixelsevenly.
This part is very similar to what we did in the vector addition challenge, so if you’re not sure about it, check out that post first then come back!
Step 3. Inverting the RGB Channels and Storing the Result
pixels = tl.load(image_ptr + offsets, mask=mask)
# XOR with 0x00FFFFFF inverts the RGB channels and leaves the alpha channel untouched inverted = pixels ^ 0x00FFFFFF
tl.store(image_ptr + offsets, inverted, mask=mask)From the top:
tl.load(image_ptr + offsets, mask=mask)loads the pixels as u32 values from the indices we prepared above.pixels ^ 0x00FFFFFFuses broadcasting to apply the XOR operation to every single pixel we loaded. Sincepixelsis an array of u32 values (one per pixel) and0x00FFFFFFis a single scalar value, Triton automatically broadcasts the scalar across all elements in the array.tl.store(image_ptr + offsets, inverted, mask=mask)writes the inverted pixels back to the image array at the same indices we loaded them from, with the mask to prevent out of bounds writes.
Step 4. Final Code
And that’s it! We’re done 😆
import torchimport tritonimport triton.language as tl
@triton.jitdef invert_kernel(image, width, height, BLOCK_SIZE: tl.constexpr): image_ptr = image.to(tl.pointer_type(tl.uint32))
n_pixels = width * height pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_pixels
pixels = tl.load(image_ptr + offsets, mask=mask) inverted = pixels ^ 0x00FFFFFF tl.store(image_ptr + offsets, inverted, mask=mask)
# image is a tensor on the GPUdef solve(image: torch.Tensor, width: int, height: int): BLOCK_SIZE = 1024 n_pixels = width * height grid = (triton.cdiv(n_pixels, BLOCK_SIZE),)
invert_kernel[grid](image, width, height, BLOCK_SIZE)What’s Next
The challenges are getting easier to write about thankfully since we’ve covered a lot of core concepts already. We can mostly focus on solving challenges and writing GPU code now in these blog posts rather than getting deep into explanations so that’s great! I’m excited to be able to write more and make posts faster :)
There are some cool concepts we haven’t covered yet but we’ll get there once we get to those challenges!