CosmosLLaVA / app.py
erndgn's picture
Create app.py
14deb22 verified
raw
history blame
4.59 kB
import spaces
import time
from threading import Thread
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from io import BytesIO
import requests
import os
from conversation import Conversation, SeparatorStyle
model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
disable_torch_init()
model_name = get_model_name_from_path(model_id)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_id, None, model_name
)
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif os.path.exists(image_file):
image = Image.open(image_file).convert("RGB")
else:
raise FileNotFoundError(f"Image file {image_file} not found.")
return image
def infer_single_image(model_id, image_file, prompt):
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in prompt:
if model.config.mm_use_im_start_end:
prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
else:
prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
else:
if model.config.mm_use_im_start_end:
prompt = image_token_se + "\n" + prompt
else:
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv = Conversation(
system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
version="llama3",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|eot_id|>",
)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
full_prompt = conv.get_prompt()
print("full prompt: ", full_prompt)
image = load_image(image_file)
image_tensor = process_images(
[image],
image_processor,
model.config
).to(model.device, dtype=torch.float16)
input_ids = (
tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image.size],
do_sample=False,
max_new_tokens=512,
use_cache=True,
)
output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
for hist in history:
if type(hist[0]) == tuple:
image = hist[0][0]
try:
if image is None:
gr.Error("You need to upload an image for LLaVA to work.")
except NameError:
gr.Error("You need to upload an image for LLaVA to work.")
prompt = message['text']
result = infer_single_image(model_id, image, prompt)
yield result
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="LLaVA Llama-3-8B",
examples=[{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
{"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
description="",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)