File size: 6,714 Bytes
addaa24
 
95ce3bb
 
addaa24
95ce3bb
d913f1a
 
 
95b1f7b
d913f1a
 
 
 
addaa24
 
 
 
 
 
 
 
 
95ce3bb
 
addaa24
 
 
 
 
 
 
 
 
95ce3bb
addaa24
 
 
 
 
 
 
 
95ce3bb
 
 
 
 
 
 
 
 
 
 
addaa24
 
 
 
95ce3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
addaa24
95ce3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
addaa24
95ce3bb
 
addaa24
 
 
 
 
 
95ce3bb
 
 
 
 
 
addaa24
 
95ce3bb
 
 
addaa24
 
95ce3bb
 
 
 
 
 
 
 
 
addaa24
95ce3bb
addaa24
 
 
 
 
 
 
95ce3bb
 
 
 
addaa24
 
 
95ce3bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
import spaces
import os
import json
from huggingface_hub import login

# Hugging Face authentication
HF_TOKEN = os.getenv('Secrets.HF_TOKEN')
try:
    login(token=HF_TOKEN)
except Exception as e:
    print(f"Error logging in to Hugging Face: {str(e)}")

# File to store model links
MODEL_FILE = "model_links.txt"

def load_model_links():
    # """Load model links from file"""
    # if not os.path.exists(MODEL_FILE):
    #     # Create default file with some example models
    #     with open(MODEL_FILE, "w") as f:
    #         f.write("meta-llama/Llama-2-7b-chat-hf\n")
    #         f.write("tiiuae/falcon-7b-instruct\n")
    
    with open(MODEL_FILE, "r") as f:
        return [line.strip() for line in f.readlines() if line.strip()]

class ModelManager:
    def __init__(self):
        self.current_model = None
        self.current_tokenizer = None
        self.current_model_name = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    def load_model(self, model_name):
        """Load model and free previous model's memory"""
        if self.current_model is not None:
            del self.current_model
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        try:
            self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.current_model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                load_in_4bit=True,
                device_map="auto"
            )
            self.current_model_name = model_name
            return f"Successfully loaded model: {model_name}"
        except Exception as e:
            return f"Error loading model: {str(e)}"

# Initialize model manager
model_manager = ModelManager()

# Default system message for JSON output
default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format.
Each response should be formatted as follows:

{
    "response": {
        "main_answer": "Your primary response here",
        "additional_details": "Any additional information or context",
        "confidence": 0.0 to 1.0,
        "tags": ["relevant", "tags", "here"]
    },
    "metadata": {
        "response_type": "type of response",
        "source": "basis of response if applicable"
    }
}

Ensure EVERY response strictly follows this JSON structure."""

@spaces.GPU
def generate_response(model_name, system_instruction, user_input):
    """Generate response with GPU support and JSON formatting"""
    if model_manager.current_model_name != model_name:
        return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2)
    
    if model_manager.current_model is None:
        return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2)

    # Prepare the prompt with explicit JSON formatting
    prompt = f"""### Instruction:
{system_instruction}
Remember to ALWAYS format your response as valid JSON.

### Input:
{user_input}

### Response:
{{"""  # Note the opening curly brace to hint JSON response

    inputs = model_manager.current_tokenizer([prompt], return_tensors="pt").to(model_manager.device)

    # Generation configuration optimized for JSON output
    meta_config = {
        "do_sample": False,
        "temperature": 0.0,
        "max_new_tokens": 512,
        "repetition_penalty": 1.1,
        "use_cache": True,
        "pad_token_id": model_manager.current_tokenizer.eos_token_id,
        "eos_token_id": model_manager.current_tokenizer.eos_token_id
    }
    generation_config = GenerationConfig(**meta_config)

    # Generate response
    try:
        with torch.no_grad():
            outputs = model_manager.current_model.generate(
                **inputs,
                generation_config=generation_config
            )
            decoded_output = model_manager.current_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            assistant_response = decoded_output.split("### Response:")[-1].strip()
            
            # Clean up and validate JSON
            try:
                # Find the last complete JSON object
                last_brace = assistant_response.rindex('}')
                assistant_response = assistant_response[:last_brace + 1]
                
                # Parse and re-format JSON
                json_response = json.loads(assistant_response)
                return json.dumps(json_response, indent=2)
            except (json.JSONDecodeError, ValueError):
                return json.dumps({
                    "error": "Failed to generate valid JSON",
                    "raw_response": assistant_response
                }, indent=2)
                
    except Exception as e:
        return json.dumps({
            "error": f"Error generating response: {str(e)}",
            "details": "An unexpected error occurred during generation"
        }, indent=2)

# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("# Chat Interface with Model Selection (JSON Output)")

    with gr.Row():
        # Left column for inputs
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=load_model_links(),
                label="Select Model",
                info="Choose a model from the list"
            )
            load_button = gr.Button("Load Selected Model")
            model_status = gr.Textbox(label="Model Status")
            
            system_instruction = gr.Textbox(
                value=default_system_message,
                placeholder="Enter system instruction here...",
                label="System Instruction",
                lines=3
            )
            user_input = gr.Textbox(
                placeholder="Type your message here...",
                label="Your Message",
                lines=3
            )
            submit_btn = gr.Button("Submit")

        # Right column for bot response
        with gr.Column():
            response_display = gr.Textbox(
                label="Bot Response (JSON)", 
                interactive=False, 
                placeholder="Response will appear here in JSON format.",
                lines=10
            )

    # Event handlers
    load_button.click(
        fn=model_manager.load_model,
        inputs=[model_dropdown],
        outputs=[model_status]
    )
    
    submit_btn.click(
        fn=generate_response,
        inputs=[model_dropdown, system_instruction, user_input],
        outputs=[response_display]
    )

# Launch the app
demo.launch()