File size: 3,442 Bytes
8cf621f
56d98ec
8cf621f
 
 
 
 
3644a6d
56d98ec
8cf621f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539297a
8cf621f
539297a
 
 
8cf621f
 
 
539297a
8cf621f
 
 
 
 
 
 
 
 
 
 
 
81301d6
8cf621f
81301d6
8cf621f
 
2667d9e
8cf621f
 
 
50f1d7b
8cf621f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7900969
8cf621f
7900969
8cf621f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d584c4
8cf621f
56d98ec
a71dbb8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import os
import shutil
import requests

import gradio as gr
from huggingface_hub import Repository, InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/DataAnalyticsLab/PersianGPT-FT-Grover"
BOT_NAME = "PersianGPT-FT"

STOP_SEQUENCES = ["<|endoftext|>"]

EXAMPLES = [
    ["<$غزل$@بر لبم هر ذره داغی می توان کردن"],
    ["<$غزل$"],
    ["<$قصیده$"],
    ["<$مثنوی$"],
    ["<$غزل$@دراین سرای بی کسی، کسی به در نمی زند"]
    ]

client = InferenceClient(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

def format_prompt(message, history, system_prompt):
  prompt = ""
  if system_prompt:
    prompt += f"{system_prompt}"
  for user_prompt, bot_response in history:
    prompt += f"{user_prompt}"
    prompt += f"{bot_response}"
  prompt += f"""{message}"""
  return prompt

def generate(
    prompt, history, system_prompt="<|endoftext|>", temperature=0.9, max_new_tokens=100, top_p=0.95, repetition_penalty=1.0, seed=42,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        stop_sequences=STOP_SEQUENCES,
        do_sample=True,
        #seed=seed,
    )
    #seed = seed + 1
    formatted_prompt = format_prompt(prompt, history, system_prompt)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response

        for stop_str in STOP_SEQUENCES:
            if output.endswith(stop_str):
                output = output[:-len(stop_str)]
                output = output.rstrip()
                yield output
        yield output
    return output


additional_inputs=[
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=100,
        minimum=0,
        maximum=250,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                PERSIAN GPT Trained by Mojtaba Valipour @ Data Analytics Lab
                """
            )

    gr.ChatInterface(
        generate, 
        examples=EXAMPLES,
        additional_inputs=additional_inputs,
        css=".gradio-container {direction: rtl; white-space: pre-line;}"
    ) 

demo.queue(concurrency_count=100, api_open=False).launch(show_api=False) #, share=True)