richardkimsm89 commited on
Commit
e560fe9
·
verified ·
1 Parent(s): 8f92fa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -6
app.py CHANGED
@@ -3,8 +3,11 @@
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
 
6
- model = "google/gemma-2-27b-it"
7
- client = InferenceClient(model)
 
 
 
8
 
9
  def fn_text(
10
  prompt,
@@ -24,7 +27,7 @@ def fn_text(
24
  history.append(messages[0])
25
 
26
  stream = client.chat.completions.create(
27
- model = model,
28
  messages = history,
29
  max_tokens = max_tokens,
30
  temperature = temperature,
@@ -47,12 +50,58 @@ app_text = gr.ChatInterface(
47
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P"),
48
  ],
49
  title = "Google Gemma",
50
- description = model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
  app = gr.TabbedInterface(
54
- [app_text],
55
- ["Text"]
56
  ).launch()
57
 
58
  #if __name__ == "__main__":
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
 
6
+ model_text = "google/gemma-2-27b-it"
7
+ client_text = InferenceClient(model_text)
8
+
9
+ model_vision = "google/paligemma2-3b-pt-224"
10
+ client_vision = InferenceClient(model_vision)
11
 
12
  def fn_text(
13
  prompt,
 
27
  history.append(messages[0])
28
 
29
  stream = client.chat.completions.create(
30
+ model = model_text,
31
  messages = history,
32
  max_tokens = max_tokens,
33
  temperature = temperature,
 
50
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P"),
51
  ],
52
  title = "Google Gemma",
53
+ description = model_text,
54
+ )
55
+
56
+ def fn_vision(
57
+ prompt,
58
+ image_url,
59
+ #system_prompt,
60
+ max_tokens,
61
+ temperature,
62
+ top_p,
63
+ ):
64
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
65
+
66
+ if image_url:
67
+ messages[0]["content"].append({"type": "image_url", "image_url": {"url": image_url}})
68
+
69
+ stream = client_vision.chat.completions.create(
70
+ model = model_vision,
71
+ messages = messages,
72
+ max_tokens = max_tokens,
73
+ temperature = temperature,
74
+ top_p = top_p,
75
+ stream = True
76
+ )
77
+
78
+ chunks = []
79
+ for chunk in stream:
80
+ chunks.append(chunk.choices[0].delta.content or "")
81
+ yield "".join(chunks)
82
+
83
+ app_vision = gr.Interface(
84
+ fn = fn_vision,
85
+ inputs = [
86
+ gr.Textbox(label="Prompt"),
87
+ gr.Textbox(label="Image URL")
88
+ ],
89
+ outputs = [
90
+ gr.Textbox(label="Output")
91
+ ],
92
+ additional_inputs = [
93
+ #gr.Textbox(value="You are a helpful assistant.", label="System Prompt"),
94
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
95
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
96
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P"),
97
+ ],
98
+ title = "Google Gemma",
99
+ description = model_vision,
100
  )
101
 
102
  app = gr.TabbedInterface(
103
+ [app_text, app_vision],
104
+ ["Text", "Vision"]
105
  ).launch()
106
 
107
  #if __name__ == "__main__":