Skip to main content
Logo
Overview

LeetGPU Challenge #6: 1D Convolution (Triton)

January 11, 2026
9 min read

It’s been a while since the last post. Life got in the way a bit, but I’m back with a new challenge! This one felt like a fun puzzle to solve.

The Challenge: 1D Convolution

The task is to implement 1D convolution. We get an input array and a kernel (filter) and we need to produce the convolved output.

We do not pad the data and the kernel only applies until the point where it fully overlaps the input. This means the output size is going to be input_size - kernel_size + 1.

The convolution operation is defined as:

output[i]=j=0kernel_size1input[i+j]×kernel[j]output[i] = \sum_{j=0}^{kernel\_size-1} input[i+j] \times kernel[j]

where i ranges from 0 to input_size - kernel_size.

Constraints:

  • 1 ≤ input_size ≤ 1,500,000
  • 1 ≤ kernel_size ≤ 2047
  • kernel_size ≤ input_size

Example

Input: input = [1, 2, 3, 4, 5]
kernel = [1, 0, -1]
Output: output = [-2, -2, -2]

Here’s how we would solve this on the CPU:

input = [1, 2, 3, 4, 5]
kernel = [1, 0, -1]
output = []
output_size = len(input) - len(kernel) + 1
for i in range(output_size):
sum = 0.0
for j in range(len(kernel)):
sum += input[i + j] * kernel[j]
output.append(sum)
# output now contains [-2, -2, -2]

We basically do the following:

  1. For each position i in the output (from 0 to input_size - kernel_size)
  2. We go over the kernel elements (from 0 to kernel_size - 1)
  3. We multiply input[i + j] by kernel[j] and accumulate the result
  4. We store the accumulated result in output[i]

Solving the Challenge

This challenge is pretty simple to solve once you understand the pattern. We need to parallelize the outer loop (over output positions) and handle the inner loop (over kernel elements) inside each program with a local accumulator. This is kind of similar to how we solved matrix multiplication when we added an inner loop within each program with a local accumulator.

The only tricky part is loading the input elements efficiently. For example, if kernel_size is 3, we need to load:

  • For output position 0: input[0], input[1], input[2]
  • For output position 1: input[1], input[2], input[3]
  • For output position 2: input[2], input[3], input[4]
  • And so on…

So we’re dealing with a sliding window pattern where each output position needs a different window from the input array.

The clean way to handle this is using broadcasting. We can create a 2D array where:

  • Each row represents an output position
  • Each column represents a kernel position
  • We use broadcasting to create all the input indices we need at once

For example, if we have output positions [0, 1, 2] and kernel positions [0, 1, 2], broadcasting output_positions[:, None] + kernel_positions[None, :] gives us:

[[0+0, 0+1, 0+2], = [[0, 1, 2],
[1+0, 1+1, 1+2], [1, 2, 3],
[2+0, 2+1, 2+2]] [2, 3, 4]]
Note

If you’re not familiar with numpy, the [:, None] and [None, :] syntax might look weird.

Here’s some easy to understand examples:

[:, None]:

[1,2,3,4] -> [[1],
[2],
[3],
[4]]

[None, :]:

[1,2,3,4] -> [[1, 2, 3, 4]]

It basically just changes the “shape” of the array and how we treat it.

This gives us exactly the sliding window pattern we need!

As always, let’s first get started with the pseudocode:

Note

If you’re new to Triton or GPU programming, the following pseudocode might be hard to understand. If that’s the case please start from the first challenge’s post and then come back!

