Spaces:
Runtime error
Runtime error
File size: 4,451 Bytes
a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d 5a8bfb5 a50ce7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
from threading import Thread
from typing import Iterator, List, Tuple
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# DeepCode-6.7B-Chat
This Space demonstrates model [DeepCode-AI](https://huggingface.co/deepcode-ai/deepcode-ai-6.7b-instruct)
by DeepCode, a code model with 6.7B parameters fine-tuned for chat instructions.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
model = None
else:
model_id = "deepcode-ai/deepcode-ai-6.7b-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
def trim_input_ids(input_ids: torch.Tensor) -> torch.Tensor:
"""
Trim input_ids to fit within the MAX_INPUT_TOKEN_LENGTH.
"""
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
return input_ids
def build_conversation(message: str, chat_history: List[Tuple[str, str]], system_prompt: str) -> List[dict]:
"""
Build the conversation structure for the chat model.
"""
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([
{"role": "user", "content": user},
{"role": "assistant", "content": assistant}
])
conversation.append({"role": "user", "content": message})
return conversation
def generate(
message: str,
chat_history: List[Tuple[str, str]],
system_prompt: str,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
if model is None:
yield "GPU is unavailable. This demo does not run on CPU."
return
conversation = build_conversation(message, chat_history, system_prompt)
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
)
input_ids = trim_input_ids(input_ids.to(model.device))
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
try:
for text in streamer:
outputs.append(text)
yield "".join(outputs).replace("<|EOT|>", "")
except Exception as e:
yield f"Error during generation: {e}"
# Gradio Interface
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
examples=[
["Implement snake game using pygame"],
["Can you explain what the Python programming language is?"],
["Write a program to find the factorial of a number"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(share=True)
|