chat / app.py
mateoluksenberg's picture
Update app.py
2a0024c verified
raw
history blame
15.4 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
from pydantic import BaseModel
from typing import Optional
import io
import pymupdf
import docx
from pptx import Presentation
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import StreamingResponse
from fastapi.responses import PlainTextResponse
import uvicorn
app = FastAPI()
@app.post("/test/")
async def test_endpoint(message: dict):
if "text" not in message:
raise HTTPException(status_code=400, detail="Missing 'text' in request body")
response = {"message": f"Received your message: {message['text']}"}
return response
MODEL_LIST = ["nikravan/glm-4vq"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4vq"
TITLE = "<h1>AI CHAT DOCS</h1>"
DESCRIPTION = f"""
<center>
<p>
<br>
USANDO MODELO: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
</center>"""
CSS = """
h1 {
text-align: center;
display: block;
}
"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
def extract_text(path):
return open(path, 'r').read()
def extract_pdf(path):
doc = pymupdf.open(path)
text = ""
for page in doc:
text += page.get_text()
return text
def extract_docx(path):
doc = docx.Document(path)
data = []
for paragraph in doc.paragraphs:
data.append(paragraph.text)
content = '\n\n'.join(data)
return content
def extract_pptx(path):
prs = Presentation(path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
# def mode_load(path):
# choice = ""
# file_type = path.split(".")[-1]
# print(file_type)
# if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
# if file_type.endswith("pdf"):
# content = extract_pdf(path)
# elif file_type.endswith("docx"):
# content = extract_docx(path)
# elif file_type.endswith("pptx"):
# content = extract_pptx(path)
# else:
# content = extract_text(path)
# choice = "doc"
# print(content[:100])
# return choice, content[:5000]
# elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
# content = Image.open(path).convert('RGB')
# choice = "image"
# return choice, content
# else:
# raise gr.Error("Oops, unsupported files.")
def mode_load(file_obj):
# Intenta detectar el tipo de archivo basado en su contenido
try:
file_obj.seek(0) # Asegúrate de que el puntero esté al inicio del archivo
# Verifica si es PDF
if file_obj.read(4) == b'%PDF':
file_obj.seek(0) # Vuelve al inicio del archivo para procesar
content = extract_pdf(file_obj)
choice = "doc"
elif file_obj.name.endswith(".docx"):
file_obj.seek(0)
content = extract_docx(file_obj)
choice = "doc"
elif file_obj.name.endswith(".pptx"):
file_obj.seek(0)
content = extract_pptx(file_obj)
choice = "doc"
elif file_obj.name.endswith(".txt"):
file_obj.seek(0)
content = file_obj.read().decode('utf-8', errors='ignore')
choice = "doc"
elif file_obj.name.endswith(".py"):
file_obj.seek(0)
content = file_obj.read().decode('utf-8', errors='ignore')
choice = "doc"
elif file_obj.name.endswith(".json"):
file_obj.seek(0)
content = file_obj.read().decode('utf-8', errors='ignore')
choice = "doc"
elif file_obj.name.endswith(".cpp"):
file_obj.seek(0)
content = file_obj.read().decode('utf-8', errors='ignore')
choice = "doc"
elif file_obj.name.endswith(".md"):
file_obj.seek(0)
content = file_obj.read().decode('utf-8', errors='ignore')
choice = "doc"
elif file_obj.name.endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
file_obj.seek(0)
content = Image.open(file_obj).convert('RGB')
choice = "image"
else:
raise ValueError("Unsupported file type.")
return choice, content
except Exception as e:
raise ValueError(f"Error processing file: {str(e)}")
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
)
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
prompt_files = []
if message["files"]:
choice, contents = mode_load(message["files"][-1])
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
else:
if len(history) == 0:
# raise gr.Error("Please upload an image first.")
contents = None
conversation.append({"role": "user", "content": message['text']})
else:
# image = Image.open(history[0][0][0])
for prompt, answer in history:
if answer is None:
prompt_files.append(prompt[0])
conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
if len(prompt_files) > 0:
choice, contents = mode_load(prompt_files[-1])
else:
choice = ""
conversation.append({"role": "user", "image": "", "content": message['text']})
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(
#rtl=True,
)
chat_input = gr.MultimodalTextbox(
interactive=True,
placeholder="Enter message or upload a file ...",
show_label=False,
#rtl=True,
)
EXAMPLES = [
[{"text": "Resumir Documento"}],
[{"text": "Explicar la Imagen"}],
[{"text": "¿De qué es la foto?", "files": ["perro.jpg"]}],
[{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
]
# Definir la función simple_chat
# @spaces.GPU()
# def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
# try:
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# trust_remote_code=True
# )
# conversation = []
# if "file" in message and message["file"]:
# file_path = message["file"]
# choice, contents = mode_load(file_path)
# if choice == "image":
# conversation.append({"role": "user", "image": contents, "content": message["text"]})
# elif choice == "doc":
# format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
# conversation.append({"role": "user", "content": format_msg})
# else:
# conversation.append({"role": "user", "content": message["text"]})
# input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
# generate_kwargs = dict(
# max_length=max_length,
# do_sample=True,
# top_p=top_p,
# top_k=top_k,
# temperature=temperature,
# repetition_penalty=penalty,
# eos_token_id=[151329, 151336, 151338],
# )
# with torch.no_grad():
# generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
# generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# return PlainTextResponse(generated_text)
# except Exception as e:
# return PlainTextResponse(f"Error: {str(e)}")
# @app.post("/chat/")
# async def test_endpoint(message: dict):
# if "text" not in message:
# raise HTTPException(status_code=400, detail="Missing 'text' in request body")
# if "file" not in message:
# print("Sin File")
# response = simple_chat(message)
# return response
@spaces.GPU()
def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
)
conversation = []
if "file" in message and message["file"]:
# Lee el contenido del archivo como bytes
file_contents = io.BytesIO(message["file"]).read()
# Convierte los bytes a una cadena si `mode_load` espera texto
file_contents_str = file_contents.decode('utf-8', errors='ignore')
choice, contents = mode_load(file_contents_str)
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message["text"]})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n".format("1") + message["text"]
conversation.append({"role": "user", "content": format_msg})
else:
conversation.append({"role": "user", "content": message["text"]})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generate_kwargs = dict(
max_length=max_length,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
with torch.no_grad():
generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return PlainTextResponse(generated_text)
except Exception as e:
return PlainTextResponse(f"Error: {str(e)}")
@app.post("/chat/")
async def test_endpoint(
text: str = Form(...),
file: Optional[UploadFile] = File(None)
):
if not text:
raise HTTPException(status_code=400, detail="Missing 'text' in request body")
# Lee el archivo si está presente
file_contents = None
if file:
file_contents = await file.read()
# Construye el diccionario para `simple_chat`
message = {
"text": text,
"file": file_contents
}
print(message)
# Llama a `simple_chat` con el diccionario
response = simple_chat(message)
return response
with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=8192,
step=1,
value=4096,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=10,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
),
gr.Examples(EXAMPLES, [chat_input])
if __name__ == "__main__":
app = gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860)
#app.mount("/static", StaticFiles(directory="static", html=True), name="static")
# app = gr.mount_gradio_app(app, block, "/", gradio_api_url="http://localhost:7860/")
# uvicorn.run(app, host="0.0.0.0", port=7860)
demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )