|
import gradio as gr |
|
import spaces |
|
from threading import Thread |
|
|
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
|
from transformers import TextIteratorStreamer |
|
from PIL import Image |
|
from peft import PeftModel |
|
import requests |
|
import torch, os, re, json |
|
import time |
|
|
|
|
|
base_model = "llava-hf/llava-v1.6-mistral-7b-hf" |
|
finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k" |
|
|
|
processor = LlavaNextProcessor.from_pretrained(base_model) |
|
|
|
model = LlavaNextForConditionalGeneration.from_pretrained( |
|
base_model, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
) |
|
model = PeftModel.from_pretrained(model, finetune_repo) |
|
model.to("cuda:0") |
|
|
|
|
|
@spaces.GPU |
|
def predict(image, input_text): |
|
image = image.convert("RGB") |
|
prompt = f"[INST] <image>\n{input_text} [/INST]" |
|
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16) |
|
|
|
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True}) |
|
|
|
|
|
model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False) |
|
|
|
text_prompt = f"[INST] \n{input_text} [/INST]" |
|
|
|
buffer = "" |
|
time.sleep(0.5) |
|
for new_text in streamer: |
|
buffer += new_text |
|
generated_text_without_prompt = buffer[len(text_prompt):] |
|
time.sleep(0.04) |
|
yield generated_text_without_prompt |
|
|
|
|
|
image = gr.components.Image(type="pil") |
|
input_prompt = gr.components.Textbox(label="Input Prompt") |
|
model_output = gr.components.Textbox(label="Model Output") |
|
examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"], |
|
["./examples/bar_n01.png", "Is this chart misleading? Explain"], |
|
["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"], |
|
["./examples/line_m01.png", "Explain if this chart is misleading"], |
|
["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"], |
|
["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"], |
|
["./examples/pie_m02.png", "Is this chart misleading? Explain"]] |
|
|
|
title = "LlavaNext finetuned on Misleading Chart Dataset" |
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[image, input_prompt], |
|
outputs=model_output, |
|
examples=examples, |
|
title=title, |
|
theme='gradio/soft', |
|
cache_examples=False |
|
) |
|
|
|
interface.launch() |