def conv1d_gpu(input, kernel, output, input_size, kernel_size):
BLOCK_SIZE = 2048 # output elements per program
K_BLOCK = 4 # kernel elements to process at a time
TOTAL = input_size - kernel_size + 1 # total output size
n_programs = ceil(TOTAL / BLOCK_SIZE)
gpu_programs = get_gpu_programs(count=n_programs)
# Each program handles BLOCK_SIZE output elements
# Calculate which output elements this program handles
gpu_programs.generate_output_indices(
offset=program_id * BLOCK_SIZE,
count=BLOCK_SIZE,
as=output_indices
)
# For program_id=0 and BLOCK_SIZE=3, output_indices would be: [0, 1, 2]
# Initialize accumulator for all output elements this program handles
gpu_programs.initialize_accumulator(
size=BLOCK_SIZE,
initial_value=0.0,
as=accumulator
)
# Process kernel in chunks
num_k_chunks = ceil(kernel_size / K_BLOCK)
for k_chunk in range(num_k_chunks):
k_chunk_start = k_chunk * K_BLOCK
k_chunk_end = min(k_chunk_start + K_BLOCK, kernel_size)
# Generate kernel indices for this chunk
gpu_programs.generate_kernel_indices(
start=k_chunk_start,
end=k_chunk_end,
as=kernel_indices_chunk
)
# For k_chunk=0 and K_BLOCK=4, kernel_indices_chunk would be: [0, 1, 2, 3]
# Load kernel values for this chunk
gpu_programs.load_values(
from=kernel,
indices=kernel_indices_chunk,
as=kernel_values_chunk
)
# Use broadcasting to create input indices for all output positions at once
# We create a 2D array where:
# - Each row represents one output position
# - Each column represents one kernel position in this chunk
# - We use broadcasting: output_indices[:, None] + kernel_indices_chunk[None, :]
# This creates a column vector from output_indices and a row vector from kernel_indices_chunk
# Broadcasting adds them together to create all combinations
gpu_programs.create_broadcasted_indices(
output_indices=output_indices,
kernel_indices=kernel_indices_chunk,
as=input_indices_2d
)
# For output_indices=[0, 1, 2] and kernel_indices_chunk=[0, 1, 2, 3], input_indices_2d would be:
# [[0+0, 0+1, 0+2, 0+3], = [[0, 1, 2, 3],
# [1+0, 1+1, 1+2, 1+3], [1, 2, 3, 4],
# [2+0, 2+1, 2+2, 2+3]] [2, 3, 4, 5]]
# Each row is the sliding window of input indices we need for that output position
# Load input values for all output positions at once using the indices we just created
gpu_programs.load_values(
from=input,
indices=input_indices_2d,
as=input_values_2d
)
# This loads a 2D array where each row contains the input window for one output position
# For our example above, input_values_2d would be:
# [[input[0], input[1], input[2], input[3]],
# [input[1], input[2], input[3], input[4]],
# [input[2], input[3], input[4], input[5]]]
# Multiply the elements and accumulate the results along the kernel dimension
# Each row gets multiplied elementwise with kernel_values_chunk and summed
# Then we add this result to the accumulator
gpu_programs.multiply_and_accumulate(
input_values_2d=input_values_2d,
kernel_values=kernel_values_chunk,
accumulator=accumulator
)
# For each output position i:
# accumulator[i] += sum(input_values_2d[i] * kernel_values_chunk)
# After processing all kernel chunks, accumulator contains the final result
# Store results
gpu_programs.store_values(
into=output,
indices=output_indices,
values=accumulator
)
def main(input, kernel, output, input_size, kernel_size):
conv1d_gpu(input, kernel, output, input_size, kernel_size)

The Solution

Now let’s write the actual solution code :)

Step 1. Boilerplate

We’ll start with LeetGPU’s boilerplate (it’s modified a bit to fit how we’re going to actually implement it):

import torch
import triton
import triton.language as tl
@triton.jit
def conv1d_kernel(
input, kernel, output,
kernel_size,
TOTAL,
BLOCK_SIZE: tl.constexpr,
K_BLOCK: tl.constexpr,
):
pass
# input, kernel, output are tensors on the GPU
def solve(input: torch.Tensor, kernel: torch.Tensor, output: torch.Tensor, input_size: int, kernel_size: int):
BLOCK_SIZE = 2048
K_BLOCK = 4
TOTAL = input_size - kernel_size + 1
n_blocks = triton.cdiv(TOTAL, BLOCK_SIZE)
grid = (n_blocks,)
conv1d_kernel[grid](
input, kernel, output,
kernel_size,
TOTAL=TOTAL,
BLOCK_SIZE=BLOCK_SIZE,
K_BLOCK=K_BLOCK
)

From the top in solve:

  • Since we get arrays, we’re just doing a 1D grid, similar to what we did in the vector addition challenge
  • BLOCK_SIZE = 2048 is the number of output elements each program handles
  • K_BLOCK = 4 is the chunk size for handling the kernel (we’ll process the kernel in chunks of 4 elements at a time)
  • TOTAL = input_size - kernel_size + 1 is the total number of output elements we need to compute
  • We calculate n_blocks = triton.cdiv(TOTAL, BLOCK_SIZE) to determine how many programs we need
  • The kernel function takes TOTAL as a parameter instead of input_size since that’s what we actually need for bounds checking

