mjavaid commited on
Commit
b3d358d
·
1 Parent(s): 43ab5c7

first commit

Browse files
Files changed (1) hide show
  1. app.py +48 -57
app.py CHANGED
@@ -6,90 +6,75 @@ import os
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
 
 
 
 
 
 
 
 
 
9
  @spaces.GPU
10
- def get_response(message, chat_history, image=None):
11
- # Choose the appropriate pipeline based on whether an image is provided
12
- if image is not None:
13
- # Multimodal pipeline for text+image
14
- pipe = pipeline(
15
- "image-text-to-text",
16
- model="google/gemma-3-4b-it",
17
- device="cuda",
18
- torch_dtype=torch.bfloat16,
19
- use_auth_token=hf_token
20
- )
21
-
22
- messages = [
23
- {
24
- "role": "system",
25
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
26
- }
27
- ]
28
-
29
- user_content = []
30
- user_content.append({"type": "image", "image": image})
31
- if message:
32
- user_content.append({"type": "text", "text": message})
33
-
34
- messages.append({"role": "user", "content": user_content})
35
-
36
- else:
37
- # Text-only pipeline
38
- pipe = pipeline(
39
- "text-generation",
40
- model="google/gemma-3-4b-it",
41
- device="cuda",
42
- torch_dtype=torch.bfloat16,
43
- use_auth_token=hf_token
44
- )
45
 
46
- messages = [
47
- {
48
- "role": "system",
49
- "content": "You are a helpful assistant."
50
- },
51
- {
52
- "role": "user",
53
- "content": message
54
- }
55
- ]
56
 
57
- # Call the appropriate pipeline
58
  output = pipe(text=messages, max_new_tokens=200)
59
 
60
  try:
61
- if image is not None:
62
- response = output[0]["generated_text"][-1]["content"]
63
- else:
64
- response = output[0]["generated_text"]
65
-
66
  chat_history.append((message, response))
67
- return "", chat_history
68
  except (KeyError, IndexError, TypeError) as e:
69
  error_message = f"Error processing the response: {str(e)}"
70
  chat_history.append((message, error_message))
71
- return "", chat_history
 
72
 
73
  with gr.Blocks() as demo:
74
- gr.Markdown("# Gemma 3 Chat Interface")
75
- gr.Markdown("Chat with Gemma 3 with optional image upload capability")
76
 
77
  chatbot = gr.Chatbot()
78
 
79
  with gr.Row():
80
  msg = gr.Textbox(
81
  show_label=False,
82
- placeholder="Type your message here...",
83
  scale=4
84
  )
85
  img = gr.Image(
86
  type="pil",
87
- label="Upload image (optional)",
88
  scale=1
89
  )
90
 
91
  submit_btn = gr.Button("Send")
92
 
 
 
 
 
 
 
93
  submit_btn.click(
94
  get_response,
95
  inputs=[msg, chatbot, img],
@@ -101,6 +86,12 @@ with gr.Blocks() as demo:
101
  inputs=[msg, chatbot, img],
102
  outputs=[msg, chatbot]
103
  )
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
  demo.launch()
 
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
+ # Load the Gemma 3 pipeline
10
+ pipe = pipeline(
11
+ "image-text-to-text",
12
+ model="google/gemma-3-4b-it",
13
+ device="cuda",
14
+ torch_dtype=torch.bfloat16,
15
+ use_auth_token=hf_token
16
+ )
17
+
18
  @spaces.GPU
19
+ def get_response(message, chat_history, image):
20
+ # Check if image is provided
21
+ if image is None:
22
+ chat_history.append((message, "Please upload an image (required)"))
23
+ return "", chat_history
24
+
25
+ messages = [
26
+ {
27
+ "role": "system",
28
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
29
+ }
30
+ ]
31
+
32
+ user_content = [{"type": "image", "image": image}]
33
+
34
+ # Add text message if provided
35
+ if message:
36
+ user_content.append({"type": "text", "text": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ messages.append({"role": "user", "content": user_content})
 
 
 
 
 
 
 
 
 
39
 
40
+ # Call the pipeline
41
  output = pipe(text=messages, max_new_tokens=200)
42
 
43
  try:
44
+ response = output[0]["generated_text"][-1]["content"]
 
 
 
 
45
  chat_history.append((message, response))
 
46
  except (KeyError, IndexError, TypeError) as e:
47
  error_message = f"Error processing the response: {str(e)}"
48
  chat_history.append((message, error_message))
49
+
50
+ return "", chat_history
51
 
52
  with gr.Blocks() as demo:
53
+ gr.Markdown("# Gemma 3 Image Chat")
54
+ gr.Markdown("Chat with Gemma 3 about images. Image upload is required for each message.")
55
 
56
  chatbot = gr.Chatbot()
57
 
58
  with gr.Row():
59
  msg = gr.Textbox(
60
  show_label=False,
61
+ placeholder="Type your message here about the image...",
62
  scale=4
63
  )
64
  img = gr.Image(
65
  type="pil",
66
+ label="Upload image (required)",
67
  scale=1
68
  )
69
 
70
  submit_btn = gr.Button("Send")
71
 
72
+ # Clear button to reset the interface
73
+ clear_btn = gr.Button("Clear")
74
+
75
+ def clear_interface():
76
+ return "", [], None
77
+
78
  submit_btn.click(
79
  get_response,
80
  inputs=[msg, chatbot, img],
 
86
  inputs=[msg, chatbot, img],
87
  outputs=[msg, chatbot]
88
  )
89
+
90
+ clear_btn.click(
91
+ clear_interface,
92
+ inputs=None,
93
+ outputs=[msg, chatbot, img]
94
+ )
95
 
96
  if __name__ == "__main__":
97
  demo.launch()