mateoluksenberg commited on
Commit
695706a
·
verified ·
1 Parent(s): 968c4c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -60
app.py CHANGED
@@ -209,68 +209,68 @@ EXAMPLES = [
209
 
210
 
211
  # Definir la función simple_chat
212
- # @spaces.GPU()
213
- # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
214
- # # Cargar el modelo preentrenado
215
- # model = AutoModelForCausalLM.from_pretrained(
216
- # MODEL_ID,
217
- # torch_dtype=torch.bfloat16,
218
- # low_cpu_mem_usage=True,
219
- # trust_remote_code=True
220
- # )
221
 
222
- # conversation = []
223
-
224
- # if "file" in message and message["file"]:
225
- # file_path = message["file"]
226
- # choice, contents = mode_load(file_path)
227
- # if choice == "image":
228
- # conversation.append({"role": "user", "image": contents, "content": message["text"]})
229
- # elif choice == "doc":
230
- # format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
231
- # conversation.append({"role": "user", "content": format_msg})
232
- # else:
233
- # conversation.append({"role": "user", "content": message["text"]})
234
-
235
- # # Preparar entrada para el modelo
236
- # input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
237
- # return_tensors="pt", return_dict=True).to(model.device)
238
- # streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
239
-
240
- # # Configurar parámetros de generación
241
- # generate_kwargs = dict(
242
- # max_length=max_length,
243
- # streamer=streamer,
244
- # do_sample=True,
245
- # top_p=top_p,
246
- # top_k=top_k,
247
- # temperature=temperature,
248
- # repetition_penalty=penalty,
249
- # eos_token_id=[151329, 151336, 151338],
250
- # )
251
- # gen_kwargs = {**input_ids, **generate_kwargs}
252
-
253
- # # Generar respuesta de manera asíncrona
254
- # def generate():
255
- # with torch.no_grad():
256
- # thread = Thread(target=model.generate, kwargs=gen_kwargs)
257
- # thread.start()
258
- # buffer = ""
259
- # for new_text in streamer:
260
- # buffer += new_text
261
- # yield buffer.encode('utf-8')
262
-
263
- # return StreamingResponse(generate(), media_type="text/plain")
264
-
265
- # @app.post("/chat/")
266
- # async def test_endpoint(message: dict):
267
- # if "text" not in message:
268
- # raise HTTPException(status_code=400, detail="Missing 'text' in request body")
269
-
270
- # if "file" not in message:
271
- # print("Sin File")
272
 
273
- # return simple_chat(message)
274
 
275
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
276
  gr.HTML(TITLE)
 
209
 
210
 
211
  # Definir la función simple_chat
212
+ @spaces.GPU()
213
+ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
214
+ # Cargar el modelo preentrenado
215
+ model = AutoModelForCausalLM.from_pretrained(
216
+ MODEL_ID,
217
+ torch_dtype=torch.bfloat16,
218
+ low_cpu_mem_usage=True,
219
+ trust_remote_code=True
220
+ )
221
 
222
+ conversation = []
223
+
224
+ if "file" in message and message["file"]:
225
+ file_path = message["file"]
226
+ choice, contents = mode_load(file_path)
227
+ if choice == "image":
228
+ conversation.append({"role": "user", "image": contents, "content": message["text"]})
229
+ elif choice == "doc":
230
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
231
+ conversation.append({"role": "user", "content": format_msg})
232
+ else:
233
+ conversation.append({"role": "user", "content": message["text"]})
234
+
235
+ # Preparar entrada para el modelo
236
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
237
+ return_tensors="pt", return_dict=True).to(model.device)
238
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
239
+
240
+ # Configurar parámetros de generación
241
+ generate_kwargs = dict(
242
+ max_length=max_length,
243
+ streamer=streamer,
244
+ do_sample=True,
245
+ top_p=top_p,
246
+ top_k=top_k,
247
+ temperature=temperature,
248
+ repetition_penalty=penalty,
249
+ eos_token_id=[151329, 151336, 151338],
250
+ )
251
+ gen_kwargs = {**input_ids, **generate_kwargs}
252
+
253
+ # Generar respuesta de manera asíncrona
254
+ def generate():
255
+ with torch.no_grad():
256
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
257
+ thread.start()
258
+ buffer = ""
259
+ for new_text in streamer:
260
+ buffer += new_text
261
+ yield buffer.encode('utf-8')
262
+
263
+ return StreamingResponse(generate(), media_type="text/plain")
264
+
265
+ @app.post("/chat/")
266
+ async def test_endpoint(message: dict):
267
+ if "text" not in message:
268
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
269
+
270
+ if "file" not in message:
271
+ print("Sin File")
272
 
273
+ return simple_chat(message)
274
 
275
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
276
  gr.HTML(TITLE)