mateoluksenberg commited on
Commit
8326f1d
·
verified ·
1 Parent(s): 877632c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -211,7 +211,7 @@ EXAMPLES = [
211
 
212
  # Definir la función simple_chat
213
  @spaces.GPU()
214
- def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
215
  # Cargar el modelo preentrenado
216
  model = AutoModelForCausalLM.from_pretrained(
217
  MODEL_ID,
@@ -219,7 +219,7 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
219
  low_cpu_mem_usage=True,
220
  trust_remote_code=True
221
  )
222
-
223
  conversation = []
224
 
225
  if "file" in message and message["file"]:
@@ -236,12 +236,10 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
236
  # Preparar entrada para el modelo
237
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
238
  return_tensors="pt", return_dict=True).to(model.device)
239
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
240
-
241
  # Configurar parámetros de generación
242
  generate_kwargs = dict(
243
  max_length=max_length,
244
- streamer=streamer,
245
  do_sample=True,
246
  top_p=top_p,
247
  top_k=top_k,
@@ -249,19 +247,13 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
249
  repetition_penalty=penalty,
250
  eos_token_id=[151329, 151336, 151338],
251
  )
252
- gen_kwargs = {**input_ids, **generate_kwargs}
253
 
254
- # Generar respuesta de manera asíncrona
255
- def generate():
256
- with torch.no_grad():
257
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
258
- thread.start()
259
- buffer = ""
260
- for new_text in streamer:
261
- buffer += new_text
262
- yield buffer.encode('utf-8')
263
-
264
- #return StreamingResponse(generate(), media_type="text/plain")
265
  return PlainTextResponse(generated_text)
266
 
267
  @app.post("/chat/")
 
211
 
212
  # Definir la función simple_chat
213
  @spaces.GPU()
214
+ async def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
215
  # Cargar el modelo preentrenado
216
  model = AutoModelForCausalLM.from_pretrained(
217
  MODEL_ID,
 
219
  low_cpu_mem_usage=True,
220
  trust_remote_code=True
221
  )
222
+
223
  conversation = []
224
 
225
  if "file" in message and message["file"]:
 
236
  # Preparar entrada para el modelo
237
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
238
  return_tensors="pt", return_dict=True).to(model.device)
239
+
 
240
  # Configurar parámetros de generación
241
  generate_kwargs = dict(
242
  max_length=max_length,
 
243
  do_sample=True,
244
  top_p=top_p,
245
  top_k=top_k,
 
247
  repetition_penalty=penalty,
248
  eos_token_id=[151329, 151336, 151338],
249
  )
 
250
 
251
+ # Generar respuesta
252
+ with torch.no_grad():
253
+ generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
254
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
255
+
256
+ # Devolver la respuesta completa
 
 
 
 
 
257
  return PlainTextResponse(generated_text)
258
 
259
  @app.post("/chat/")