Spaces:
Running
Running
import torch | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import os | |
import pymupdf | |
import docx | |
from pptx import Presentation | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from typing import List, Dict | |
app = FastAPI() | |
# Model and tokenizer initialization | |
MODEL_LIST = ["nikravan/glm-4vq"] | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL_ID = MODEL_LIST[0] | |
MODEL_NAME = "GLM-4vq" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
def extract_text(path): | |
return open(path, 'r').read() | |
def extract_pdf(path): | |
doc = pymupdf.open(path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
def extract_docx(path): | |
doc = docx.Document(path) | |
data = [paragraph.text for paragraph in doc.paragraphs] | |
return '\n\n'.join(data) | |
def extract_pptx(path): | |
prs = Presentation(path) | |
text = "" | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text + "\n" | |
return text | |
def mode_load(path): | |
file_type = path.split(".")[-1].lower() | |
if file_type in ["pdf", "txt", "py", "docx", "pptx"]: | |
if file_type == "pdf": | |
content = extract_pdf(path) | |
elif file_type == "docx": | |
content = extract_docx(path) | |
elif file_type == "pptx": | |
content = extract_pptx(path) | |
else: | |
content = extract_text(path) | |
return "doc", content[:5000] | |
elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]: | |
content = Image.open(path).convert('RGB') | |
return "image", content | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported file type") | |
async def test_endpoint(message: Dict[str, str]): | |
if "text" not in message: | |
raise HTTPException(status_code=400, detail="Missing 'text' in request body") | |
response = {"message": f"Received your message: {message['text']}"} | |
return response | |
async def chat_endpoint( | |
message: Dict[str, str], | |
history: List[Dict[str, str]] = [], | |
temperature: float = 0.8, | |
max_length: int = 4096, | |
top_p: float = 1.0, | |
top_k: int = 10, | |
penalty: float = 1.0 | |
): | |
conversation = [] | |
if "files" in message and message["files"]: | |
choice, contents = mode_load(message["files"][-1]) | |
if choice == "image": | |
conversation.append({"role": "user", "image": contents, "content": message['text']}) | |
elif choice == "doc": | |
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text'] | |
conversation.append({"role": "user", "content": format_msg}) | |
else: | |
if len(history) == 0: | |
conversation.append({"role": "user", "content": message['text']}) | |
else: | |
for prompt, answer in history: | |
if answer is None: | |
conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]) | |
else: | |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
if len(history) > 0: | |
choice, contents = mode_load(history[-1][0]) | |
if choice == "image": | |
conversation.append({"role": "user", "image": contents, "content": message['text']}) | |
elif choice == "doc": | |
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text'] | |
conversation.append({"role": "user", "content": format_msg}) | |
else: | |
conversation.append({"role": "user", "content": message['text']}) | |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
max_length=max_length, | |
streamer=streamer, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=penalty | |
) | |
with torch.no_grad(): | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
return {"response": buffer} | |