Spaces:
Runtime error
Runtime error
File size: 4,062 Bytes
4375b7f 4e683ec 76a154f f4e3549 8c0a7e8 76a154f b1c12fa 76a154f d534002 4e683ec f073a1c 76a154f 4375b7f 1f2b852 76a154f ad4597f 1f2b852 ad4597f f073a1c 76a154f e12dd90 0da65a1 2fe7e62 4e683ec e12dd90 4e683ec 2fe7e62 6111f2c 4e683ec 6111f2c 4e683ec f073a1c 4e683ec f073a1c 4e683ec 76a154f 4e683ec 76a154f 4e683ec a400f4b 4e683ec 76a154f 4e683ec 76a154f 2fe7e62 4e683ec 76a154f 4e683ec 76a154f 4e683ec af33034 4e683ec cadad8a e12dd90 4e683ec bb98bf8 6336a63 7621468 66c2c87 f073a1c d6a1a38 6336a63 a636dc2 f5aa776 a636dc2 76a154f 4e683ec f073a1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model = None
tokenizer = None
def load_model():
global model, tokenizer
model_id = "stabilityai/ar-stablelm-2-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
def generate(
message: str,
chat_history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 128,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
conversation += chat_history
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
eos_token_id=tokenizer.eos_token_id, # Stop generation at <EOS>
temperature=temperature,
top_p=top_p,
top_k=top_k
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["السلام عليكم"],
["اعرب الجملة التالية: ذهبت الى السوق"]
],
cache_examples=False,
type="messages",
)
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
def authenticate_token(token):
try:
login(token = token)
load_model()
return "Authenticated successfully"
except:
return "Invalid token. Please try again."
# Components
token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
auth_button = gr.Button("Authenticate")
output = gr.Textbox(label="Output")
auth_button.click(fn=authenticate_token, inputs=token_input, outputs=output)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch() |