Spaces:
Runtime error
Runtime error
File size: 3,595 Bytes
8cf621f 56d98ec 8cf621f 3644a6d 56d98ec 8cf621f 4b99c43 8cf621f 4b99c43 8cf621f 539297a 8cf621f 539297a 8cf621f 539297a 8cf621f 81301d6 8cf621f 81301d6 8cf621f 2667d9e 8cf621f 50f1d7b 8cf621f c4490a2 8cf621f 7900969 8cf621f 7900969 8cf621f c4490a2 8cf621f c4490a2 8cf621f 728f4ca 66161b9 728f4ca 8cf621f 728f4ca 8cf621f 56d98ec a71dbb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import os
import shutil
import requests
import gradio as gr
from huggingface_hub import Repository, InferenceClient
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/DataAnalyticsLab/PersianGPT-FT-Grover"
BOT_NAME = "PersianGPT-FT"
STOP_SEQUENCES = [] #["<|endoftext|>",">"]
EXAMPLES = [
["<$غزل$@بر لبم هر ذره داغی می توان کردن"],
["<$غزل$"],
["<%حافظ%"],
["<$مثنوی$%مولوی%@قدسی"],
["<$غزل$@دراین سرای بی کسی، کسی به در نمی زند"]
]
client = InferenceClient(
API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
def format_prompt(message, history, system_prompt):
prompt = ""
if system_prompt:
prompt += f"{system_prompt}"
for user_prompt, bot_response in history:
prompt += f"{user_prompt}"
prompt += f"{bot_response}"
prompt += f"""{message}"""
return prompt
def generate(
prompt, history, system_prompt="<|endoftext|>", temperature=0.9, max_new_tokens=100, top_p=0.95, repetition_penalty=1.0, seed=42,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
stop_sequences=STOP_SEQUENCES,
do_sample=True,
#seed=seed,
)
#seed = seed + 1
formatted_prompt = format_prompt(prompt, history, system_prompt)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
output = ""
for response in stream:
output += response
for stop_str in STOP_SEQUENCES:
if output.endswith(stop_str):
output = output[:-len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=1.0,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=100,
minimum=0,
maximum=250,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.05,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.0,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
CSS = """
.gradio-container textarea {direction: rtl; white-space: pre-line;}
#component-11 #component-12 {direction: rtl; white-space: pre-line;}
p {direction: rtl; white-space: pre-line;}
"""
with gr.Blocks(css=CSS) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
PERSIAN GPT Trained by Mojtaba Valipour @ Data Analytics Lab
"""
)
gr.ChatInterface(
generate,
examples=EXAMPLES,
additional_inputs=additional_inputs
)
demo.queue(concurrency_count=100, api_open=False).launch(show_api=False) #, share=True) |