mjavaid commited on
Commit
a3b555a
·
1 Parent(s): 199e7c3

first commit

Browse files
Files changed (1) hide show
  1. app.py +43 -22
app.py CHANGED
@@ -6,9 +6,9 @@ import os
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
- # Load the Gemma 3 pipeline
10
  pipe = pipeline(
11
- "image-text-to-text",
12
  model="google/gemma-3-4b-it",
13
  device="cuda",
14
  torch_dtype=torch.bfloat16,
@@ -16,10 +16,7 @@ pipe = pipeline(
16
  )
17
 
18
  @spaces.GPU
19
- def generate_response(user_text, user_image):
20
- if user_image is None:
21
- return "Please upload an image (required)"
22
-
23
  messages = [
24
  {
25
  "role": "system",
@@ -27,36 +24,60 @@ def generate_response(user_text, user_image):
27
  }
28
  ]
29
 
30
- user_content = [{"type": "image", "image": user_image}]
31
- if user_text:
32
- user_content.append({"type": "text", "text": user_text})
 
 
33
 
 
 
 
 
34
  messages.append({"role": "user", "content": user_content})
35
 
36
- # Call the pipeline with the provided messages
37
  output = pipe(text=messages, max_new_tokens=200)
38
 
39
  try:
40
  response = output[0]["generated_text"][-1]["content"]
41
- return response
42
- except (KeyError, IndexError, TypeError):
43
- return "Error processing the response. Please try again."
 
 
 
44
 
45
  with gr.Blocks() as demo:
46
- gr.Markdown("# Gemma 3 Image Analysis")
47
- gr.Markdown("Upload an image and optionally add a prompt to get the model's response.")
 
 
48
 
49
  with gr.Row():
50
- img = gr.Image(type="pil", label="Upload an image (required)")
51
- txt = gr.Textbox(label="Your prompt (optional)", placeholder="Describe what you see in this image")
 
 
 
 
 
 
 
 
52
 
53
- output = gr.Textbox(label="Model Response")
54
 
55
- submit_btn = gr.Button("Submit")
56
  submit_btn.click(
57
- generate_response,
58
- inputs=[txt, img],
59
- outputs=output
 
 
 
 
 
 
60
  )
61
 
62
  if __name__ == "__main__":
 
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
+ # Load the Gemma 3 pipeline - use the multimodal version for all cases
10
  pipe = pipeline(
11
+ "image-text-to-text", # This pipeline can handle both text-only and text+image
12
  model="google/gemma-3-4b-it",
13
  device="cuda",
14
  torch_dtype=torch.bfloat16,
 
16
  )
17
 
18
  @spaces.GPU
19
+ def get_response(message, chat_history, image=None):
 
 
 
20
  messages = [
21
  {
22
  "role": "system",
 
24
  }
25
  ]
26
 
27
+ user_content = []
28
+
29
+ # Only add image if provided
30
+ if image is not None:
31
+ user_content.append({"type": "image", "image": image})
32
 
33
+ # Always add the text message
34
+ if message:
35
+ user_content.append({"type": "text", "text": message})
36
+
37
  messages.append({"role": "user", "content": user_content})
38
 
39
+ # Call the pipeline
40
  output = pipe(text=messages, max_new_tokens=200)
41
 
42
  try:
43
  response = output[0]["generated_text"][-1]["content"]
44
+ chat_history.append((message, response))
45
+ except (KeyError, IndexError, TypeError) as e:
46
+ error_message = f"Error processing the response: {str(e)}"
47
+ chat_history.append((message, error_message))
48
+
49
+ return "", chat_history
50
 
51
  with gr.Blocks() as demo:
52
+ gr.Markdown("# Gemma 3 Chat Interface")
53
+ gr.Markdown("Chat with Gemma 3 with optional image upload capability")
54
+
55
+ chatbot = gr.Chatbot()
56
 
57
  with gr.Row():
58
+ msg = gr.Textbox(
59
+ show_label=False,
60
+ placeholder="Type your message here...",
61
+ scale=4
62
+ )
63
+ img = gr.Image(
64
+ type="pil",
65
+ label="Upload image (optional)",
66
+ scale=1
67
+ )
68
 
69
+ submit_btn = gr.Button("Send")
70
 
 
71
  submit_btn.click(
72
+ get_response,
73
+ inputs=[msg, chatbot, img],
74
+ outputs=[msg, chatbot]
75
+ )
76
+
77
+ msg.submit(
78
+ get_response,
79
+ inputs=[msg, chatbot, img],
80
+ outputs=[msg, chatbot]
81
  )
82
 
83
  if __name__ == "__main__":