File size: 2,990 Bytes
7416d8a
5c0e14a
 
7416d8a
 
 
5c0e14a
 
 
 
 
 
7416d8a
08eb742
 
 
 
 
5c0e14a
7416d8a
5c0e14a
 
 
 
 
7416d8a
 
 
 
 
 
 
 
 
5c0e14a
 
7416d8a
 
5c0e14a
7416d8a
 
 
 
17fba42
7416d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44b3bcb
 
 
5c0e14a
44b3bcb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
from huggingface_hub import InferenceClient
import gradio as gr
import random

API_URL = "https://api-inference.huggingface.co/models/"

client = InferenceClient(
    "mistralai/Mistral-7B-Instruct-v0.1"
)

def format_prompt(message, history):
    prompt = "You're a helpful assistant."
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(0, 10**7),
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output

def load_database():
    try:
        # Attempt to load the database from JSON
        with open("database.json", "r", encoding="utf-8") as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        # Handle potential errors gracefully
        print("Error loading database: File not found or invalid format. Creating an empty database.")
        return []  # Return an empty list if database loading fails

def save_database(data):
    try:
        # Save the updated database to JSON
        with open("database.json", "w", encoding="utf-8") as f:
            json.dump(data, f, indent=4)
    except (IOError, json.JSONEncodeError):
        # Handle potential errors gracefully
        print("Error saving database: Encountered an issue while saving.")

def chat_interface(message):
    database = load_database()

    # Check if the question already exists in the database
    if (message, None) not in database:
        # If not, generate a response and add it to the database
        response = generate(message, history=[])
        database.append((message, response))
        save_database(database)
    else:
        # If it does, retrieve the stored response
        _, stored_response = next(item for item in database if item[0] == message)
        response = stored_response

    return response

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # Use "Textbox" components for both input and output
    input_textbox = gr.Textbox(label="Your question")
    output_textbox = gr.Textbox(label="Assistant's response", value="", editable=False)

    # Use demo.launch instead of demo.queue().launch()
    demo.launch(fn=chat_interface, inputs=input_textbox, outputs=output_textbox)