File size: 2,859 Bytes
4decb51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import spaces
from threading import Thread

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from transformers import TextIteratorStreamer
from PIL import Image
from peft import PeftModel
import requests
import torch, os, re, json
import time


base_model = "llava-hf/llava-v1.6-mistral-7b-hf"
finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k"

processor = LlavaNextProcessor.from_pretrained(base_model)

model = LlavaNextForConditionalGeneration.from_pretrained(
    base_model,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(model, finetune_repo)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to("cuda:0")


@spaces.GPU
def predict(image, input_text):
    image = image.convert("RGB")
    prompt = f"[INST] <image>\n{input_text} [/INST]"
    
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16)
    
    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
    # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200, do_sample=False)

    model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False)

    text_prompt = f"[INST]  \n{input_text} [/INST]"

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[len(text_prompt):]
        time.sleep(0.04)
        yield generated_text_without_prompt


    # prompt_length = inputs['input_ids'].shape[1]
    # generate_ids = model.generate(**inputs, max_new_tokens=512)
    # output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # return output_text

   
image = gr.components.Image(type="pil")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"],
            ["./examples/bar_n01.png", "Is this chart misleading? Explain"],
            ["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"],
            ["./examples/line_m01.png", "Explain if this chart is misleading"],
            ["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"],
            ["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"],
            ["./examples/pie_m02.png", "Is this chart misleading? Explain"]]

title = "LlavaNext finetuned on Misleading Chart Dataset"
interface = gr.Interface(
    fn=predict, 
    inputs=[image, input_prompt], 
    outputs=model_output, 
    examples=examples, 
    title=title,
    theme='gradio/soft',
    cache_examples=False
)

interface.launch()