Akjava commited on
Commit
9c2d729
·
verified ·
1 Parent(s): 4744ae9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py CHANGED
@@ -60,6 +60,142 @@ description = """Gemma 3 is a family of lightweight, multimodal open models that
60
  llm = None
61
  llm_model = None
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def trans(text):
64
 
65
 
@@ -305,3 +441,4 @@ demo = gr.ChatInterface(
305
  # Launch the chat interface
306
  if __name__ == "__main__":
307
  demo.launch(debug=False)
 
 
60
  llm = None
61
  llm_model = None
62
 
63
+ import ctypes
64
+ import os
65
+ import multiprocessing
66
+
67
+ import llama_cpp
68
+
69
+ def test():
70
+
71
+
72
+ llama_cpp.llama_backend_init(numa=False)
73
+
74
+ N_THREADS = multiprocessing.cpu_count()
75
+ MODEL_PATH = os.environ.get("MODEL", "/mnt/md0/models/t5-base.gguf")
76
+
77
+ prompt = b"translate English to German: The house is wonderful."
78
+
79
+ lparams = llama_cpp.llama_model_default_params()
80
+ model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode("utf-8"), lparams)
81
+
82
+ vocab = llama_cpp.llama_model_get_vocab(model)
83
+
84
+ cparams = llama_cpp.llama_context_default_params()
85
+ cparams.no_perf = False
86
+ ctx = llama_cpp.llama_init_from_model(model, cparams)
87
+
88
+ sparams = llama_cpp.llama_sampler_chain_default_params()
89
+ smpl = llama_cpp.llama_sampler_chain_init(sparams)
90
+ llama_cpp.llama_sampler_chain_add(smpl, llama_cpp.llama_sampler_init_greedy())
91
+
92
+ n_past = 0
93
+
94
+ embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))()
95
+
96
+ n_of_tok = llama_cpp.llama_tokenize(
97
+ vocab,
98
+ prompt,
99
+ len(prompt),
100
+ embd_inp,
101
+ len(embd_inp),
102
+ True,
103
+ True,
104
+ )
105
+
106
+ embd_inp = embd_inp[:n_of_tok]
107
+
108
+ n_ctx = llama_cpp.llama_n_ctx(ctx)
109
+
110
+ n_predict = 20
111
+ n_predict = min(n_predict, n_ctx - len(embd_inp))
112
+
113
+ input_consumed = 0
114
+ input_noecho = False
115
+
116
+ remaining_tokens = n_predict
117
+
118
+ embd = []
119
+ last_n_size = 64
120
+ last_n_tokens_data = [0] * last_n_size
121
+ n_batch = 24
122
+ last_n_repeat = 64
123
+ repeat_penalty = 1
124
+ frequency_penalty = 0.0
125
+ presence_penalty = 0.0
126
+
127
+ batch = llama_cpp.llama_batch_init(n_batch, 0, 1)
128
+
129
+ # prepare batch for encoding containing the prompt
130
+ batch.n_tokens = len(embd_inp)
131
+ for i in range(batch.n_tokens):
132
+ batch.token[i] = embd_inp[i]
133
+ batch.pos[i] = i
134
+ batch.n_seq_id[i] = 1
135
+ batch.seq_id[i][0] = 0
136
+ batch.logits[i] = False
137
+
138
+ llama_cpp.llama_encode(
139
+ ctx,
140
+ batch
141
+ )
142
+
143
+ # now overwrite embd_inp so batch for decoding will initially contain only
144
+ # a single token with id acquired from llama_model_decoder_start_token(model)
145
+ embd_inp = [llama_cpp.llama_model_decoder_start_token(model)]
146
+
147
+ while remaining_tokens > 0:
148
+ if len(embd) > 0:
149
+
150
+ batch.n_tokens = len(embd)
151
+ for i in range(batch.n_tokens):
152
+ batch.token[i] = embd[i]
153
+ batch.pos[i] = n_past + i
154
+ batch.n_seq_id[i] = 1
155
+ batch.seq_id[i][0] = 0
156
+ batch.logits[i] = i == batch.n_tokens - 1
157
+
158
+ llama_cpp.llama_decode(
159
+ ctx,
160
+ batch
161
+ )
162
+
163
+ n_past += len(embd)
164
+ embd = []
165
+ if len(embd_inp) <= input_consumed:
166
+ id = llama_cpp.llama_sampler_sample(smpl, ctx, -1)
167
+
168
+ last_n_tokens_data = last_n_tokens_data[1:] + [id]
169
+ embd.append(id)
170
+ input_noecho = False
171
+ remaining_tokens -= 1
172
+ else:
173
+ while len(embd_inp) > input_consumed:
174
+ embd.append(embd_inp[input_consumed])
175
+ last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
176
+ input_consumed += 1
177
+ if len(embd) >= n_batch:
178
+ break
179
+ if not input_noecho:
180
+ for id in embd:
181
+ size = 32
182
+ buffer = (ctypes.c_char * size)()
183
+ n = llama_cpp.llama_token_to_piece(
184
+ vocab, llama_cpp.llama_token(id), buffer, size, 0, True
185
+ )
186
+ assert n <= size
187
+ print(
188
+ buffer[:n].decode("utf-8"),
189
+ end="",
190
+ flush=True,
191
+ )
192
+
193
+ if len(embd) > 0 and embd[-1] in [llama_cpp.llama_token_eos(vocab), llama_cpp.llama_token_eot(vocab)]:
194
+ break
195
+
196
+ print()
197
+
198
+
199
  def trans(text):
200
 
201
 
 
441
  # Launch the chat interface
442
  if __name__ == "__main__":
443
  demo.launch(debug=False)
444
+ test()