chat / app.py
mateoluksenberg's picture
Update app.py
1f1f572 verified
raw
history blame
4.58 kB
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import pymupdf
import docx
from pptx import Presentation
from fastapi import FastAPI, File, UploadFile, HTTPException
from typing import List, Dict
app = FastAPI()
# Model and tokenizer initialization
MODEL_LIST = ["nikravan/glm-4vq"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4vq"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
)
def extract_text(path):
return open(path, 'r').read()
def extract_pdf(path):
doc = pymupdf.open(path)
text = ""
for page in doc:
text += page.get_text()
return text
def extract_docx(path):
doc = docx.Document(path)
data = [paragraph.text for paragraph in doc.paragraphs]
return '\n\n'.join(data)
def extract_pptx(path):
prs = Presentation(path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
def mode_load(path):
file_type = path.split(".")[-1].lower()
if file_type in ["pdf", "txt", "py", "docx", "pptx"]:
if file_type == "pdf":
content = extract_pdf(path)
elif file_type == "docx":
content = extract_docx(path)
elif file_type == "pptx":
content = extract_pptx(path)
else:
content = extract_text(path)
return "doc", content[:5000]
elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
content = Image.open(path).convert('RGB')
return "image", content
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
@app.post("/test/")
async def test_endpoint(message: Dict[str, str]):
if "text" not in message:
raise HTTPException(status_code=400, detail="Missing 'text' in request body")
response = {"message": f"Received your message: {message['text']}"}
return response
@app.post("/chat/")
async def chat_endpoint(
message: Dict[str, str],
history: List[Dict[str, str]] = [],
temperature: float = 0.8,
max_length: int = 4096,
top_p: float = 1.0,
top_k: int = 10,
penalty: float = 1.0
):
conversation = []
if "files" in message and message["files"]:
choice, contents = mode_load(message["files"][-1])
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
else:
if len(history) == 0:
conversation.append({"role": "user", "content": message['text']})
else:
for prompt, answer in history:
if answer is None:
conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
if len(history) > 0:
choice, contents = mode_load(history[-1][0])
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
else:
conversation.append({"role": "user", "content": message['text']})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty
)
with torch.no_grad():
buffer = ""
for new_text in streamer:
buffer += new_text
return {"response": buffer}