File size: 4,062 Bytes
4375b7f
4e683ec
76a154f
f4e3549
8c0a7e8
76a154f
b1c12fa
76a154f
d534002
4e683ec
f073a1c
76a154f
4375b7f
1f2b852
 
76a154f
ad4597f
1f2b852
ad4597f
 
 
f073a1c
 
76a154f
 
 
 
e12dd90
0da65a1
2fe7e62
4e683ec
 
 
 
 
 
 
 
e12dd90
4e683ec
 
2fe7e62
6111f2c
 
 
 
4e683ec
 
 
6111f2c
4e683ec
 
f073a1c
 
4e683ec
f073a1c
 
4e683ec
 
 
76a154f
4e683ec
 
 
 
76a154f
 
4e683ec
 
 
a400f4b
4e683ec
 
76a154f
 
 
 
4e683ec
 
 
76a154f
 
 
2fe7e62
4e683ec
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
 
4e683ec
 
 
 
 
 
 
 
 
 
 
af33034
 
4e683ec
cadad8a
e12dd90
4e683ec
 
bb98bf8
6336a63
 
7621468
66c2c87
f073a1c
d6a1a38
6336a63
 
 
 
 
 
a636dc2
f5aa776
a636dc2
 
76a154f
4e683ec
f073a1c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    model_id = "stabilityai/ar-stablelm-2-chat"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id



def generate(
    message: str,
    chat_history: list[dict],
    system_prompt: str = "",
    max_new_tokens: int = 128,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,  # Stop generation at <EOS>
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["السلام عليكم"],
        ["اعرب الجملة التالية: ذهبت الى السوق"]
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    def authenticate_token(token):
        try:
            login(token = token)
            load_model()
            return "Authenticated successfully"
        except:
            return "Invalid token. Please try again."

    # Components
    token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
    auth_button = gr.Button("Authenticate")
    output = gr.Textbox(label="Output")
    auth_button.click(fn=authenticate_token, inputs=token_input, outputs=output)
    chat_interface.render()

    

if __name__ == "__main__":
    demo.queue(max_size=20).launch()