mateoluksenberg commited on
Commit
c1be641
·
verified ·
1 Parent(s): d4978b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -87
app.py CHANGED
@@ -117,37 +117,6 @@ def mode_load(path):
117
  else:
118
  raise gr.Error("Oops, unsupported files.")
119
 
120
- # def mode_load(file_obj):
121
- # try:
122
- # file_obj.seek(0) # Asegúrate de que el puntero esté al inicio del archivo
123
-
124
- # # Detecta el tipo de archivo basándote en los primeros bytes si es posible
125
- # file_header = file_obj.read(4)
126
- # file_obj.seek(0) # Vuelve al inicio del archivo para procesamiento completo
127
-
128
- # if file_header.startswith(b'%PDF'):
129
- # content = extract_pdf(file_obj)
130
- # choice = "doc"
131
- # elif file_obj.name.endswith(".docx"):
132
- # content = extract_docx(file_obj)
133
- # choice = "doc"
134
- # elif file_obj.name.endswith(".pptx"):
135
- # content = extract_pptx(file_obj)
136
- # choice = "doc"
137
- # elif file_obj.name.endswith(".txt") or file_obj.name.endswith(".py") or file_obj.name.endswith(".json") or file_obj.name.endswith(".cpp") or file_obj.name.endswith(".md"):
138
- # content = file_obj.read().decode('utf-8', errors='ignore')
139
- # choice = "doc"
140
- # elif file_obj.name.endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
141
- # content = Image.open(file_obj).convert('RGB')
142
- # choice = "image"
143
- # else:
144
- # raise ValueError("Unsupported file type.")
145
-
146
- # return choice, content
147
-
148
- # except Exception as e:
149
- # raise ValueError(f"Error processing file: {str(e)}")
150
-
151
 
152
  @spaces.GPU()
153
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
@@ -243,62 +212,6 @@ EXAMPLES = [
243
  ]
244
 
245
 
246
- # Definir la función simple_chat
247
- # @spaces.GPU()
248
- # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
249
- # try:
250
- # model = AutoModelForCausalLM.from_pretrained(
251
- # MODEL_ID,
252
- # torch_dtype=torch.bfloat16,
253
- # low_cpu_mem_usage=True,
254
- # trust_remote_code=True
255
- # )
256
-
257
- # conversation = []
258
-
259
- # if "file" in message and message["file"]:
260
- # file_path = message["file"]
261
- # choice, contents = mode_load(file_path)
262
- # if choice == "image":
263
- # conversation.append({"role": "user", "image": contents, "content": message["text"]})
264
- # elif choice == "doc":
265
- # format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
266
- # conversation.append({"role": "user", "content": format_msg})
267
- # else:
268
- # conversation.append({"role": "user", "content": message["text"]})
269
-
270
- # input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
271
-
272
- # generate_kwargs = dict(
273
- # max_length=max_length,
274
- # do_sample=True,
275
- # top_p=top_p,
276
- # top_k=top_k,
277
- # temperature=temperature,
278
- # repetition_penalty=penalty,
279
- # eos_token_id=[151329, 151336, 151338],
280
- # )
281
-
282
- # with torch.no_grad():
283
- # generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
284
- # generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
285
-
286
- # return PlainTextResponse(generated_text)
287
- # except Exception as e:
288
- # return PlainTextResponse(f"Error: {str(e)}")
289
-
290
- # @app.post("/chat/")
291
- # async def test_endpoint(message: dict):
292
- # if "text" not in message:
293
- # raise HTTPException(status_code=400, detail="Missing 'text' in request body")
294
-
295
- # if "file" not in message:
296
- # print("Sin File")
297
-
298
- # response = simple_chat(message)
299
- # return response
300
-
301
-
302
  @spaces.GPU()
303
  def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
304
  try:
@@ -342,6 +255,8 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
342
 
343
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
344
 
 
 
345
  generate_kwargs = dict(
346
  max_length=max_length,
347
  do_sample=True,
@@ -352,6 +267,23 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
352
  eos_token_id=[151329, 151336, 151338],
353
  )
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  with torch.no_grad():
356
  generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
357
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
@@ -360,6 +292,67 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
360
  return PlainTextResponse(generated_text)
361
  except Exception as e:
362
  return PlainTextResponse(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
 
365
 
 
117
  else:
118
  raise gr.Error("Oops, unsupported files.")
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @spaces.GPU()
122
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
 
212
  ]
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  @spaces.GPU()
216
  def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
217
  try:
 
255
 
256
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
257
 
258
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
259
+
260
  generate_kwargs = dict(
261
  max_length=max_length,
262
  do_sample=True,
 
267
  eos_token_id=[151329, 151336, 151338],
268
  )
269
 
270
+ gen_kwargs = {**input_ids, **generate_kwargs}
271
+
272
+ with torch.no_grad():
273
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
274
+ thread.start()
275
+ buffer = ""
276
+ for new_text in streamer:
277
+ buffer += new_text
278
+ yield buffer
279
+
280
+ print("--------------")
281
+ print("Buffer: ")
282
+ print(" ")
283
+ print(buffer)
284
+ print(" ")
285
+ print("--------------")
286
+
287
  with torch.no_grad():
288
  generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
289
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
292
  return PlainTextResponse(generated_text)
293
  except Exception as e:
294
  return PlainTextResponse(f"Error: {str(e)}")
295
+
296
+ # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
297
+ # try:
298
+ # model = AutoModelForCausalLM.from_pretrained(
299
+ # MODEL_ID,
300
+ # torch_dtype=torch.bfloat16,
301
+ # low_cpu_mem_usage=True,
302
+ # trust_remote_code=True
303
+ # )
304
+
305
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
306
+
307
+ # conversation = []
308
+
309
+ # # Acceder al contenido del archivo y al nombre del archivo
310
+ # if "file_content" in message and message["file_content"]:
311
+ # file_content = message["file_content"]
312
+ # file_name = message["file_name"]
313
+
314
+ # # Guardar el archivo en un archivo temporal
315
+ # with open(file_name, "wb") as f:
316
+ # f.write(file_content.read())
317
+
318
+ # # Llamar a `mode_load` con el nombre del archivo
319
+ # choice, contents = mode_load(file_name)
320
+
321
+ # if choice == "image":
322
+ # conversation.append({"role": "user", "image": contents, "content": message['text']})
323
+ # elif choice == "doc":
324
+ # format_msg = contents + "\n\n\n" + "{} files uploaded.\n".format(1) + message['text']
325
+ # conversation.append({"role": "user", "content": format_msg})
326
+ # else:
327
+ # # Manejar caso donde no se sube archivo
328
+ # conversation.append({"role": "user", "content": message['text']})
329
+
330
+ # print("--------------")
331
+ # print(" ")
332
+ # print(conversation)
333
+ # print(" ")
334
+ # print("--------------")
335
+
336
+ # input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
337
+
338
+ # generate_kwargs = dict(
339
+ # max_length=max_length,
340
+ # do_sample=True,
341
+ # top_p=top_p,
342
+ # top_k=top_k,
343
+ # temperature=temperature,
344
+ # repetition_penalty=penalty,
345
+ # eos_token_id=[151329, 151336, 151338],
346
+ # )
347
+
348
+ # with torch.no_grad():
349
+ # generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
350
+ # generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
351
+
352
+
353
+ # return PlainTextResponse(generated_text)
354
+ # except Exception as e:
355
+ # return PlainTextResponse(f"Error: {str(e)}")
356
 
357
 
358