File size: 4,510 Bytes
4375b7f
4e683ec
76a154f
f4e3549
be21dc7
76a154f
b1c12fa
76a154f
d534002
be21dc7
4e683ec
f073a1c
76a154f
4375b7f
1f2b852
 
76a154f
665015d
be21dc7
 
 
 
 
665015d
 
 
 
 
f073a1c
76a154f
 
 
 
e12dd90
0da65a1
2fe7e62
4e683ec
 
 
 
 
 
 
 
e12dd90
4e683ec
 
2fe7e62
6111f2c
 
 
 
4e683ec
 
 
6111f2c
4e683ec
 
f073a1c
 
4e683ec
f073a1c
 
4e683ec
 
 
76a154f
4e683ec
 
 
 
76a154f
 
4e683ec
 
 
a400f4b
4e683ec
 
76a154f
 
 
 
4e683ec
 
 
76a154f
 
 
2fe7e62
4e683ec
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
 
4e683ec
 
 
 
 
 
 
 
 
 
 
af33034
665015d
1898074
665015d
4e683ec
cadad8a
e12dd90
4e683ec
 
bb98bf8
665015d
 
 
 
 
 
6336a63
665015d
 
 
 
 
f5aa776
a636dc2
 
76a154f
4e683ec
be21dc7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login,whoami
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import argparse

MAX_MAX_NEW_TOKENS = 128
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model = None
tokenizer = None

my_token = os.getenv("HF_AUTH_TOKEN")

try:
    username = whoami()
except OSError:
    login(token = my_token, add_to_git_credential = True)

model_id = "stabilityai/ar-stablelm-2-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model.generation_config.pad_token_id = model.generation_config.eos_token_id



def generate(
    message: str,
    chat_history: list[dict],
    system_prompt: str = "",
    max_new_tokens: int = 128,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,  # Stop generation at <EOS>
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["السلام عليكم"],
        ["اعرب الجملة التالية: ذهبت الى السوق"],
        ["اضف تشكيل للجملة التالية: ضرب زيدا عمر"],
        ["كم عدد بحور الشعر العربي؟"]
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    # def authenticate_token(token):
    #     try:
    #         login(token)
    #         return "Authenticated successfully"
    #     except:
    #         return "Invalid token. Please try again."

    # # Components
    # token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
    # auth_button = gr.Button("Authenticate")
    # output = gr.Textbox(label="Output")
    # auth_button.click(fn=authenticate_token, inputs=token_input, outputs=output)
    chat_interface.render()

    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Gradio App with Sharing")
    parser.add_argument("--share", action="store_true", help="Enable public sharing") 
    args = parser.parse_args()
    demo.queue(max_size=20).launch(share = args.share)