|
#include <android/log.h> |
|
#include <jni.h> |
|
#include <iomanip> |
|
#include <math.h> |
|
#include <string> |
|
#include <unistd.h> |
|
#include "llama.h" |
|
#include "common.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define TAG "llama-android.cpp" |
|
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) |
|
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) |
|
|
|
jclass la_int_var; |
|
jmethodID la_int_var_value; |
|
jmethodID la_int_var_inc; |
|
|
|
std::string cached_token_chars; |
|
|
|
bool is_valid_utf8(const char * string) { |
|
if (!string) { |
|
return true; |
|
} |
|
|
|
const unsigned char * bytes = (const unsigned char *)string; |
|
int num; |
|
|
|
while (*bytes != 0x00) { |
|
if ((*bytes & 0x80) == 0x00) { |
|
|
|
num = 1; |
|
} else if ((*bytes & 0xE0) == 0xC0) { |
|
|
|
num = 2; |
|
} else if ((*bytes & 0xF0) == 0xE0) { |
|
|
|
num = 3; |
|
} else if ((*bytes & 0xF8) == 0xF0) { |
|
|
|
num = 4; |
|
} else { |
|
return false; |
|
} |
|
|
|
bytes += 1; |
|
for (int i = 1; i < num; ++i) { |
|
if ((*bytes & 0xC0) != 0x80) { |
|
return false; |
|
} |
|
bytes += 1; |
|
} |
|
} |
|
|
|
return true; |
|
} |
|
|
|
static void log_callback(ggml_log_level level, const char * fmt, void * data) { |
|
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); |
|
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); |
|
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); |
|
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jlong JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { |
|
llama_model_params model_params = llama_model_default_params(); |
|
|
|
auto path_to_model = env->GetStringUTFChars(filename, 0); |
|
LOGi("Loading model from %s", path_to_model); |
|
|
|
auto model = llama_load_model_from_file(path_to_model, model_params); |
|
env->ReleaseStringUTFChars(filename, path_to_model); |
|
|
|
if (!model) { |
|
LOGe("load_model() failed"); |
|
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); |
|
return 0; |
|
} |
|
|
|
return reinterpret_cast<jlong>(model); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { |
|
llama_free_model(reinterpret_cast<llama_model *>(model)); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jlong JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { |
|
auto model = reinterpret_cast<llama_model *>(jmodel); |
|
|
|
if (!model) { |
|
LOGe("new_context(): model cannot be null"); |
|
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); |
|
return 0; |
|
} |
|
|
|
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); |
|
LOGi("Using %d threads", n_threads); |
|
|
|
llama_context_params ctx_params = llama_context_default_params(); |
|
|
|
ctx_params.n_ctx = 2048; |
|
ctx_params.n_threads = n_threads; |
|
ctx_params.n_threads_batch = n_threads; |
|
|
|
llama_context * context = llama_new_context_with_model(model, ctx_params); |
|
|
|
if (!context) { |
|
LOGe("llama_new_context_with_model() returned null)"); |
|
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), |
|
"llama_new_context_with_model() returned null)"); |
|
return 0; |
|
} |
|
|
|
return reinterpret_cast<jlong>(context); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { |
|
llama_free(reinterpret_cast<llama_context *>(context)); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { |
|
llama_backend_free(); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { |
|
llama_log_set(log_callback, NULL); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jstring JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_bench_1model( |
|
JNIEnv *env, |
|
jobject, |
|
jlong context_pointer, |
|
jlong model_pointer, |
|
jlong batch_pointer, |
|
jint pp, |
|
jint tg, |
|
jint pl, |
|
jint nr |
|
) { |
|
auto pp_avg = 0.0; |
|
auto tg_avg = 0.0; |
|
auto pp_std = 0.0; |
|
auto tg_std = 0.0; |
|
|
|
const auto context = reinterpret_cast<llama_context *>(context_pointer); |
|
const auto model = reinterpret_cast<llama_model *>(model_pointer); |
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); |
|
|
|
const int n_ctx = llama_n_ctx(context); |
|
|
|
LOGi("n_ctx = %d", n_ctx); |
|
|
|
int i, j; |
|
int nri; |
|
for (nri = 0; nri < nr; nri++) { |
|
LOGi("Benchmark prompt processing (pp)"); |
|
|
|
common_batch_clear(*batch); |
|
|
|
const int n_tokens = pp; |
|
for (i = 0; i < n_tokens; i++) { |
|
common_batch_add(*batch, 0, i, { 0 }, false); |
|
} |
|
|
|
batch->logits[batch->n_tokens - 1] = true; |
|
llama_kv_cache_clear(context); |
|
|
|
const auto t_pp_start = ggml_time_us(); |
|
if (llama_decode(context, *batch) != 0) { |
|
LOGi("llama_decode() failed during prompt processing"); |
|
} |
|
const auto t_pp_end = ggml_time_us(); |
|
|
|
|
|
|
|
LOGi("Benchmark text generation (tg)"); |
|
|
|
llama_kv_cache_clear(context); |
|
const auto t_tg_start = ggml_time_us(); |
|
for (i = 0; i < tg; i++) { |
|
|
|
common_batch_clear(*batch); |
|
for (j = 0; j < pl; j++) { |
|
common_batch_add(*batch, 0, i, { j }, true); |
|
} |
|
|
|
LOGi("llama_decode() text generation: %d", i); |
|
if (llama_decode(context, *batch) != 0) { |
|
LOGi("llama_decode() failed during text generation"); |
|
} |
|
} |
|
|
|
const auto t_tg_end = ggml_time_us(); |
|
|
|
llama_kv_cache_clear(context); |
|
|
|
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; |
|
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; |
|
|
|
const auto speed_pp = double(pp) / t_pp; |
|
const auto speed_tg = double(pl * tg) / t_tg; |
|
|
|
pp_avg += speed_pp; |
|
tg_avg += speed_tg; |
|
|
|
pp_std += speed_pp * speed_pp; |
|
tg_std += speed_tg * speed_tg; |
|
|
|
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); |
|
} |
|
|
|
pp_avg /= double(nr); |
|
tg_avg /= double(nr); |
|
|
|
if (nr > 1) { |
|
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); |
|
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); |
|
} else { |
|
pp_std = 0; |
|
tg_std = 0; |
|
} |
|
|
|
char model_desc[128]; |
|
llama_model_desc(model, model_desc, sizeof(model_desc)); |
|
|
|
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; |
|
const auto model_n_params = double(llama_model_n_params(model)) / 1e9; |
|
|
|
const auto backend = "(Android)"; |
|
|
|
std::stringstream result; |
|
result << std::setprecision(2); |
|
result << "| model | size | params | backend | test | t/s |\n"; |
|
result << "| --- | --- | --- | --- | --- | --- |\n"; |
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; |
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; |
|
|
|
return env->NewStringUTF(result.str().c_str()); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jlong JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { |
|
|
|
|
|
|
|
llama_batch *batch = new llama_batch { |
|
0, |
|
nullptr, |
|
nullptr, |
|
nullptr, |
|
nullptr, |
|
nullptr, |
|
nullptr, |
|
}; |
|
|
|
if (embd) { |
|
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); |
|
} else { |
|
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); |
|
} |
|
|
|
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); |
|
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); |
|
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); |
|
for (int i = 0; i < n_tokens; ++i) { |
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); |
|
} |
|
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); |
|
|
|
return reinterpret_cast<jlong>(batch); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { |
|
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jlong JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { |
|
auto sparams = llama_sampler_chain_default_params(); |
|
sparams.no_perf = true; |
|
llama_sampler * smpl = llama_sampler_chain_init(sparams); |
|
llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); |
|
|
|
return reinterpret_cast<jlong>(smpl); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) { |
|
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer)); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { |
|
llama_backend_init(); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jstring JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { |
|
return env->NewStringUTF(llama_print_system_info()); |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jint JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_completion_1init( |
|
JNIEnv *env, |
|
jobject, |
|
jlong context_pointer, |
|
jlong batch_pointer, |
|
jstring jtext, |
|
jint n_len |
|
) { |
|
|
|
cached_token_chars.clear(); |
|
|
|
const auto text = env->GetStringUTFChars(jtext, 0); |
|
const auto context = reinterpret_cast<llama_context *>(context_pointer); |
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); |
|
|
|
const auto tokens_list = common_tokenize(context, text, 1); |
|
|
|
auto n_ctx = llama_n_ctx(context); |
|
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); |
|
|
|
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); |
|
|
|
if (n_kv_req > n_ctx) { |
|
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); |
|
} |
|
|
|
for (auto id : tokens_list) { |
|
LOGi("%s", common_token_to_piece(context, id).c_str()); |
|
} |
|
|
|
common_batch_clear(*batch); |
|
|
|
|
|
for (auto i = 0; i < tokens_list.size(); i++) { |
|
common_batch_add(*batch, tokens_list[i], i, { 0 }, false); |
|
} |
|
|
|
|
|
batch->logits[batch->n_tokens - 1] = true; |
|
|
|
if (llama_decode(context, *batch) != 0) { |
|
LOGe("llama_decode() failed"); |
|
} |
|
|
|
env->ReleaseStringUTFChars(jtext, text); |
|
|
|
return batch->n_tokens; |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT jstring JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_completion_1loop( |
|
JNIEnv * env, |
|
jobject, |
|
jlong context_pointer, |
|
jlong batch_pointer, |
|
jlong sampler_pointer, |
|
jint n_len, |
|
jobject intvar_ncur |
|
) { |
|
const auto context = reinterpret_cast<llama_context *>(context_pointer); |
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); |
|
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); |
|
const auto model = llama_get_model(context); |
|
|
|
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); |
|
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); |
|
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); |
|
|
|
|
|
const auto new_token_id = llama_sampler_sample(sampler, context, -1); |
|
|
|
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); |
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { |
|
return nullptr; |
|
} |
|
|
|
auto new_token_chars = common_token_to_piece(context, new_token_id); |
|
cached_token_chars += new_token_chars; |
|
|
|
jstring new_token = nullptr; |
|
if (is_valid_utf8(cached_token_chars.c_str())) { |
|
new_token = env->NewStringUTF(cached_token_chars.c_str()); |
|
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); |
|
cached_token_chars.clear(); |
|
} else { |
|
new_token = env->NewStringUTF(""); |
|
} |
|
|
|
common_batch_clear(*batch); |
|
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); |
|
|
|
env->CallVoidMethod(intvar_ncur, la_int_var_inc); |
|
|
|
if (llama_decode(context, *batch) != 0) { |
|
LOGe("llama_decode() returned null"); |
|
} |
|
|
|
return new_token; |
|
} |
|
|
|
extern "C" |
|
JNIEXPORT void JNICALL |
|
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { |
|
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); |
|
} |
|
|