nafisneehal's picture
Update app.py
95b1f7b verified
raw
history blame
6.71 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
import spaces
import os
import json
from huggingface_hub import login
# Hugging Face authentication
HF_TOKEN = os.getenv('Secrets.HF_TOKEN')
try:
login(token=HF_TOKEN)
except Exception as e:
print(f"Error logging in to Hugging Face: {str(e)}")
# File to store model links
MODEL_FILE = "model_links.txt"
def load_model_links():
# """Load model links from file"""
# if not os.path.exists(MODEL_FILE):
# # Create default file with some example models
# with open(MODEL_FILE, "w") as f:
# f.write("meta-llama/Llama-2-7b-chat-hf\n")
# f.write("tiiuae/falcon-7b-instruct\n")
with open(MODEL_FILE, "r") as f:
return [line.strip() for line in f.readlines() if line.strip()]
class ModelManager:
def __init__(self):
self.current_model = None
self.current_tokenizer = None
self.current_model_name = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self, model_name):
"""Load model and free previous model's memory"""
if self.current_model is not None:
del self.current_model
del self.current_tokenizer
torch.cuda.empty_cache()
try:
self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.current_model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
device_map="auto"
)
self.current_model_name = model_name
return f"Successfully loaded model: {model_name}"
except Exception as e:
return f"Error loading model: {str(e)}"
# Initialize model manager
model_manager = ModelManager()
# Default system message for JSON output
default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format.
Each response should be formatted as follows:
{
"response": {
"main_answer": "Your primary response here",
"additional_details": "Any additional information or context",
"confidence": 0.0 to 1.0,
"tags": ["relevant", "tags", "here"]
},
"metadata": {
"response_type": "type of response",
"source": "basis of response if applicable"
}
}
Ensure EVERY response strictly follows this JSON structure."""
@spaces.GPU
def generate_response(model_name, system_instruction, user_input):
"""Generate response with GPU support and JSON formatting"""
if model_manager.current_model_name != model_name:
return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2)
if model_manager.current_model is None:
return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2)
# Prepare the prompt with explicit JSON formatting
prompt = f"""### Instruction:
{system_instruction}
Remember to ALWAYS format your response as valid JSON.
### Input:
{user_input}
### Response:
{{""" # Note the opening curly brace to hint JSON response
inputs = model_manager.current_tokenizer([prompt], return_tensors="pt").to(model_manager.device)
# Generation configuration optimized for JSON output
meta_config = {
"do_sample": False,
"temperature": 0.0,
"max_new_tokens": 512,
"repetition_penalty": 1.1,
"use_cache": True,
"pad_token_id": model_manager.current_tokenizer.eos_token_id,
"eos_token_id": model_manager.current_tokenizer.eos_token_id
}
generation_config = GenerationConfig(**meta_config)
# Generate response
try:
with torch.no_grad():
outputs = model_manager.current_model.generate(
**inputs,
generation_config=generation_config
)
decoded_output = model_manager.current_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
assistant_response = decoded_output.split("### Response:")[-1].strip()
# Clean up and validate JSON
try:
# Find the last complete JSON object
last_brace = assistant_response.rindex('}')
assistant_response = assistant_response[:last_brace + 1]
# Parse and re-format JSON
json_response = json.loads(assistant_response)
return json.dumps(json_response, indent=2)
except (json.JSONDecodeError, ValueError):
return json.dumps({
"error": "Failed to generate valid JSON",
"raw_response": assistant_response
}, indent=2)
except Exception as e:
return json.dumps({
"error": f"Error generating response: {str(e)}",
"details": "An unexpected error occurred during generation"
}, indent=2)
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("# Chat Interface with Model Selection (JSON Output)")
with gr.Row():
# Left column for inputs
with gr.Column():
model_dropdown = gr.Dropdown(
choices=load_model_links(),
label="Select Model",
info="Choose a model from the list"
)
load_button = gr.Button("Load Selected Model")
model_status = gr.Textbox(label="Model Status")
system_instruction = gr.Textbox(
value=default_system_message,
placeholder="Enter system instruction here...",
label="System Instruction",
lines=3
)
user_input = gr.Textbox(
placeholder="Type your message here...",
label="Your Message",
lines=3
)
submit_btn = gr.Button("Submit")
# Right column for bot response
with gr.Column():
response_display = gr.Textbox(
label="Bot Response (JSON)",
interactive=False,
placeholder="Response will appear here in JSON format.",
lines=10
)
# Event handlers
load_button.click(
fn=model_manager.load_model,
inputs=[model_dropdown],
outputs=[model_status]
)
submit_btn.click(
fn=generate_response,
inputs=[model_dropdown, system_instruction, user_input],
outputs=[response_display]
)
# Launch the app
demo.launch()