mateoluksenberg commited on
Commit
d96f949
·
verified ·
1 Parent(s): 33432bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -47
app.py CHANGED
@@ -27,53 +27,6 @@ async def test_endpoint(message: dict):
27
  return response
28
 
29
 
30
- @app.post("/chat/")
31
- async def chat_endpoint(message: dict):
32
- if "text" not in message:
33
- raise HTTPException(status_code=400, detail="Missing 'text' in request body")
34
-
35
- chat_message = message["text"]
36
- response_text = generate_chat_response(chat_message)
37
-
38
- return {"response": response_text}
39
-
40
- def generate_chat_response(text: str):
41
- model = AutoModelForCausalLM.from_pretrained(
42
- MODEL_ID,
43
- torch_dtype=torch.bfloat16,
44
- low_cpu_mem_usage=True,
45
- trust_remote_code=True
46
- )
47
-
48
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
49
-
50
- conversation = [{"role": "user", "content": text}]
51
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
52
- return_tensors="pt", return_dict=True).to(model.device)
53
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
54
-
55
- generate_kwargs = dict(
56
- max_length=4096,
57
- streamer=streamer,
58
- do_sample=True,
59
- top_p=0.9,
60
- top_k=50,
61
- temperature=0.7,
62
- repetition_penalty=1.0,
63
- eos_token_id=[151329, 151336, 151338],
64
- )
65
- gen_kwargs = {**input_ids, **generate_kwargs}
66
-
67
- with torch.no_grad():
68
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
69
- thread.start()
70
- buffer = ""
71
- for new_text in streamer:
72
- buffer += new_text
73
-
74
- return buffer
75
-
76
-
77
  MODEL_LIST = ["nikravan/glm-4vq"]
78
 
79
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -252,6 +205,74 @@ EXAMPLES = [
252
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
253
  ]
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
256
  gr.HTML(TITLE)
257
  gr.HTML(DESCRIPTION)
 
27
  return response
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  MODEL_LIST = ["nikravan/glm-4vq"]
31
 
32
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
205
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
206
  ]
207
 
208
+ # Definir la estructura del mensaje utilizando Pydantic
209
+ class Message(BaseModel):
210
+ text: str
211
+ file: Optional[UploadFile] = None
212
+
213
+ # Definir la función simple_chat
214
+ def simple_chat(message: Message, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
215
+ # Cargar el modelo preentrenado
216
+ model = AutoModelForCausalLM.from_pretrained(
217
+ MODEL_ID,
218
+ torch_dtype=torch.bfloat16,
219
+ low_cpu_mem_usage=True,
220
+ trust_remote_code=True
221
+ )
222
+
223
+ conversation = []
224
+
225
+ # Procesar el mensaje
226
+ if message.file:
227
+ file_contents = message.file.file.read()
228
+ # Aquí deberías procesar el archivo como corresponda, por ejemplo:
229
+ # choice, contents = mode_load(file_contents)
230
+ # Por ahora solo agregaremos un marcador de posición
231
+ choice = "doc"
232
+ contents = "Contenido del archivo"
233
+ if choice == "image":
234
+ conversation.append({"role": "user", "image": contents, "content": message.text})
235
+ elif choice == "doc":
236
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message.text
237
+ conversation.append({"role": "user", "content": format_msg})
238
+ else:
239
+ conversation.append({"role": "user", "content": message.text})
240
+
241
+ # Preparar entrada para el modelo
242
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
243
+ return_tensors="pt", return_dict=True).to(model.device)
244
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
245
+
246
+ # Configurar parámetros de generación
247
+ generate_kwargs = dict(
248
+ max_length=max_length,
249
+ streamer=streamer,
250
+ do_sample=True,
251
+ top_p=top_p,
252
+ top_k=top_k,
253
+ temperature=temperature,
254
+ repetition_penalty=penalty,
255
+ eos_token_id=[151329, 151336, 151338],
256
+ )
257
+ gen_kwargs = {**input_ids, **generate_kwargs}
258
+
259
+ # Generar respuesta de manera asíncrona
260
+ def generate():
261
+ with torch.no_grad():
262
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
263
+ thread.start()
264
+ buffer = ""
265
+ for new_text in streamer:
266
+ buffer += new_text
267
+ yield buffer.encode('utf-8')
268
+
269
+ return StreamingResponse(generate(), media_type="text/plain")
270
+
271
+ # Definir la ruta en FastAPI
272
+ @app.post("/chat")
273
+ async def chat(message: Message):
274
+ return simple_chat(message)
275
+
276
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
277
  gr.HTML(TITLE)
278
  gr.HTML(DESCRIPTION)