Spaces:
Runtime error
Runtime error
File size: 4,510 Bytes
4375b7f 4e683ec 76a154f f4e3549 be21dc7 76a154f b1c12fa 76a154f d534002 be21dc7 4e683ec f073a1c 76a154f 4375b7f 1f2b852 76a154f 665015d be21dc7 665015d f073a1c 76a154f e12dd90 0da65a1 2fe7e62 4e683ec e12dd90 4e683ec 2fe7e62 6111f2c 4e683ec 6111f2c 4e683ec f073a1c 4e683ec f073a1c 4e683ec 76a154f 4e683ec 76a154f 4e683ec a400f4b 4e683ec 76a154f 4e683ec 76a154f 2fe7e62 4e683ec 76a154f 4e683ec 76a154f 4e683ec af33034 665015d 1898074 665015d 4e683ec cadad8a e12dd90 4e683ec bb98bf8 665015d 6336a63 665015d f5aa776 a636dc2 76a154f 4e683ec be21dc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login,whoami
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import argparse
MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model = None
tokenizer = None
my_token = os.getenv("HF_AUTH_TOKEN")
try:
username = whoami()
except OSError:
login(token = my_token, add_to_git_credential = True)
model_id = "stabilityai/ar-stablelm-2-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
def generate(
message: str,
chat_history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 128,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
conversation += chat_history
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
eos_token_id=tokenizer.eos_token_id, # Stop generation at <EOS>
temperature=temperature,
top_p=top_p,
top_k=top_k
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["السلام عليكم"],
["اعرب الجملة التالية: ذهبت الى السوق"],
["اضف تشكيل للجملة التالية: ضرب زيدا عمر"],
["كم عدد بحور الشعر العربي؟"]
],
cache_examples=False,
type="messages",
)
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
# def authenticate_token(token):
# try:
# login(token)
# return "Authenticated successfully"
# except:
# return "Invalid token. Please try again."
# # Components
# token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
# auth_button = gr.Button("Authenticate")
# output = gr.Textbox(label="Output")
# auth_button.click(fn=authenticate_token, inputs=token_input, outputs=output)
chat_interface.render()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gradio App with Sharing")
parser.add_argument("--share", action="store_true", help="Enable public sharing")
args = parser.parse_args()
demo.queue(max_size=20).launch(share = args.share) |