Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import random | |
import torch | |
import gradio as gr | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
model_path = "ibm-granite/granite-vision-3.1-2b-preview" | |
processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True) | |
model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto") | |
def get_text_from_content(content): | |
texts = [] | |
for item in content: | |
if item["type"] == "text": | |
texts.append(item["text"]) | |
elif item["type"] == "image": | |
texts.append("[Image]") | |
return " ".join(texts) | |
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation): | |
if conversation is None: | |
conversation = [] | |
user_content = [] | |
if image is not None: | |
user_content.append({"type": "image", "image": image}) | |
if text and text.strip(): | |
user_content.append({"type": "text", "text": text.strip()}) | |
if not user_content: | |
return conversation_display(conversation), conversation | |
conversation.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
inputs = processor.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to("cuda") | |
torch.manual_seed(random.randint(0, 10000)) | |
generation_kwargs = { | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"do_sample": True, | |
} | |
output = model.generate(**inputs, **generation_kwargs) | |
assistant_response = processor.decode(output[0], skip_special_tokens=True) | |
conversation.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": assistant_response.strip()}] | |
}) | |
return conversation_display(conversation), conversation | |
def conversation_display(conversation): | |
chat_history = [] | |
for msg in conversation: | |
if msg["role"] == "user": | |
user_text = get_text_from_content(msg["content"]) | |
elif msg["role"] == "assistant": | |
assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip() | |
chat_history.append({"role": "user", "content": user_text}) | |
chat_history.append({"role": "assistant", "content": assistant_text}) | |
return chat_history | |
def clear_chat(): | |
return [], [], "", None | |
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo: | |
gr.Markdown("# Granite Vision 3.1 2B") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input = gr.Image(type="pil", label="Upload Image (optional)") | |
with gr.Column(): | |
temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature") | |
top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p") | |
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k") | |
max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens") | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages') | |
text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") | |
with gr.Row(): | |
send_button = gr.Button("Chat") | |
clear_button = gr.Button("Clear Chat") | |
state = gr.State([]) | |
send_button.click( | |
chat_inference, | |
inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state], | |
outputs=[chatbot, state] | |
) | |
clear_button.click( | |
clear_chat, | |
inputs=None, | |
outputs=[chatbot, state, text_input, image_input] | |
) | |
gr.Examples( | |
examples=[ | |
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] | |
], | |
inputs=[image_input, text_input] | |
) | |
if __name__ == "__main__": | |
demo.launch() |