File size: 2,437 Bytes
199e7c3
313270e
 
 
642f587
 
 
313270e
b3d358d
 
 
 
 
 
 
 
 
313270e
b3d358d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3b555a
b3d358d
313270e
b3d358d
4e70e07
1b04e5d
 
b3d358d
a3b555a
 
 
 
b3d358d
 
4e70e07
199e7c3
b3d358d
 
a3b555a
 
199e7c3
 
a3b555a
 
b3d358d
a3b555a
 
 
 
b3d358d
a3b555a
 
199e7c3
a3b555a
199e7c3
b3d358d
 
 
 
 
 
199e7c3
a3b555a
 
 
 
 
 
 
 
 
199e7c3
b3d358d
 
 
 
 
 
43565c9
 
199e7c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import spaces
import gradio as gr
from transformers import pipeline
import torch
import os

hf_token = os.environ["HF_TOKEN"]

# Load the Gemma 3 pipeline
pipe = pipeline(
    "image-text-to-text",
    model="google/gemma-3-4b-it",
    device="cuda",
    torch_dtype=torch.bfloat16,
    use_auth_token=hf_token
)

@spaces.GPU
def get_response(message, chat_history, image):
    # Check if image is provided
    if image is None:
        chat_history.append((message, "Please upload an image (required)"))
        return "", chat_history
    
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a helpful assistant."}]
        }
    ]
    
    user_content = [{"type": "image", "image": image}]
    
    # Add text message if provided
    if message:
        user_content.append({"type": "text", "text": message})
        
    messages.append({"role": "user", "content": user_content})
    
    # Call the pipeline
    output = pipe(text=messages, max_new_tokens=200)
    
    try:
        response = output[0]["generated_text"][-1]["content"]
        chat_history.append((message, response))
    except (KeyError, IndexError, TypeError) as e:
        error_message = f"Error processing the response: {str(e)}"
        chat_history.append((message, error_message))
    
    return "", chat_history

with gr.Blocks() as demo:
    gr.Markdown("# Gemma 3 Image Chat")
    gr.Markdown("Chat with Gemma 3 about images. Image upload is required for each message.")
    
    chatbot = gr.Chatbot()
    
    with gr.Row():
        msg = gr.Textbox(
            show_label=False,
            placeholder="Type your message here about the image...",
            scale=4
        )
        img = gr.Image(
            type="pil", 
            label="Upload image (required)", 
            scale=1
        )
    
    submit_btn = gr.Button("Send")
    
    # Clear button to reset the interface
    clear_btn = gr.Button("Clear")
    
    def clear_interface():
        return "", [], None
    
    submit_btn.click(
        get_response,
        inputs=[msg, chatbot, img],
        outputs=[msg, chatbot]
    )
    
    msg.submit(
        get_response,
        inputs=[msg, chatbot, img],
        outputs=[msg, chatbot]
    )
    
    clear_btn.click(
        clear_interface,
        inputs=None,
        outputs=[msg, chatbot, img]
    )

if __name__ == "__main__":
    demo.launch()