File size: 3,882 Bytes
4375b7f
4e683ec
76a154f
f4e3549
8c0a7e8
76a154f
b1c12fa
76a154f
d534002
4e683ec
76a154f
 
4375b7f
76a154f
 
ab421f4
aba3816
 
af33034
76a154f
 
 
 
e12dd90
0da65a1
2fe7e62
4e683ec
 
 
 
 
 
 
 
e12dd90
4e683ec
 
2fe7e62
6111f2c
 
 
 
4e683ec
 
 
6111f2c
4e683ec
 
2fe7e62
4e683ec
 
 
 
76a154f
4e683ec
 
 
 
76a154f
 
4e683ec
 
 
a400f4b
4e683ec
 
76a154f
 
 
 
4e683ec
 
 
76a154f
 
 
2fe7e62
4e683ec
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
 
4e683ec
 
 
 
 
 
 
 
 
 
 
af33034
 
4e683ec
cadad8a
e12dd90
4e683ec
 
bb98bf8
4e683ec
6336a63
 
 
 
 
 
 
 
 
 
 
 
4e683ec
76a154f
4e683ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))


model_id = "stabilityai/ar-stablelm-2-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.use_default_system_prompt = False


def generate(
    message: str,
    chat_history: list[dict],
    system_prompt: str = "",
    max_new_tokens: int = 128,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=temperature,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["السلام عليكم"],
        ["اعرب الجملة التالية: ذهبت الى السوق"]
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    def authenticate_token(token):
        try:
            # Validate token using Hugging Face Hub API
            login(token = hf_token)
            return f"Authenticated successfully"
        except HfHubHTTPError:
            return "Invalid token. Please try again."

    # Components
    token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
    auth_button = gr.Button("Authenticate")
    output = gr.Textbox(label="Output")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()