|
import gradio as gr |
|
import cv2 |
|
import torch |
|
from PIL import Image |
|
from pathlib import Path |
|
from threading import Thread |
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer |
|
import spaces |
|
import time |
|
|
|
|
|
model_12b_name = "google/gemma-3-12b-it" |
|
model_4b_name = "google/gemma-3-4b-it" |
|
model_12b = Gemma3ForConditionalGeneration.from_pretrained( |
|
model_12b_name, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16 |
|
).eval() |
|
processor_12b = AutoProcessor.from_pretrained(model_12b_name) |
|
model_4b = Gemma3ForConditionalGeneration.from_pretrained( |
|
model_4b_name, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16 |
|
).eval() |
|
processor_4b = AutoProcessor.from_pretrained(model_4b_name) |
|
|
|
def extract_video_frames(video_path, num_frames=8): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
step = max(total_frames // num_frames, 1) |
|
|
|
for i in range(num_frames): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) |
|
ret, frame = cap.read() |
|
if ret: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame)) |
|
cap.release() |
|
return frames |
|
|
|
def format_message(content, files): |
|
|
|
message_content = [] |
|
|
|
if content: |
|
parts = content.split('<image>') |
|
for i, part in enumerate(parts): |
|
if part.strip(): |
|
message_content.append({"type": "text", "text": part.strip()}) |
|
if i < len(parts) - 1 and files: |
|
img = Image.open(files.pop(0)) |
|
message_content.append({"type": "image", "image": img}) |
|
for file in files: |
|
file_path = file if isinstance(file, str) else file.name |
|
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: |
|
img = Image.open(file_path) |
|
message_content.append({"type": "image", "image": img}) |
|
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: |
|
frames = extract_video_frames(file_path) |
|
for frame in frames: |
|
message_content.append({"type": "image", "image": frame}) |
|
return message_content |
|
|
|
def format_conversation_history(chat_history): |
|
messages = [] |
|
current_user_content = [] |
|
for item in chat_history: |
|
role = item["role"] |
|
content = item["content"] |
|
if role == "user": |
|
if isinstance(content, str): |
|
current_user_content.append({"type": "text", "text": content}) |
|
elif isinstance(content, list): |
|
current_user_content.extend(content) |
|
else: |
|
current_user_content.append({"type": "text", "text": str(content)}) |
|
elif role == "assistant": |
|
if current_user_content: |
|
messages.append({"role": "user", "content": current_user_content}) |
|
current_user_content = [] |
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) |
|
if current_user_content: |
|
messages.append({"role": "user", "content": current_user_content}) |
|
return messages |
|
|
|
@spaces.GPU(duration=120) |
|
def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): |
|
if isinstance(input_data, dict) and "text" in input_data: |
|
text = input_data["text"] |
|
files = input_data.get("files", []) |
|
else: |
|
text = str(input_data) |
|
files = [] |
|
|
|
new_message_content = format_message(text, files) |
|
new_message = {"role": "user", "content": new_message_content} |
|
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] |
|
processed_history = format_conversation_history(chat_history) |
|
messages = system_message + processed_history |
|
if messages and messages[-1]["role"] == "user": |
|
messages[-1]["content"].extend(new_message["content"]) |
|
else: |
|
messages.append(new_message) |
|
if model_choice == "Gemma 3 12B": |
|
model = model_12b |
|
processor = processor_12b |
|
else: |
|
model = model_4b |
|
processor = processor_4b |
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_tensors="pt", |
|
return_dict=True |
|
).to(model.device) |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict( |
|
inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty |
|
) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
demo = gr.ChatInterface( |
|
fn=generate_response, |
|
additional_inputs=[ |
|
gr.Dropdown( |
|
label="Model", |
|
choices=["Gemma 3 12B", "Gemma 3 4B"], |
|
value="Gemma 3 12B" |
|
), |
|
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), |
|
gr.Textbox( |
|
label="System Prompt", |
|
value="You are a friendly chatbot. ", |
|
lines=4, |
|
placeholder="Change system prompt" |
|
), |
|
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), |
|
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), |
|
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), |
|
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0), |
|
], |
|
examples=[ |
|
[{"text": "Explain this image", "files": ["examples/image1.jpg"]}], |
|
], |
|
cache_examples=False, |
|
type="messages", |
|
description=""" |
|
# Gemma 3 |
|
You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience. |
|
""", |
|
fill_height=True, |
|
textbox=gr.MultimodalTextbox( |
|
label="Query Input", |
|
file_types=["image", "video"], |
|
file_count="multiple", |
|
placeholder="Type your message or upload media" |
|
), |
|
stop_btn="Stop Generation", |
|
multimodal=True, |
|
theme=gr.themes.Soft(), |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |