Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,113 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# How to use this model
|
| 6 |
+
|
| 7 |
+
```python
|
| 8 |
+
tl_methods = [
|
| 9 |
+
'PropagateNan', 'TRITON_MAX_TENSOR_NUMEL', 'abs', 'advance', 'arange',
|
| 10 |
+
'argmax', 'argmin', 'associative_scan', 'atomic_add', 'atomic_and',
|
| 11 |
+
'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg',
|
| 12 |
+
'atomic_xor', 'bfloat16', 'block_type', 'broadcast', 'broadcast_to',
|
| 13 |
+
'cast', 'cat', 'cdiv', 'ceil', 'clamp', 'const', 'const_pointer_type',
|
| 14 |
+
'constexpr', 'cos', 'cumprod', 'cumsum', 'debug_barrier', 'device_assert',
|
| 15 |
+
'device_print', 'div_rn', 'dot', 'dtype', 'erf', 'exp', 'exp2',
|
| 16 |
+
'expand_dims', 'fdiv', 'flip', 'float16', 'float32', 'float64',
|
| 17 |
+
'float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16',
|
| 18 |
+
'floor', 'fma', 'full', 'function_type', 'histogram',
|
| 19 |
+
'inline_asm_elementwise', 'int1', 'int16', 'int32', 'int64', 'int8',
|
| 20 |
+
'interleave', 'join', 'load', 'log', 'log2', 'make_block_ptr', 'max',
|
| 21 |
+
'max_constancy', 'max_contiguous', 'maximum', 'min', 'minimum',
|
| 22 |
+
'multiple_of', 'num_programs', 'pair_uniform_to_normal', 'permute',
|
| 23 |
+
'philox', 'pi32_t', 'pointer_type', 'program_id', 'rand', 'rand4x',
|
| 24 |
+
'randint', 'randint4x', 'randn', 'randn4x', 'range', 'ravel', 'reduce',
|
| 25 |
+
'reshape', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'sort', 'split', 'sqrt',
|
| 26 |
+
'sqrt_rn', 'static_assert', 'static_print', 'static_range', 'store',
|
| 27 |
+
'str_to_ty', 'sum', 'swizzle2d', 'tensor', 'trans', 'uint16', 'uint32',
|
| 28 |
+
'uint64', 'uint8', 'uint_to_uniform_float', 'umulhi', 'view', 'void',
|
| 29 |
+
'where', 'xor_sum', 'zeros', 'zeros_like'
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_user_prompt(name, pytorch_impl):
|
| 34 |
+
prompt = f"""Convert this PyTorch module implementation into an equivalent Triton kernel:
|
| 35 |
+
|
| 36 |
+
<torch_code>
|
| 37 |
+
{pytorch_impl}
|
| 38 |
+
</torch_code>
|
| 39 |
+
|
| 40 |
+
The Triton kernel should:
|
| 41 |
+
1. Import torch, triton, and triton.language as tl and other necessary modules
|
| 42 |
+
2. Use @triton.jit decorator on the kernel implementation (not the entrypoint function)
|
| 43 |
+
3. Have proper grid and block sizes
|
| 44 |
+
4. Use a mask in the load/store operations
|
| 45 |
+
5. Use typed constants (tl.constexpr)
|
| 46 |
+
6. Handle tensor dimensions correctly
|
| 47 |
+
7. Return output matching PyTorch's implementation
|
| 48 |
+
8. Do not include any test code in your response, only the Triton kernel implementation and entrypoint function
|
| 49 |
+
|
| 50 |
+
The triton.language (tl) module supports the following methods: {", ".join(tl_methods)}
|
| 51 |
+
|
| 52 |
+
The entrypoint function must be named: {name}_triton
|
| 53 |
+
The Triton kernel implementation (called by the entrypoint) must be named: {name}_kernel
|
| 54 |
+
|
| 55 |
+
No computation logic should be done within the entrypoint function. All computation logic should be done within the Triton kernel implementation.
|
| 56 |
+
|
| 57 |
+
The final generated code in the response must start with <triton_code> and end with </triton_code> tags."""
|
| 58 |
+
|
| 59 |
+
return prompt
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
SYSTEM_PROMPT = """You are a helpful assistant that converts PyTorch code into Triton kernels."""
|
| 63 |
+
|
| 64 |
+
messages = [
|
| 65 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 66 |
+
{"role": "user", "content": get_user_prompt(name, code)},
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
...
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Example PyTorch code (from Kernelbench):
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
import torch
|
| 76 |
+
import torch.nn as nn
|
| 77 |
+
|
| 78 |
+
class Model(nn.Module):
|
| 79 |
+
"""
|
| 80 |
+
Simple model that performs a LeakyReLU activation.
|
| 81 |
+
"""
|
| 82 |
+
def __init__(self, negative_slope: float = 0.01):
|
| 83 |
+
"""
|
| 84 |
+
Initializes the LeakyReLU module.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
negative_slope (float, optional): The negative slope of the activation function. Defaults to 0.01.
|
| 88 |
+
"""
|
| 89 |
+
super(Model, self).__init__()
|
| 90 |
+
self.negative_slope = negative_slope
|
| 91 |
+
|
| 92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Applies LeakyReLU activation to the input tensor.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
x (torch.Tensor): Input tensor of any shape.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
torch.Tensor: Output tensor with LeakyReLU applied, same shape as input.
|
| 101 |
+
"""
|
| 102 |
+
return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope)
|
| 103 |
+
|
| 104 |
+
batch_size = 16
|
| 105 |
+
dim = 16384
|
| 106 |
+
|
| 107 |
+
def get_inputs():
|
| 108 |
+
x = torch.randn(batch_size, dim)
|
| 109 |
+
return [x]
|
| 110 |
+
|
| 111 |
+
def get_init_inputs():
|
| 112 |
+
return [] # No special initialization inputs needed
|
| 113 |
+
```
|