File size: 4,206 Bytes
9dc5c64
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f8d6c9
3f01084
 
9dc5c64
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc5c64
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import spaces
import random
import torch
import gradio as gr
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

model_path = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")

def get_text_from_content(content):
    texts = []
    for item in content:
        if item["type"] == "text":
            texts.append(item["text"])
        elif item["type"] == "image":
            texts.append("<image>")
    return " ".join(texts)

@spaces.GPU
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
    if conversation is None:
        conversation = []
        
    user_content = []
    if image is not None:
        user_content.append({"type": "image", "image": image})
    if text and text.strip():
        user_content.append({"type": "text", "text": text.strip()})
    if not user_content:
        return conversation_display(conversation), conversation

    conversation.append({
        "role": "user",
        "content": user_content
    })

    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to("cuda")

    torch.manual_seed(random.randint(0, 10000))

    generation_kwargs = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "do_sample": True,
    }

    output = model.generate(**inputs, **generation_kwargs)
    assistant_response = processor.decode(output[0], skip_special_tokens=True)

    conversation.append({
        "role": "assistant",
        "content": [{"type": "text", "text": assistant_response.strip()}]
    })
    
    return conversation_display(conversation), conversation

def conversation_display(conversation):
    chat_history = []
    for msg in conversation:
        if msg["role"] == "user":
            user_text = get_text_from_content(msg["content"])
        elif msg["role"] == "assistant":
            assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
            chat_history.append({"role": "user", "content": user_text})
            chat_history.append({"role": "assistant", "content": assistant_text})
    return chat_history

def clear_chat():
    return [], [], "", None
    
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
    gr.Markdown("# Granite Vision 3.1 2B")
    
    with gr.Row():
        with gr.Column(scale=2):
            image_input = gr.Image(type="pil", label="Upload Image (optional)")
            with gr.Column():
                temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature")
                top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p")
                top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k")
                max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
            text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
            with gr.Row():
                send_button = gr.Button("Chat")
                clear_button = gr.Button("Clear Chat")
    

    state = gr.State([])

    send_button.click(
        chat_inference,
        inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
        outputs=[chatbot, state]
    )

    clear_button.click(
        clear_chat,
        inputs=None,
        outputs=[chatbot, state, text_input, image_input]
    )

    gr.Examples(
        examples=[
            ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
        ],
        inputs=[image_input, text_input]
    )

if __name__ == "__main__":
    demo.launch()