Spaces:
Runtime error
Runtime error
File size: 4,369 Bytes
f15ed8e f5f6359 f15ed8e 835ba85 f5f6359 5c2ab5b 1dc34d7 d9b9a34 1dc34d7 d9b9a34 1dc34d7 0488844 d6af013 835ba85 d6af013 835ba85 d6af013 335ddf3 1dc34d7 d6af013 1dc34d7 d9b9a34 0eb1946 f15ed8e 7942c52 1dc34d7 0b127ab d9b9a34 f15ed8e 1dc34d7 f15ed8e 0eb1946 f15ed8e d6af013 0eb1946 d6af013 f15ed8e d6af013 f15ed8e 0eb1946 f15ed8e 0eb1946 72ded5c f15ed8e 0eb1946 f15ed8e 0eb1946 f15ed8e 0eb1946 f15ed8e 7942c52 f15ed8e 0b127ab f15ed8e d9b9a34 f15ed8e 00cff3f f15ed8e d9b9a34 f15ed8e d9b9a34 f15ed8e 0eb1946 f5f6359 f15ed8e d6af013 7250fa7 f15ed8e 522d261 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import bitsandbytes
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
DESCRIPTION = """\
# Chat
"""
# Load model with appropriate device configuration
def load_model():
model_id = "CreitinGameplays/Mistral-Nemo-12B-R1-v0.1"
device = "cuda" if torch.cuda.is_available() else "cpu"
# If using CPU, load in 32-bit to avoid potential issues with 16-bit operations
if device == "cpu":
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
tokenizer.use_default_system_prompt = False
return model, tokenizer, device
model, tokenizer, device = load_model()
system_prompt_text = "You are a helpful AI assistant."
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = system_prompt_text,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6, value=system_prompt_text),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=0,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=10).launch(show_error=True) |