khulnasoft's picture
Update app.py
5a8bfb5 verified
raw
history blame
4.45 kB
import os
from threading import Thread
from typing import Iterator, List, Tuple
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# DeepCode-6.7B-Chat
This Space demonstrates model [DeepCode-AI](https://huggingface.co/deepcode-ai/deepcode-ai-6.7b-instruct)
by DeepCode, a code model with 6.7B parameters fine-tuned for chat instructions.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
model = None
else:
model_id = "deepcode-ai/deepcode-ai-6.7b-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
def trim_input_ids(input_ids: torch.Tensor) -> torch.Tensor:
"""
Trim input_ids to fit within the MAX_INPUT_TOKEN_LENGTH.
"""
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
return input_ids
def build_conversation(message: str, chat_history: List[Tuple[str, str]], system_prompt: str) -> List[dict]:
"""
Build the conversation structure for the chat model.
"""
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([
{"role": "user", "content": user},
{"role": "assistant", "content": assistant}
])
conversation.append({"role": "user", "content": message})
return conversation
def generate(
message: str,
chat_history: List[Tuple[str, str]],
system_prompt: str,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
if model is None:
yield "GPU is unavailable. This demo does not run on CPU."
return
conversation = build_conversation(message, chat_history, system_prompt)
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
)
input_ids = trim_input_ids(input_ids.to(model.device))
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
try:
for text in streamer:
outputs.append(text)
yield "".join(outputs).replace("<|EOT|>", "")
except Exception as e:
yield f"Error during generation: {e}"
# Gradio Interface
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
examples=[
["Implement snake game using pygame"],
["Can you explain what the Python programming language is?"],
["Write a program to find the factorial of a number"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(share=True)