I’ve been working on optimizing the following PyTorch functions by rewriting them in Triton to speed up execution:
PyTorch Code
def split_sys_img_que(x):
return (
x[..., :num_sys, :],
x[..., num_sys:num_sys+num_img, :],
x[..., -num_que:, :])
def construct_chunk(x, y):
sink_tok = x[..., :1, :].expand(num_chunks, -1, -1, -1)
x = torch.cat([x, self.pad_tok], dim=-2).view(1, num_heads, num_sys_chk, -1, head_dim).transpose(1,2).view(num_sys_chk, num_heads, -1, head_dim)
y = torch.cat([sink_tok, y.view(1, num_heads, num_chunks, -1, head_dim).transpose(1,2).view(num_chunks, num_heads, -1, head_dim)], dim=-2)
return torch.cat([x, y], dim=0)
In these functions:
- split_sys_img_que:
A typical input has x.shape = [1, 32, 645, 128]
Global variable values: num_sys=35, num_img=576, num_que=34 - construct_chunk:
Typical input shapes: x.shape = [1, 32, 35, 128], y.shape = [1, 32, 576, 128]
Global variable values:- num_chunks=12
- self.pad_tok.shape=[1,32,14,128]
- num_sys_chk=1
Request for Help
I found these functions to be time-consuming and attempted to optimize them using Triton. However, my Triton implementation turned out to be slower than the original PyTorch code (possibly due to my lack of experience with Triton).
Could anyone suggest improvements to my approach or provide guidance on how to efficiently use Triton/CUDA for these functions?
rain_ is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.