Spaces:
Running
Running
| title: "Optimizing LLM Performance Using Triton" | |
| format: | |
| revealjs: | |
| theme: dark | |
| transition: slide | |
| slide-number: true | |
| author: "Matej Sirovatka" | |
| date: today | |
| ## `whoami` | |
| - My name is Matej | |
| - I'm a Master's student at the Brno University of Technology | |
| - I currently make GPUs go `brrrrrr` at Hugging Face π€ | |
| ## `What is Triton?` | |
| - NVIDIA's open-source programming language for GPU kernels | |
| - Designed for AI/ML workloads | |
| - Simplifies GPU programming compared to CUDA | |
| {.center fig-align="center"} | |
| ## `Why Optimize with Triton?` | |
| - Simple yet effective | |
| - Less headache than CUDA | |
| - GPUs go `brrrrrrr` π | |
| - Feel cool when your kernel is faster than PyTorch π | |
| ## `Example Problem: KL Divergence` | |
| - commonly used in LLMs for knowledge distillation | |
| - for probability distributions $P$ and $Q$, the Kullback-Leibler divergence is defined as: | |
| $$ | |
| D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) | |
| $$ | |
| ```python | |
| import torch | |
| from torch.nn.functional import kl_div | |
| def kl_div_torch(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: | |
| return kl_div(p, q, reduction='none') | |
| ``` | |
| ## `How about Triton?` | |
| ```python | |
| import triton.language as tl | |
| @triton.jit | |
| def kl_div_triton( | |
| p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr | |
| ): | |
| pid = tl.program_id(0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| p = tl.load(p_ptr + offsets, mask=mask) | |
| q = tl.load(q_ptr + offsets, mask=mask) | |
| output = p * (tl.log(p) - tl.log(q)) | |
| tl.store(output_ptr + offsets, output, mask=mask) | |
| ``` | |
| ## `How to integrate with PyTorch?` | |
| - Triton works with pointers | |
| - How to use our custom kernel with PyTorch autograd? | |
| ```python | |
| import torch | |
| class VectorAdd(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, p, q): | |
| ctx.save_for_backward(q) | |
| output = torch.empty_like(p) | |
| grid = (len(p) + 512 - 1) // 512 | |
| kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512) | |
| return output | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| q = ctx.saved_tensors[0] | |
| # Calculate gradients (another triton kernel) | |
| return ... | |
| ``` | |
| ## `Some benchmarks` | |
| - A KL Divergence kernel that is currently used in [Liger Kernel](https://github.com/linkedin/liger-kernel) written by @me | |
| :::: {.columns} | |
| ::: {.column width="50%"} | |
| {.center fig-align="center"} | |
| ::: | |
| ::: {.column width="50%"} | |
| {.center fig-align="center"} | |
| ::: | |
| :::: | |
| ## `Do I have to write everything?` | |
| - TLDR: No | |
| - Many cool projects already using Triton | |
| - Better Integration with PyTorch and even Hugging Face π€ | |
| - Liger Kernel, Unsloth AI, etc. | |
| :::: {.columns} | |
| ::: {.column width="50%"} | |
| {.center fig-align="center"} | |
| ::: | |
| ::: {.column width="50%"} | |
| {.center fig-align="center"} | |
| ::: | |
| :::: | |
| ## `So how can I use this in my LLM? π` | |
| - Liger Kernel is a great example, providing examples of how to integrate with Hugging Face π€ Trainer | |
| ```diff | |
| - from transformers import AutoModelForCausalLM | |
| + from liger_kernel.transformers import AutoLigerKernelForCausalLM | |
| model_path = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| - model = AutoModelForCausalLM.from_pretrained(model_path) | |
| + model = AutoLigerKernelForCausalLM.from_pretrained(model_path) | |
| # training/inference logic... | |
| ``` | |
| ## `Key Optimization Techniques adapted by Liger Kernel` | |
| - Kernel Fusion | |
| - Domain-specific optimizations | |
| - Memory Access Patterns | |
| - Preemptive memory freeing | |
| ## `Aaand some more benchmarks π` | |
| :::: {.columns} | |
| ::: {.column width="50%"} | |
| {fig-align="center"} | |
| ::: | |
| ::: {.column width="50%"} | |
| {fig-align="center"} | |
| ::: | |
| :::: | |
| ## `Last benchmark I promise...` | |
| {height="50%" width="50%" } | |
| ::: {.incremental} | |
| *Attention is all you need, so I thank you for yours!* π€ | |
| ::: | |