old_napoleon / app.py
mmcgovern574's picture
Create app.py
3cbe237 verified
raw
history blame
4.44 kB
import json
import subprocess
from threading import Thread
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
# Update model configuration for Mistral-small-24B
MODEL_ID = "mistralai/Mistral-24B-v0.1"
CHAT_TEMPLATE = "mistral" # Mistral uses its own chat template
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 32768 # Mistral supports longer context
COLOR = "black"
EMOJI = "🌪️" # Mistral-themed emoji
DESCRIPTION = f"This is {MODEL_NAME} model, a powerful 24B parameter language model from Mistral AI."
def load_system_message():
try:
with open('system_message.txt', 'r', encoding='utf-8') as file:
return file.read().strip()
except FileNotFoundError:
print("Warning: system_message.txt not found. Using default message.")
return "You are a helpful assistant. First recognize the user request and then reply carefully with thinking."
except Exception as e:
print(f"Error loading system message: {e}")
return "You are a helpful assistant. First recognize the user request and then reply carefully with thinking."
SYSTEM_MESSAGE = load_system_message()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history using Mistral's chat template
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for user, assistant in history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# Convert messages to Mistral format
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
yield "".join(outputs)
# Load model with optimized settings for Mistral-24B
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
use_double_quant=True, # Enable double quantization
bnb_4bit_quant_type="nf4" # Use normal float 4 for better precision
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
use_flash_attention_2=True, # Enable Flash Attention 2 for better performance
torch_dtype=torch.bfloat16
)
# Create Gradio interface
gr.ChatInterface(
predict,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
examples=[
['What are the key differences between classical and quantum computing?'],
['Explain the concept of recursive neural networks in simple terms.'],
['How does transfer learning work in large language models?'],
['What are the ethical considerations in AI development?']
],
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(SYSTEM_MESSAGE, label="System prompt", visible=False), # Hidden system prompt
gr.Slider(0, 1, 0.7, label="Temperature"), # Adjusted default for Mistral
gr.Slider(0, 32768, 12000, label="Max new tokens"), # Increased for longer context
gr.Slider(1, 100, 50, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
).queue().launch()