File size: 2,873 Bytes
58d9279
 
e82a10b
 
c6e8f4b
e560fe9
 
a07bb4e
 
12fb4a0
 
 
 
97a4aa1
12fb4a0
 
 
 
97a4aa1
c7bde51
97a4aa1
 
 
c7bde51
97a4aa1
12fb4a0
 
 
e560fe9
12fb4a0
 
 
 
7b17be9
12fb4a0
 
 
 
 
 
 
 
 
 
 
97a4aa1
12fb4a0
 
 
 
 
e560fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a07bb4e
e560fe9
 
 
 
 
7b17be9
e560fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fb4a0
 
 
e560fe9
 
12fb4a0
 
 
8f92fa8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Inference

import gradio as gr
from huggingface_hub import InferenceClient

model_text = "google/gemma-2-27b-it"
model_vision = "google/paligemma2-3b-pt-224"

client = InferenceClient()

def fn_text(
    prompt,
    history,
    #system_prompt,
    max_tokens,
    temperature,
    top_p,
):
    #messages = [{"role": "system", "content": system_prompt}]
    #history.append(messages[0])
    
    #messages.append({"role": "user", "content": prompt})
    #history.append(messages[1])

    messages = [{"role": "user", "content": prompt}]
    history.append(messages[0])
    
    stream = client.chat.completions.create(
        model = model_text,
        messages = history,
        max_tokens = max_tokens,
        temperature = temperature,
        top_p = top_p,
        stream = True,
    )
    
    chunks = []
    for chunk in stream:
        chunks.append(chunk.choices[0].delta.content or "")
        yield "".join(chunks)

app_text = gr.ChatInterface(
    fn = fn_text,
    type = "messages",
    additional_inputs = [
        #gr.Textbox(value="You are a helpful assistant.", label="System Prompt"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P"),
    ],
    title = "Google Gemma",
    description = model_text,
)

def fn_vision(
    prompt,
    image_url,
    #system_prompt,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
    
    if image_url:
        messages[0]["content"].append({"type": "image_url", "image_url": {"url": image_url}})
    
    stream = client.chat.completions.create(
        model = model_vision,
        messages = messages,
        max_tokens = max_tokens,
        temperature = temperature,
        top_p = top_p,
        stream = True,
    )
    
    chunks = []
    for chunk in stream:
        chunks.append(chunk.choices[0].delta.content or "")
        yield "".join(chunks)

app_vision = gr.Interface(
    fn = fn_vision,
    inputs = [
        gr.Textbox(label="Prompt"),
        gr.Textbox(label="Image URL")
    ],
    outputs = [
        gr.Textbox(label="Output")
    ],
    additional_inputs = [
        #gr.Textbox(value="You are a helpful assistant.", label="System Prompt"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P"),
    ],
    title = "Google Gemma",
    description = model_vision,
)

app = gr.TabbedInterface(
    [app_text, app_vision],
    ["Text", "Vision"]
).launch()

#if __name__ == "__main__":
#    app.launch()