Spaces:
Build error
Build error
// kernel argument structs | |
// | |
// - element counters (e.g. ne00) typically use int32_t to reduce register usage | |
// however, be careful from int overflows when using those in the kernel implementation | |
// | |
// - strides (e.g. nb00) use uint64_t | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
int32_t ne03; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne10; | |
int32_t ne11; | |
int32_t ne12; | |
int32_t ne13; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
int32_t ne0; | |
int32_t ne1; | |
int32_t ne2; | |
int32_t ne3; | |
uint64_t nb0; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
int32_t dim; | |
} ggml_metal_kargs_concat; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
int32_t ne03; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne10; | |
int32_t ne11; | |
int32_t ne12; | |
int32_t ne13; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
int32_t ne0; | |
int32_t ne1; | |
int32_t ne2; | |
int32_t ne3; | |
uint64_t nb0; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
uint64_t offs; | |
} ggml_metal_kargs_bin; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
int32_t ne03; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne0; | |
int32_t ne1; | |
int32_t ne2; | |
int32_t ne3; | |
uint64_t nb0; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
} ggml_metal_kargs_repeat; | |
typedef struct { | |
int64_t ne00; | |
int64_t ne01; | |
int64_t ne02; | |
int64_t ne03; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int64_t ne0; | |
int64_t ne1; | |
int64_t ne2; | |
int64_t ne3; | |
uint64_t nb0; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
} ggml_metal_kargs_cpy; | |
typedef struct { | |
int64_t ne10; | |
int64_t ne11; | |
int64_t ne12; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
uint64_t offs; | |
bool inplace; | |
} ggml_metal_kargs_set; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
int32_t ne03; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne0; | |
int32_t ne1; | |
int32_t ne2; | |
int32_t ne3; | |
uint64_t nb0; | |
uint64_t nb1; | |
uint64_t nb2; | |
uint64_t nb3; | |
int32_t n_past; | |
int32_t n_dims; | |
int32_t n_ctx_orig; | |
float freq_base; | |
float freq_scale; | |
float ext_factor; | |
float attn_factor; | |
float beta_fast; | |
float beta_slow; | |
} ggml_metal_kargs_rope; | |
typedef struct { | |
int32_t ne01; | |
int32_t ne02; | |
int32_t ne03; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne11; | |
int32_t ne_12_2; // assume K and V are same shape | |
int32_t ne_12_3; | |
uint64_t nb_12_1; | |
uint64_t nb_12_2; | |
uint64_t nb_12_3; | |
uint64_t nb31; | |
int32_t ne1; | |
int32_t ne2; | |
float scale; | |
float max_bias; | |
float m0; | |
float m1; | |
uint16_t n_head_log2; | |
float logit_softcap; | |
} ggml_metal_kargs_flash_attn_ext; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne02; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne12; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
int32_t ne0; | |
int32_t ne1; | |
int16_t r2; | |
int16_t r3; | |
} ggml_metal_kargs_mul_mm; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne10; | |
int32_t ne11; | |
int32_t ne12; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
int32_t ne0; | |
int32_t ne1; | |
int16_t r2; | |
int16_t r3; | |
} ggml_metal_kargs_mul_mv; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
uint64_t nb03; | |
int32_t ne10; | |
int32_t ne11; | |
int32_t ne12; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
uint64_t nb13; | |
int32_t ne0; | |
int32_t ne1; | |
int16_t r2; | |
int16_t r3; | |
int16_t nsg; | |
int16_t nxpsg; | |
int16_t r1ptg; | |
} ggml_metal_kargs_mul_mv_ext; | |
typedef struct { | |
int32_t nei0; | |
int32_t nei1; | |
uint64_t nbi1; | |
int32_t ne00; | |
int32_t ne02; | |
uint64_t nb01; | |
uint64_t nb02; | |
int32_t ne11; | |
int32_t ne12; | |
int32_t ne13; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
int32_t ne0; | |
int32_t ne1; | |
} ggml_metal_kargs_mul_mm_id; | |
typedef struct { | |
int32_t nei0; | |
int32_t nei1; | |
uint64_t nbi1; | |
int32_t ne00; | |
int32_t ne01; | |
int32_t ne02; | |
uint64_t nb00; | |
uint64_t nb01; | |
uint64_t nb02; | |
int32_t ne10; | |
int32_t ne11; | |
int32_t ne12; | |
int32_t ne13; | |
uint64_t nb10; | |
uint64_t nb11; | |
uint64_t nb12; | |
int32_t ne0; | |
int32_t ne1; | |
uint64_t nb1; | |
} ggml_metal_kargs_mul_mv_id; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne00_4; | |
uint64_t nb01; | |
float eps; | |
} ggml_metal_kargs_norm; | |
typedef struct { | |
int32_t ne00; | |
int32_t ne00_4; | |
uint64_t nb01; | |
float eps; | |
} ggml_metal_kargs_rms_norm; | |