import spaces import random import torch import gradio as gr from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration model_path = "ibm-granite/granite-vision-3.1-2b-preview" processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True) model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto") def get_text_from_content(content): texts = [] for item in content: if item["type"] == "text": texts.append(item["text"]) elif item["type"] == "image": texts.append("[Image]") return " ".join(texts) @spaces.GPU def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation): if conversation is None: conversation = [] user_content = [] if image is not None: user_content.append({"type": "image", "image": image}) if text and text.strip(): user_content.append({"type": "text", "text": text.strip()}) if not user_content: return conversation_display(conversation), conversation conversation.append({ "role": "user", "content": user_content }) inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to("cuda") torch.manual_seed(random.randint(0, 10000)) generation_kwargs = { "max_new_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "do_sample": True, } output = model.generate(**inputs, **generation_kwargs) assistant_response = processor.decode(output[0], skip_special_tokens=True) conversation.append({ "role": "assistant", "content": [{"type": "text", "text": assistant_response.strip()}] }) return conversation_display(conversation), conversation def conversation_display(conversation): chat_history = [] for msg in conversation: if msg["role"] == "user": user_text = get_text_from_content(msg["content"]) elif msg["role"] == "assistant": assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip() chat_history.append({"role": "user", "content": user_text}) chat_history.append({"role": "assistant", "content": assistant_text}) return chat_history def clear_chat(): return [], [], "", None with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo: gr.Markdown("# Granite Vision 3.1 2B") with gr.Row(): with gr.Column(scale=2): image_input = gr.Image(type="pil", label="Upload Image (optional)") with gr.Column(): temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature") top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p") top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k") max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens") with gr.Column(scale=3): chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages') text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") with gr.Row(): send_button = gr.Button("Chat") clear_button = gr.Button("Clear Chat") state = gr.State([]) send_button.click( chat_inference, inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state], outputs=[chatbot, state] ) clear_button.click( clear_chat, inputs=None, outputs=[chatbot, state, text_input, image_input] ) gr.Examples( examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] ], inputs=[image_input, text_input] ) if __name__ == "__main__": demo.launch()