File size: 2,473 Bytes
743d1bd
04b933e
ec89555
1e3869c
1854dfd
0d7fc07
7e5beaf
ec4d6e3
7cfaf27
 
 
 
 
 
04b933e
7d03deb
7cfaf27
 
 
 
 
 
 
 
 
 
 
ec4d6e3
7cfaf27
6be7d23
cc73a2c
7cfaf27
 
2cb9aa9
 
 
 
d2acdfd
04b933e
fa11edf
 
 
 
 
 
 
 
ec4d6e3
 
 
 
6be7d23
d0a24f7
 
043dd31
92cab31
ed6950e
1005a82
d0a24f7
ec4d6e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from huggingface_hub import InferenceClient
import gradio as gr
import datetime

# Initialize the InferenceClient
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"\[INST\] {user_prompt} \[/INST\]"
        prompt += f" {bot_response}</s> "
    prompt += f"\[INST\] {message} \[/INST\]"
    return prompt

def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
    temperature = max(float(temperature), 1e-2)
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Get current time
    now = datetime.datetime.now()
    formatted_time = now.strftime("%H:%M:%S, %B %d, %Y")
    system_prompt = f"System time: (UTC+1 {formatted_time})"

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield output

additional_inputs = [
    gr.Textbox(label="System Prompt", max_lines=1, interactive=True),
    gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
    gr.Slider(label="Max new tokens", value=9048, minimum=256, maximum=9048, step=64, interactive=True, info="The maximum numbers of new tokens"),
    gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
    gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=True, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="ConvoLite",
    submit_btn="➢",
    retry_btn="Retry",
    undo_btn="↩ Undo",
    clear_btn="Clear (new chat)",
    stop_btn="Stop ▢",
    submin_btn min_width = "50px",
    concurrency_limit=100,
).launch(show_api=False)