Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import random | |
| import torch | |
| import hashlib | |
| import gradio as gr | |
| import threading | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer | |
| model_id = "ibm-granite/granite-vision-3.1-2b-preview" | |
| processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True) | |
| model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
| SYSTEM_PROMPT = ( | |
| "A chat between a curious user and an artificial intelligence assistant. " | |
| "The assistant gives helpful, detailed, and polite answers to the user's questions." | |
| ) | |
| def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation): | |
| if conversation is None or conversation == []: | |
| conversation = [{ | |
| "role": "system", | |
| "content": [{"type": "text", "text": SYSTEM_PROMPT}] | |
| }] | |
| user_content = [] | |
| if image is not None: | |
| if image.width > 512 or image.height > 512: | |
| image.thumbnail((512, 512)) | |
| user_content.append({"type": "image", "image": image}) | |
| if text and text.strip(): | |
| user_content.append({"type": "text", "text": text.strip()}) | |
| if not user_content: | |
| return conversation_display(conversation), conversation, "", False | |
| conversation.append({ | |
| "role": "user", | |
| "content": user_content | |
| }) | |
| conversation = preprocess_conversation(conversation) | |
| # Generate input prompt using the chat template. | |
| inputs = processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to("cuda") | |
| torch.manual_seed(random.randint(0, 10000)) | |
| generation_kwargs = { | |
| "max_new_tokens": max_tokens, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "do_sample": True, | |
| } | |
| if temperature > 0: | |
| generation_kwargs["temperature"] = temperature | |
| generation_kwargs["do_sample"] = True | |
| conversation.append({ | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": ""}] | |
| }) | |
| yield conversation_display(conversation), conversation, "Processing...", True | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs["streamer"] = streamer | |
| def generate_thread(): | |
| model.generate(**inputs, **generation_kwargs) | |
| thread = threading.Thread(target=generate_thread) | |
| thread.start() | |
| assistant_text = "" | |
| for new_text in streamer: | |
| assistant_text += new_text | |
| conversation[-1]["content"][0]["text"] = extract_answer(assistant_text) | |
| yield conversation_display(conversation), conversation, "Processing...", True | |
| thread.join() | |
| yield conversation_display(conversation), conversation, "", False | |
| return | |
| def extract_answer(response): | |
| if "<|assistant|>" in response: | |
| return response.split("<|assistant|>")[-1].strip() | |
| return response.strip() | |
| def compute_image_hash(image): | |
| image = image.convert("RGB") | |
| image_bytes = image.tobytes() | |
| return hashlib.md5(image_bytes).hexdigest() | |
| def preprocess_conversation(conversation): | |
| # Find the last sent image in previous user messages (excluding the latest message) | |
| last_image_hash = None | |
| for msg in reversed(conversation[:-1]): | |
| if msg.get("role") == "user": | |
| for item in msg.get("content", []): | |
| if item.get("type") == "image" and item.get("image") is not None: | |
| try: | |
| last_image_hash = compute_image_hash(item["image"]) | |
| break | |
| except Exception as e: | |
| continue | |
| if last_image_hash is not None: | |
| break | |
| # Process the latest user message. | |
| latest_msg = conversation[-1] | |
| if latest_msg.get("role") == "user": | |
| new_content = [] | |
| for item in latest_msg.get("content", []): | |
| if item.get("type") == "image" and item.get("image") is not None: | |
| try: | |
| current_hash = compute_image_hash(item["image"]) | |
| except Exception as e: | |
| current_hash = None | |
| # Remove the image if it matches the last sent image. | |
| if last_image_hash is not None and current_hash is not None and current_hash == last_image_hash: | |
| continue | |
| else: | |
| new_content.append(item) | |
| else: | |
| new_content.append(item) | |
| latest_msg["content"] = new_content | |
| return conversation | |
| def conversation_display(conversation): | |
| chat_history = [] | |
| for msg in conversation: | |
| if msg["role"] == "user": | |
| texts = [] | |
| for item in msg["content"]: | |
| if item["type"] == "image": | |
| texts.append("<image>") | |
| elif item["type"] == "text": | |
| texts.append(item["text"]) | |
| chat_history.append({ | |
| "role": "user", | |
| "content": "\n".join(texts) | |
| }) | |
| else: | |
| chat_history.append({ | |
| "role": msg["role"], | |
| "content": msg["content"][0]["text"] | |
| }) | |
| return chat_history | |
| def clear_chat(chat_history, conversation, text_value, image, is_generating): | |
| if is_generating: | |
| return chat_history, conversation, text_value, image, is_generating | |
| else: | |
| return [], [], "", None, is_generating | |
| with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo: | |
| gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Upload Image (optional)") | |
| with gr.Column(): | |
| temperature_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="Temperature") | |
| top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.01, label="Top p") | |
| top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top k") | |
| max_tokens_input = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max Tokens") | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages') | |
| text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") | |
| with gr.Row(): | |
| send_button = gr.Button("Chat") | |
| clear_button = gr.Button("Clear Chat") | |
| conversation_state = gr.State([]) | |
| is_generating = gr.State(False) | |
| send_button.click( | |
| chat_inference, | |
| inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, conversation_state], | |
| outputs=[chatbot, conversation_state, text_input, is_generating] | |
| ) | |
| clear_button.click( | |
| clear_chat, | |
| inputs=[chatbot, conversation_state, text_input, image_input, is_generating], | |
| outputs=[chatbot, conversation_state, text_input, image_input, is_generating] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] | |
| ], | |
| inputs=[image_input, text_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False) |