Spaces:
Build error
Build error
// Write C++ code here. | |
// | |
// Do not forget to dynamically load the C++ library into your application. | |
// | |
// For instance, | |
// | |
// In MainActivity.java: | |
// static { | |
// System.loadLibrary("llama-android"); | |
// } | |
// | |
// Or, in MainActivity.kt: | |
// companion object { | |
// init { | |
// System.loadLibrary("llama-android") | |
// } | |
// } | |
jclass la_int_var; | |
jmethodID la_int_var_value; | |
jmethodID la_int_var_inc; | |
std::string cached_token_chars; | |
bool is_valid_utf8(const char * string) { | |
if (!string) { | |
return true; | |
} | |
const unsigned char * bytes = (const unsigned char *)string; | |
int num; | |
while (*bytes != 0x00) { | |
if ((*bytes & 0x80) == 0x00) { | |
// U+0000 to U+007F | |
num = 1; | |
} else if ((*bytes & 0xE0) == 0xC0) { | |
// U+0080 to U+07FF | |
num = 2; | |
} else if ((*bytes & 0xF0) == 0xE0) { | |
// U+0800 to U+FFFF | |
num = 3; | |
} else if ((*bytes & 0xF8) == 0xF0) { | |
// U+10000 to U+10FFFF | |
num = 4; | |
} else { | |
return false; | |
} | |
bytes += 1; | |
for (int i = 1; i < num; ++i) { | |
if ((*bytes & 0xC0) != 0x80) { | |
return false; | |
} | |
bytes += 1; | |
} | |
} | |
return true; | |
} | |
static void log_callback(ggml_log_level level, const char * fmt, void * data) { | |
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); | |
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); | |
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); | |
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); | |
} | |
extern "C" | |
JNIEXPORT jlong JNICALL | |
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { | |
llama_model_params model_params = llama_model_default_params(); | |
auto path_to_model = env->GetStringUTFChars(filename, 0); | |
LOGi("Loading model from %s", path_to_model); | |
auto model = llama_model_load_from_file(path_to_model, model_params); | |
env->ReleaseStringUTFChars(filename, path_to_model); | |
if (!model) { | |
LOGe("load_model() failed"); | |
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); | |
return 0; | |
} | |
return reinterpret_cast<jlong>(model); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { | |
llama_model_free(reinterpret_cast<llama_model *>(model)); | |
} | |
extern "C" | |
JNIEXPORT jlong JNICALL | |
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { | |
auto model = reinterpret_cast<llama_model *>(jmodel); | |
if (!model) { | |
LOGe("new_context(): model cannot be null"); | |
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); | |
return 0; | |
} | |
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); | |
LOGi("Using %d threads", n_threads); | |
llama_context_params ctx_params = llama_context_default_params(); | |
ctx_params.n_ctx = 2048; | |
ctx_params.n_threads = n_threads; | |
ctx_params.n_threads_batch = n_threads; | |
llama_context * context = llama_new_context_with_model(model, ctx_params); | |
if (!context) { | |
LOGe("llama_new_context_with_model() returned null)"); | |
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), | |
"llama_new_context_with_model() returned null)"); | |
return 0; | |
} | |
return reinterpret_cast<jlong>(context); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { | |
llama_free(reinterpret_cast<llama_context *>(context)); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { | |
llama_backend_free(); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { | |
llama_log_set(log_callback, NULL); | |
} | |
extern "C" | |
JNIEXPORT jstring JNICALL | |
Java_android_llama_cpp_LLamaAndroid_bench_1model( | |
JNIEnv *env, | |
jobject, | |
jlong context_pointer, | |
jlong model_pointer, | |
jlong batch_pointer, | |
jint pp, | |
jint tg, | |
jint pl, | |
jint nr | |
) { | |
auto pp_avg = 0.0; | |
auto tg_avg = 0.0; | |
auto pp_std = 0.0; | |
auto tg_std = 0.0; | |
const auto context = reinterpret_cast<llama_context *>(context_pointer); | |
const auto model = reinterpret_cast<llama_model *>(model_pointer); | |
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); | |
const int n_ctx = llama_n_ctx(context); | |
LOGi("n_ctx = %d", n_ctx); | |
int i, j; | |
int nri; | |
for (nri = 0; nri < nr; nri++) { | |
LOGi("Benchmark prompt processing (pp)"); | |
common_batch_clear(*batch); | |
const int n_tokens = pp; | |
for (i = 0; i < n_tokens; i++) { | |
common_batch_add(*batch, 0, i, { 0 }, false); | |
} | |
batch->logits[batch->n_tokens - 1] = true; | |
llama_kv_cache_clear(context); | |
const auto t_pp_start = ggml_time_us(); | |
if (llama_decode(context, *batch) != 0) { | |
LOGi("llama_decode() failed during prompt processing"); | |
} | |
const auto t_pp_end = ggml_time_us(); | |
// bench text generation | |
LOGi("Benchmark text generation (tg)"); | |
llama_kv_cache_clear(context); | |
const auto t_tg_start = ggml_time_us(); | |
for (i = 0; i < tg; i++) { | |
common_batch_clear(*batch); | |
for (j = 0; j < pl; j++) { | |
common_batch_add(*batch, 0, i, { j }, true); | |
} | |
LOGi("llama_decode() text generation: %d", i); | |
if (llama_decode(context, *batch) != 0) { | |
LOGi("llama_decode() failed during text generation"); | |
} | |
} | |
const auto t_tg_end = ggml_time_us(); | |
llama_kv_cache_clear(context); | |
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; | |
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; | |
const auto speed_pp = double(pp) / t_pp; | |
const auto speed_tg = double(pl * tg) / t_tg; | |
pp_avg += speed_pp; | |
tg_avg += speed_tg; | |
pp_std += speed_pp * speed_pp; | |
tg_std += speed_tg * speed_tg; | |
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); | |
} | |
pp_avg /= double(nr); | |
tg_avg /= double(nr); | |
if (nr > 1) { | |
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); | |
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); | |
} else { | |
pp_std = 0; | |
tg_std = 0; | |
} | |
char model_desc[128]; | |
llama_model_desc(model, model_desc, sizeof(model_desc)); | |
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; | |
const auto model_n_params = double(llama_model_n_params(model)) / 1e9; | |
const auto backend = "(Android)"; // TODO: What should this be? | |
std::stringstream result; | |
result << std::setprecision(2); | |
result << "| model | size | params | backend | test | t/s |\n"; | |
result << "| --- | --- | --- | --- | --- | --- |\n"; | |
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; | |
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; | |
return env->NewStringUTF(result.str().c_str()); | |
} | |
extern "C" | |
JNIEXPORT jlong JNICALL | |
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { | |
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated. | |
llama_batch *batch = new llama_batch { | |
0, | |
nullptr, | |
nullptr, | |
nullptr, | |
nullptr, | |
nullptr, | |
nullptr, | |
}; | |
if (embd) { | |
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); | |
} else { | |
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); | |
} | |
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); | |
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); | |
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); | |
for (int i = 0; i < n_tokens; ++i) { | |
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); | |
} | |
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); | |
return reinterpret_cast<jlong>(batch); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { | |
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); | |
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); | |
delete batch; | |
} | |
extern "C" | |
JNIEXPORT jlong JNICALL | |
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { | |
auto sparams = llama_sampler_chain_default_params(); | |
sparams.no_perf = true; | |
llama_sampler * smpl = llama_sampler_chain_init(sparams); | |
llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); | |
return reinterpret_cast<jlong>(smpl); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) { | |
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer)); | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { | |
llama_backend_init(); | |
} | |
extern "C" | |
JNIEXPORT jstring JNICALL | |
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { | |
return env->NewStringUTF(llama_print_system_info()); | |
} | |
extern "C" | |
JNIEXPORT jint JNICALL | |
Java_android_llama_cpp_LLamaAndroid_completion_1init( | |
JNIEnv *env, | |
jobject, | |
jlong context_pointer, | |
jlong batch_pointer, | |
jstring jtext, | |
jboolean format_chat, | |
jint n_len | |
) { | |
cached_token_chars.clear(); | |
const auto text = env->GetStringUTFChars(jtext, 0); | |
const auto context = reinterpret_cast<llama_context *>(context_pointer); | |
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); | |
bool parse_special = (format_chat == JNI_TRUE); | |
const auto tokens_list = common_tokenize(context, text, true, parse_special); | |
auto n_ctx = llama_n_ctx(context); | |
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); | |
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); | |
if (n_kv_req > n_ctx) { | |
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); | |
} | |
for (auto id : tokens_list) { | |
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); | |
} | |
common_batch_clear(*batch); | |
// evaluate the initial prompt | |
for (auto i = 0; i < tokens_list.size(); i++) { | |
common_batch_add(*batch, tokens_list[i], i, { 0 }, false); | |
} | |
// llama_decode will output logits only for the last token of the prompt | |
batch->logits[batch->n_tokens - 1] = true; | |
if (llama_decode(context, *batch) != 0) { | |
LOGe("llama_decode() failed"); | |
} | |
env->ReleaseStringUTFChars(jtext, text); | |
return batch->n_tokens; | |
} | |
extern "C" | |
JNIEXPORT jstring JNICALL | |
Java_android_llama_cpp_LLamaAndroid_completion_1loop( | |
JNIEnv * env, | |
jobject, | |
jlong context_pointer, | |
jlong batch_pointer, | |
jlong sampler_pointer, | |
jint n_len, | |
jobject intvar_ncur | |
) { | |
const auto context = reinterpret_cast<llama_context *>(context_pointer); | |
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); | |
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); | |
const auto model = llama_get_model(context); | |
const auto vocab = llama_model_get_vocab(model); | |
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); | |
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); | |
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); | |
// sample the most likely token | |
const auto new_token_id = llama_sampler_sample(sampler, context, -1); | |
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); | |
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { | |
return nullptr; | |
} | |
auto new_token_chars = common_token_to_piece(context, new_token_id); | |
cached_token_chars += new_token_chars; | |
jstring new_token = nullptr; | |
if (is_valid_utf8(cached_token_chars.c_str())) { | |
new_token = env->NewStringUTF(cached_token_chars.c_str()); | |
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); | |
cached_token_chars.clear(); | |
} else { | |
new_token = env->NewStringUTF(""); | |
} | |
common_batch_clear(*batch); | |
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); | |
env->CallVoidMethod(intvar_ncur, la_int_var_inc); | |
if (llama_decode(context, *batch) != 0) { | |
LOGe("llama_decode() returned null"); | |
} | |
return new_token; | |
} | |
extern "C" | |
JNIEXPORT void JNICALL | |
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { | |
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); | |
} | |