mateoluksenberg commited on
Commit
41e4c1f
·
verified ·
1 Parent(s): c002c58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -217,11 +217,11 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
217
  try:
218
  model = AutoModelForCausalLM.from_pretrained(
219
  MODEL_ID,
220
- torch_dtype=torch.bfloat16,
221
  low_cpu_mem_usage=True,
222
  trust_remote_code=True
223
  )
224
-
225
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
226
 
227
  conversation = []
@@ -230,14 +230,14 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
230
  if "file_content" in message and message["file_content"]:
231
  file_content = message["file_content"]
232
  file_name = message["file_name"]
233
-
234
  # Guardar el archivo en un archivo temporal
235
  with open(file_name, "wb") as f:
236
  f.write(file_content.read())
237
-
238
  # Llamar a `mode_load` con el nombre del archivo
239
  choice, contents = mode_load(file_name)
240
-
241
  if choice == "image":
242
  conversation.append({"role": "user", "image": contents, "content": message['text']})
243
  elif choice == "doc":
@@ -267,31 +267,31 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
267
  eos_token_id=[151329, 151336, 151338],
268
  )
269
 
270
- # gen_kwargs = {**input_ids, **generate_kwargs}
271
 
272
- # with torch.no_grad():
273
- # thread = Thread(target=model.generate, kwargs=gen_kwargs)
274
- # thread.start()
275
- # buffer = ""
276
- # for new_text in streamer:
277
- # buffer += new_text
278
- # yield buffer
279
 
280
- # print("--------------")
281
- # print("Buffer: ")
282
- # print(" ")
283
- # print(buffer)
284
- # print(" ")
285
- # print("--------------")
286
-
287
- with torch.no_grad():
288
- generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
289
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
290
 
291
- #generated_text = buffer
 
 
 
 
 
 
 
 
 
 
292
 
 
293
 
294
- return PlainTextResponse(generated_text)
295
  except Exception as e:
296
  return PlainTextResponse(f"Error: {str(e)}")
297
 
 
217
  try:
218
  model = AutoModelForCausalLM.from_pretrained(
219
  MODEL_ID,
220
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
221
  low_cpu_mem_usage=True,
222
  trust_remote_code=True
223
  )
224
+
225
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
226
 
227
  conversation = []
 
230
  if "file_content" in message and message["file_content"]:
231
  file_content = message["file_content"]
232
  file_name = message["file_name"]
233
+
234
  # Guardar el archivo en un archivo temporal
235
  with open(file_name, "wb") as f:
236
  f.write(file_content.read())
237
+
238
  # Llamar a `mode_load` con el nombre del archivo
239
  choice, contents = mode_load(file_name)
240
+
241
  if choice == "image":
242
  conversation.append({"role": "user", "image": contents, "content": message['text']})
243
  elif choice == "doc":
 
267
  eos_token_id=[151329, 151336, 151338],
268
  )
269
 
270
+ gen_kwargs = {**input_ids, **generate_kwargs}
271
 
272
+ # Define the function to run generation
273
+ def generate_text():
274
+ with torch.no_grad():
275
+ model.generate(**gen_kwargs, streamer=streamer)
 
 
 
276
 
277
+ # Start the generation in a separate thread
278
+ thread = Thread(target=generate_text)
279
+ thread.start()
 
 
 
 
 
 
 
280
 
281
+ def stream_response():
282
+ buffer = ""
283
+ for new_text in streamer:
284
+ buffer += new_text
285
+ yield new_text
286
+ print("--------------")
287
+ print("Buffer: ")
288
+ print(" ")
289
+ print(buffer)
290
+ print(" ")
291
+ print("--------------")
292
 
293
+ return StreamingResponse(stream_response(), media_type="text/plain")
294
 
 
295
  except Exception as e:
296
  return PlainTextResponse(f"Error: {str(e)}")
297