shb777's picture
Update app.py
9dc5c64 verified
import spaces
import random
import torch
import gradio as gr
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
model_path = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
def get_text_from_content(content):
texts = []
for item in content:
if item["type"] == "text":
texts.append(item["text"])
elif item["type"] == "image":
texts.append("[Image]")
return " ".join(texts)
@spaces.GPU
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
if conversation is None:
conversation = []
user_content = []
if image is not None:
user_content.append({"type": "image", "image": image})
if text and text.strip():
user_content.append({"type": "text", "text": text.strip()})
if not user_content:
return conversation_display(conversation), conversation
conversation.append({
"role": "user",
"content": user_content
})
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to("cuda")
torch.manual_seed(random.randint(0, 10000))
generation_kwargs = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": True,
}
output = model.generate(**inputs, **generation_kwargs)
assistant_response = processor.decode(output[0], skip_special_tokens=True)
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": assistant_response.strip()}]
})
return conversation_display(conversation), conversation
def conversation_display(conversation):
chat_history = []
for msg in conversation:
if msg["role"] == "user":
user_text = get_text_from_content(msg["content"])
elif msg["role"] == "assistant":
assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
chat_history.append({"role": "user", "content": user_text})
chat_history.append({"role": "assistant", "content": assistant_text})
return chat_history
def clear_chat():
return [], [], "", None
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
gr.Markdown("# Granite Vision 3.1 2B")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Upload Image (optional)")
with gr.Column():
temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature")
top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p")
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k")
max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
with gr.Row():
send_button = gr.Button("Chat")
clear_button = gr.Button("Clear Chat")
state = gr.State([])
send_button.click(
chat_inference,
inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
outputs=[chatbot, state]
)
clear_button.click(
clear_chat,
inputs=None,
outputs=[chatbot, state, text_input, image_input]
)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
],
inputs=[image_input, text_input]
)
if __name__ == "__main__":
demo.launch()