File size: 4,451 Bytes
a50ce7d
 
5a8bfb5
a50ce7d
 
 
 
 
5a8bfb5
a50ce7d
 
 
 
 
 
5a8bfb5
 
a50ce7d
 
 
 
5a8bfb5
 
a50ce7d
5a8bfb5
 
 
a50ce7d
 
 
 
5a8bfb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50ce7d
 
5a8bfb5
a50ce7d
5a8bfb5
a50ce7d
 
 
5a8bfb5
a50ce7d
5a8bfb5
 
 
a50ce7d
5a8bfb5
 
 
 
 
a50ce7d
5a8bfb5
 
 
a50ce7d
5a8bfb5
a50ce7d
 
 
 
 
5a8bfb5
a50ce7d
 
 
 
 
5a8bfb5
 
 
 
 
 
a50ce7d
 
5a8bfb5
a50ce7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a8bfb5
a50ce7d
 
 
5a8bfb5
 
 
a50ce7d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
from threading import Thread
from typing import Iterator, List, Tuple

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# DeepCode-6.7B-Chat
This Space demonstrates model [DeepCode-AI](https://huggingface.co/deepcode-ai/deepcode-ai-6.7b-instruct) 
by DeepCode, a code model with 6.7B parameters fine-tuned for chat instructions.
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
    model = None
else:
    model_id = "deepcode-ai/deepcode-ai-6.7b-instruct"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False


def trim_input_ids(input_ids: torch.Tensor) -> torch.Tensor:
    """
    Trim input_ids to fit within the MAX_INPUT_TOKEN_LENGTH.
    """
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
    return input_ids


def build_conversation(message: str, chat_history: List[Tuple[str, str]], system_prompt: str) -> List[dict]:
    """
    Build the conversation structure for the chat model.
    """
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([
            {"role": "user", "content": user},
            {"role": "assistant", "content": assistant}
        ])
    conversation.append({"role": "user", "content": message})
    return conversation


def generate(
    message: str,
    chat_history: List[Tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    if model is None:
        yield "GPU is unavailable. This demo does not run on CPU."
        return

    conversation = build_conversation(message, chat_history, system_prompt)
    input_ids = tokenizer.apply_chat_template(
        conversation, return_tensors="pt", add_generation_prompt=True
    )
    input_ids = trim_input_ids(input_ids.to(model.device))

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        repetition_penalty=repetition_penalty,
        eos_token_id=tokenizer.eos_token_id,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    try:
        for text in streamer:
            outputs.append(text)
            yield "".join(outputs).replace("<|EOT|>", "")
    except Exception as e:
        yield f"Error during generation: {e}"


# Gradio Interface
chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    examples=[
        ["Implement snake game using pygame"],
        ["Can you explain what the Python programming language is?"],
        ["Write a program to find the factorial of a number"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue().launch(share=True)