Sorvad / app.py
vpcom's picture
fix: comment out the seed as it seems the model doesn't support it at the moment, later will try to investigate more
81301d6
raw
history blame
3.37 kB
import json
import os
import shutil
import requests
import gradio as gr
from huggingface_hub import Repository, InferenceClient
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/DataAnalyticsLab/PersianGPT-FT-Grover"
BOT_NAME = "PersianGPT-FT"
STOP_SEQUENCES = ["<|endoftext|>"]
EXAMPLES = [
["<$غزل$@بر لبم هر ذره داغی می توان کردن"],
["<$غزل$"],
["<$قصیده$"],
["<$مثنوی$"],
["<$غزل$@دراین سرای بی کسی، کسی به در نمی زند"]
]
client = InferenceClient(
API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
def format_prompt(message, history, system_prompt):
prompt = ""
if system_prompt:
prompt += f"{system_prompt}"
for user_prompt, bot_response in history:
prompt += f"{user_prompt}"
prompt += f"{bot_response}"
prompt += f"""{message}"""
return prompt
def generate(
prompt, history, system_prompt="<|endoftext|>", temperature=0.9, max_new_tokens=100, top_p=0.95, repetition_penalty=1.0, seed=42,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
stop_sequences=STOP_SEQUENCES,
do_sample=True,
#seed=seed,
)
#seed = seed + 1
formatted_prompt = format_prompt(prompt, history, system_prompt)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
for stop_str in STOP_SEQUENCES:
if output.endswith(stop_str):
output = output[:-len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=100,
minimum=0,
maximum=250,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
PERSIAN GPT Trained by Mojtaba Valipour @ Data Analytics Lab
"""
)
gr.ChatInterface(
generate,
examples=EXAMPLES,
additional_inputs=additional_inputs,
)
demo.queue(concurrency_count=100, api_open=False).launch(show_api=False)