File size: 4,837 Bytes
fec4cfa
 
 
 
3ce5891
 
fec4cfa
3ce5891
 
fec4cfa
 
 
 
 
 
 
3ce5891
 
 
fec4cfa
 
3ce5891
 
 
 
 
 
fec4cfa
3ce5891
 
fec4cfa
3ce5891
fec4cfa
3ce5891
fec4cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce5891
fec4cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce5891
fec4cfa
 
 
 
3ce5891
fec4cfa
 
 
6ddaa62
fec4cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce5891
 
 
 
fec4cfa
 
3ce5891
fec4cfa
3ce5891
fec4cfa
 
 
 
3ce5891
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from huggingface_hub import InferenceClient
import random
import textwrap
from transformers import pipeline
import numpy as np

# Load the Whisper model for automatic speech recognition
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")

# Define the model to be used
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
client = InferenceClient(model)

# Embedded system prompt
system_prompt_text = (
    "You are a smart and helpful co-worker of Thailand based multi-national company PTT, and PTTEP. "
    "You help with any kind of request and provide a detailed answer to the question. But if you are asked about something "
    "unethical or dangerous, you must refuse and provide a safe and respectful way to handle that."
)

# Function to transcribe audio input
def transcribe(audio):
    sr, y = audio
    # Convert to mono if stereo
    if y.ndim > 1:
        y = y.mean(axis=1)

    y = y.astype(np.float32)
    y /= np.max(np.abs(y))  # Normalize audio

    return transcriber({"sampling_rate": sr, "raw": y})["text"]  # Transcribe audio

def format_prompt_mixtral(message, history):
    prompt = "<s>"
    prompt += f"{system_prompt_text}\n\n"  # Add the system prompt

    if history:
        for user_prompt, bot_response in history:
            prompt += f"[INST] {user_prompt} [/INST]"
            prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
    generate_kwargs = dict(
        temperature=temp,
        max_new_tokens=tokens,
        top_p=top_p,
        repetition_penalty=rep_p,
        do_sample=True,
        seed=seed,
    )

    formatted_prompt = format_prompt_mixtral(prompt, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield [(prompt, output)]
    history.append((prompt, output))
    yield history

def clear_fn():
    return None, None

rand_val = random.randint(1, 1111111111111111)

def check_rand(inp, val):
    if inp:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
    else:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))

with gr.Blocks() as app:  # Add auth here
    gr.HTML("""<center><h1 style='font-size:xx-large;'>PTT Chatbot</h1><br><h3>running on Huggingface Inference</h3><br><h7>EXPERIMENTAL</center>""")
    
    with gr.Row():
        chat = gr.Chatbot(height=500)

    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                inp = gr.Audio(type="filepath")  # Remove the source parameter
                with gr.Row():
                    with gr.Column(scale=2):
                        btn = gr.Button("Chat")
                    with gr.Column(scale=1):
                        with gr.Group():
                            stop_btn = gr.Button("Stop")
                            clear_btn = gr.Button("Clear")
            with gr.Column(scale=1):
                with gr.Group():
                    rand = gr.Checkbox(label="Random Seed", value=True)
                    seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
                    tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
                    temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)

    hid1 = gr.Number(value=1, visible=False)

    def handle_chat(audio_input, chat_history, seed, temp, tokens, top_p, rep_p):
        user_message = transcribe(audio_input)  # Transcribe audio to text
        if not user_message:  # Check for empty or error in recognition
            return chat_history, "Sorry, I couldn't understand that."

        response_gen = chat_inf(user_message, chat_history, seed, temp, tokens, top_p, rep_p)
        response = next(response_gen)[0][-1][1]  # Get the response text
        return chat_history + [(user_message, response)], response  # Return updated chat history

    go = btn.click(handle_chat, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)

    stop_btn.click(None, None, None, cancels=[go])
    clear_btn.click(clear_fn, None, [inp, chat])

app.queue(default_concurrency_limit=10).launch(share=True, auth=("admin", "0112358"))  # Launch the app with authentication