File size: 1,971 Bytes
29e93ec 6eaa88c 29e93ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
#pragma once
#include <torch/torch.h>
#include <core/scalar_type.hpp>
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices,
torch::Tensor &token_expert_indices,
torch::Tensor &gating_output);
void moe_sum(torch::Tensor &input, torch::Tensor &output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void static_scaled_fp8_quant(torch::Tensor &out, torch::Tensor const &input,
torch::Tensor const &scale);
void dynamic_scaled_fp8_quant(torch::Tensor &out, torch::Tensor const &input,
torch::Tensor &scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor &out, torch::Tensor const &input, torch::Tensor &scale,
std::optional<torch::Tensor> const &scale_ub);
#ifndef USE_ROCM
torch::Tensor marlin_gemm_moe(
const torch::Tensor &a, const torch::Tensor &b_q_weights,
const torch::Tensor &sorted_ids, const torch::Tensor &topk_weights,
const torch::Tensor &topk_ids, const torch::Tensor &b_scales,
torch::Tensor &b_zeros, const torch::Tensor &g_idx,
const torch::Tensor &perm, torch::Tensor &workspace,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights);
#endif
|