File size: 3,496 Bytes
743d1bd
04b933e
ec89555
48b3789
1e3869c
1854dfd
9a692e8
7e5beaf
48b3789
 
fa11edf
48b3789
fa11edf
48b3789
 
7cfaf27
 
 
 
 
 
04b933e
7d03deb
7cfaf27
 
 
 
 
 
 
 
 
 
 
48b3789
7cfaf27
 
48b3789
 
 
 
 
fa11edf
48b3789
7cfaf27
 
2cb9aa9
 
 
 
d2acdfd
04b933e
fa11edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b3789
 
 
 
 
 
fa11edf
 
58c68ca
fa11edf
48b3789
 
fa11edf
48b3789
fa11edf
48b3789
 
fa11edf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from huggingface_hub import InferenceClient
import gradio as gr
import datetime
import re

# Initialize the InferenceClient
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Define the system prompt templates
system_prompt_templates = {
    r"\btime\b|\bhour\b|\bclock\b": "server log: ~This message was sent at {formatted_time}.~",
    r"\bdate\b|\bcalendar\b": "server log: ~Today's date is {formatted_date}.~",
    

def format_prompt(message, history, system_prompt):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"\[INST\] {user_prompt} \[/INST\]"
        prompt += f" {bot_response}</s> "
    prompt += f"\[INST\] {message} \[/INST\]"
    return prompt

def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
    temperature = max(float(temperature), 1e-2)
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Get current time and date
    now = datetime.datetime.now()
    formatted_time = now.strftime("%H.%M.%S, %B, %Y")
    formatted_date = now.strftime("%B %d, %Y")

    # Check for keywords in the user's input and update the system prompt accordingly
    for keyword, template in system_prompt_templates.items():
        if re.search(keyword, prompt, re.IGNORECASE):
            system_prompt = template.format(formatted_time=formatted_time, formatted_date=formatted_date)
            break

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield output

additional_inputs = [
    gr.Textbox(label="System Prompt", max_lines=1, interactive=True),
    gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
    gr.Slider(label="Max new tokens", value=9048, minimum=256, maximum=9048, step=64, interactive=True, info="The maximum numbers of new tokens"),
    gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
    gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
]

def check_keywords(text):
    for keyword, _ in system_prompt_templates.items():
        if re.search(keyword, text, re.IGNORECASE):
            return True
    return False

chatbot = gr.Chatbot(show_label=True, show_share_button=False, show_copy_button=True, likeable=True, layout="panel")
with gr.Blocks():
    with gr.Row():
        with gr.Column(scale=3):
            user_input = gr.Textbox(label="Your message", placeholder="Type your message here...")
        with gr.Column(scale=1):
            submit_button = gr.Button("Send")

    with gr.Row():
        chatbot_output = chatbot

    submit_button.click(
        fn=generate,
        inputs=[user_input, chatbot, gr.Textbox(label="System Prompt", max_lines=1, interactive=True)],
        outputs=chatbot_output,
        every=200,
        _js="check_keywords"
    )

gr.Blocks().launch(show_api=False)