Step 2. Setting Up the Indices and the Accumulator

Next we set up the indices for the output elements and initialize our accumulator:

@triton.jit
def conv1d_kernel(
input, kernel, output,
kernel_size,
TOTAL,
BLOCK_SIZE: tl.constexpr,
K_BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]
mask = offsets < TOTAL
k_offsets = tl.arange(0, K_BLOCK)[None, :]
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

From the top:

  • pid = tl.program_id(0) gets the program ID from our 1D grid
  • offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] creates a column vector of output indices we handle in this program.
  • mask = offsets < TOTAL creates a mask for us that will help prevent out of bounds memory access.
  • k_offsets = tl.arange(0, K_BLOCK)[None, :] creates the base row vector [0, 1, 2, ..., K_BLOCK-1] for kernel indices.
  • acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) initializes our accumulator as a 1D array of zeros, one per output element we handle in this program.

Step 3. Processing the Kernel in Chunks

Now let’s add the loop that processes the kernel in chunks:

for k in range(0, kernel_size, K_BLOCK):
k_mask = (k + k_offsets) < kernel_size
k_vals = tl.load(kernel + k + k_offsets, mask=k_mask)
in_ptrs = input + offsets + k + k_offsets
in_mask = mask & k_mask
in_vals = tl.load(in_ptrs, mask=in_mask)
acc += tl.sum(in_vals * k_vals, axis=1)

From the top:

  • for k in range(0, kernel_size, K_BLOCK) iterates over the kernel in chunks of K_BLOCK elements.
  • k_mask = (k + k_offsets) < kernel_size creates a mask for the kernel chunk to prevent out of bounds memory access.
  • k_vals = tl.load(kernel + k + k_offsets, mask=k_mask) loads the kernel values for this chunk.
  • in_ptrs = input + offsets + k + k_offsets uses broadcasting to create a BLOCK_SIZE x K_BLOCK grid of input pointers.
  • in_mask = mask & k_mask combines both masks to prevent out of bounds memory access.
  • in_vals = tl.load(in_ptrs, mask=in_mask) loads the input values.
  • acc += tl.sum(in_vals * k_vals, axis=1) multiplies and sums along the kernel dimension, then adds the result to our accumulator.

Step 4. Storing the Results

Now all that’s left to do is to store the results:

tl.store(output + offsets, acc[:, None], mask=mask)

Here we need to use acc[:, None] to reshape the accumulator to a column vector so it matches the shape of our output indices. This is just something that tl.store needs us to do for it to be able to figure out where to store the values.

Step 5. Final Code

We’re done! Here’s the complete solution:

import torch
import triton
import triton.language as tl
@triton.jit
def conv1d_kernel(
input, kernel, output,
kernel_size,
TOTAL,
BLOCK_SIZE: tl.constexpr,
K_BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]
mask = offsets < TOTAL
k_offsets = tl.arange(0, K_BLOCK)[None, :]
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for k in range(0, kernel_size, K_BLOCK):
k_mask = (k + k_offsets) < kernel_size
k_vals = tl.load(kernel + k + k_offsets, mask=k_mask)
in_ptrs = input + offsets + k + k_offsets
in_mask = mask & k_mask
in_vals = tl.load(in_ptrs, mask=in_mask)
acc += tl.sum(in_vals * k_vals, axis=1)
tl.store(output + offsets, acc[:, None], mask=mask)
# input, kernel, output are tensors on the GPU
def solve(input: torch.Tensor, kernel: torch.Tensor, output: torch.Tensor, input_size: int, kernel_size: int):
BLOCK_SIZE = 2048
K_BLOCK = 4
TOTAL = input_size - kernel_size + 1
n_blocks = triton.cdiv(TOTAL, BLOCK_SIZE)
grid = (n_blocks,)
conv1d_kernel[grid](
input, kernel, output,
kernel_size,
TOTAL=TOTAL,
BLOCK_SIZE=BLOCK_SIZE,
K_BLOCK=K_BLOCK
)

What’s Next

In this challenge figuring out how to structure the broadcasting for the inner loop felt like a pretty fun puzzle! I won’t be able to post as often as I’d like to for a while but I promise to post at least once a week so please stay tuned :)

Resources