File size: 7,229 Bytes
88f1511
 
 
 
 
 
 
 
f0c45b2
88f1511
f0c45b2
88f1511
 
672cd19
88f1511
672cd19
88f1511
672cd19
88f1511
 
 
672cd19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88f1511
 
 
 
 
 
672cd19
 
 
bd11797
 
 
 
 
05682ba
830ed7d
672cd19
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5710ac
 
88f1511
 
 
 
 
 
 
 
 
 
 
c5710ac
 
 
 
 
 
 
 
 
 
 
 
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672cd19
88f1511
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 256
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 10240

DESCRIPTION = """\
# CLEX-7B-Chat-16K

This Space demonstrates model [CLEX-7B-Chat-16K](https://huggingface.co/DAMO-NLP-SG/CLEX-7B-Chat-16K), a Llama-2-7B model fine-tuned using our [CLEX](https://arxiv.org/abs/2310.16450) method. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).

The model supports the maximun input sequence length of 64k now.

"""

# LICENSE = """
# <p/>

# ---
# As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
# this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
# """


CITE = """
If you find our project useful, hope you can star our repo and cite our paper as follows:
```
@article{damonlpsg2023clex,
  author = {Chen, Guanzheng and Li, Xin and Meng, Zaiqiao and Liang, Shangsong and Bing, Lidong},
  title = {CLEX: Continuous Length Extrapolation for Large Language Models},
  year = 2023,
  url = {https://arxiv.org/abs/2310.16450}
}
```
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


# if torch.cuda.is_available():
model_id = "DAMO-NLP-SG/CLEX-7b-Chat-16K"
from transformers import AutoModelForCausalLM
from modeling_llama import LlamaForCausalLM
# from configuration_clex import CLEXLlamaConfig
# config = CLEXLlamaConfig.from_pretrained(
#         model_id
#     )
model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
tokenizer.use_default_system_prompt = False

import PyPDF2
from io import BytesIO

def process_pdf(input_pdf):
    # Read the binary data from the input_pdf
    # pdf_data = BytesIO(input_pdf)
    # if pdf_data.getvalue().strip() == b'':
    #     return ""
    # Create a PDF reader object
    reader = PyPDF2.PdfReader(input_pdf.name)
    # Extract the text from each page of the PDF
    text = ""
    for page in reader.pages:
        text += page.extract_text()
    # Close the PDF reader and reset the pointer
    # reader.close()
    # pdf_data.seek(0)
    # Return the extracted text
    return text



def build_chat():
    from fastchat.model import get_conversation_template
    conv = get_conversation_template("vicuna")
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt() 
    return prompt

from fastchat.model import get_conversation_template

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conv = get_conversation_template("vicuna")
    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    # if system_prompt:
    #     conversation.append({"role": "system", "content": system_prompt})
    # for user, assistant in chat_history:
    #     conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    # conversation.append({"role": "user", "content": message})

    # chat = tokenizer.apply_chat_template(conversation, tokenize=False)
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
    if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
        inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


def generate_with_pdf(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    input_pdf: BytesIO = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    if input_pdf is not None:
        pdf_text = process_pdf(input_pdf)
        # print(pdf_text)
        message += f"\nThis is the beginning of a pdf\n{pdf_text}This is the end of a pdf\n"
    yield from generate(
        message,
        chat_history,
        system_prompt,
        max_new_tokens,
        temperature,
        top_p,
        top_k,
        repetition_penalty
    )

chat_interface = gr.ChatInterface(
    fn=generate_with_pdf,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.File(label="PDF File", accept=".pdf"),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)



with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")

    chat_interface.render()
    gr.Markdown(CITE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch(share=False)