mjavaid commited on
Commit
43ab5c7
·
1 Parent(s): a3b555a

first commit

Browse files
Files changed (1) hide show
  1. app.py +50 -28
app.py CHANGED
@@ -6,47 +6,69 @@ import os
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
- # Load the Gemma 3 pipeline - use the multimodal version for all cases
10
- pipe = pipeline(
11
- "image-text-to-text", # This pipeline can handle both text-only and text+image
12
- model="google/gemma-3-4b-it",
13
- device="cuda",
14
- torch_dtype=torch.bfloat16,
15
- use_auth_token=hf_token
16
- )
17
-
18
  @spaces.GPU
19
  def get_response(message, chat_history, image=None):
20
- messages = [
21
- {
22
- "role": "system",
23
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
24
- }
25
- ]
26
-
27
- user_content = []
28
-
29
- # Only add image if provided
30
  if image is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  user_content.append({"type": "image", "image": image})
32
-
33
- # Always add the text message
34
- if message:
35
- user_content.append({"type": "text", "text": message})
 
 
 
 
 
 
 
 
 
 
36
 
37
- messages.append({"role": "user", "content": user_content})
 
 
 
 
 
 
 
 
 
38
 
39
- # Call the pipeline
40
  output = pipe(text=messages, max_new_tokens=200)
41
 
42
  try:
43
- response = output[0]["generated_text"][-1]["content"]
 
 
 
 
44
  chat_history.append((message, response))
 
45
  except (KeyError, IndexError, TypeError) as e:
46
  error_message = f"Error processing the response: {str(e)}"
47
  chat_history.append((message, error_message))
48
-
49
- return "", chat_history
50
 
51
  with gr.Blocks() as demo:
52
  gr.Markdown("# Gemma 3 Chat Interface")
 
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
 
 
 
 
 
 
 
 
 
9
  @spaces.GPU
10
  def get_response(message, chat_history, image=None):
11
+ # Choose the appropriate pipeline based on whether an image is provided
 
 
 
 
 
 
 
 
 
12
  if image is not None:
13
+ # Multimodal pipeline for text+image
14
+ pipe = pipeline(
15
+ "image-text-to-text",
16
+ model="google/gemma-3-4b-it",
17
+ device="cuda",
18
+ torch_dtype=torch.bfloat16,
19
+ use_auth_token=hf_token
20
+ )
21
+
22
+ messages = [
23
+ {
24
+ "role": "system",
25
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
26
+ }
27
+ ]
28
+
29
+ user_content = []
30
  user_content.append({"type": "image", "image": image})
31
+ if message:
32
+ user_content.append({"type": "text", "text": message})
33
+
34
+ messages.append({"role": "user", "content": user_content})
35
+
36
+ else:
37
+ # Text-only pipeline
38
+ pipe = pipeline(
39
+ "text-generation",
40
+ model="google/gemma-3-4b-it",
41
+ device="cuda",
42
+ torch_dtype=torch.bfloat16,
43
+ use_auth_token=hf_token
44
+ )
45
 
46
+ messages = [
47
+ {
48
+ "role": "system",
49
+ "content": "You are a helpful assistant."
50
+ },
51
+ {
52
+ "role": "user",
53
+ "content": message
54
+ }
55
+ ]
56
 
57
+ # Call the appropriate pipeline
58
  output = pipe(text=messages, max_new_tokens=200)
59
 
60
  try:
61
+ if image is not None:
62
+ response = output[0]["generated_text"][-1]["content"]
63
+ else:
64
+ response = output[0]["generated_text"]
65
+
66
  chat_history.append((message, response))
67
+ return "", chat_history
68
  except (KeyError, IndexError, TypeError) as e:
69
  error_message = f"Error processing the response: {str(e)}"
70
  chat_history.append((message, error_message))
71
+ return "", chat_history
 
72
 
73
  with gr.Blocks() as demo:
74
  gr.Markdown("# Gemma 3 Chat Interface")