File size: 6,723 Bytes
addaa24
 
95ce3bb
 
addaa24
95ce3bb
d913f1a
 
 
95b1f7b
d913f1a
 
 
 
addaa24
 
 
 
 
41e7fe3
 
 
 
 
addaa24
 
 
 
 
 
 
 
 
30332c0
 
addaa24
 
 
 
 
 
 
 
95ce3bb
 
 
 
00e507f
 
30332c0
 
95ce3bb
30332c0
95ce3bb
 
addaa24
41e7fe3
 
30332c0
 
41e7fe3
95ce3bb
e380cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30332c0
95ce3bb
 
 
 
addaa24
95ce3bb
 
 
 
 
 
 
 
 
41e7fe3
95ce3bb
76d6bf4
41e7fe3
 
76d6bf4
 
 
 
41e7fe3
76d6bf4
 
 
 
 
95ce3bb
 
 
41e7fe3
95ce3bb
30332c0
41e7fe3
 
30332c0
41e7fe3
 
 
95ce3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30332c0
95ce3bb
 
41e7fe3
 
 
95ce3bb
 
 
 
addaa24
95ce3bb
 
addaa24
 
 
 
 
 
95ce3bb
 
 
 
 
 
addaa24
 
95ce3bb
 
 
addaa24
 
95ce3bb
 
 
 
 
 
 
 
 
addaa24
95ce3bb
addaa24
 
 
 
 
 
 
95ce3bb
 
 
 
addaa24
 
 
95ce3bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
import spaces
import os
import json
from huggingface_hub import login

# Hugging Face authentication
HF_TOKEN = os.getenv('Secrets.HF_TOKEN')
try:
    login(token=HF_TOKEN)
except Exception as e:
    print(f"Error logging in to Hugging Face: {str(e)}")

# File to store model links
MODEL_FILE = "model_links.txt"

def load_model_links():
    """Load model links from file"""
    if not os.path.exists(MODEL_FILE):
        # Create default file with some example models
        with open(MODEL_FILE, "w") as f:
            f.write("meta-llama/Llama-2-7b-chat-hf\n")
    
    with open(MODEL_FILE, "r") as f:
        return [line.strip() for line in f.readlines() if line.strip()]

class ModelManager:
    def __init__(self):
        self.current_model = None
        self.current_tokenizer = None
        self.current_model_name = None
        # Don't initialize CUDA in __init__
        self.device = None
    
    def load_model(self, model_name):
        """Load model and free previous model's memory"""
        if self.current_model is not None:
            del self.current_model
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        try:
            self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.current_model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                load_in_4bit=False,
                torch_dtype=torch.bfloat16,
                device_map="auto"  # Let the model decide device mapping
            )
            self.current_model_name = model_name
            return f"Successfully loaded model: {model_name}"
        except Exception as e:
            return f"Error loading model: {str(e)}"

    def generate(self, prompt):
        """Helper method for generation"""
        inputs = self.current_tokenizer(prompt, return_tensors="pt")
        # Let device mapping happen automatically
        return inputs


# Initialize model manager
model_manager = ModelManager()

# Default system message for JSON output
default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format.
Each response should be formatted as follows:
{
    "response": {
        "main_answer": "Your primary response here",
        "additional_details": "Any additional information or context",
        "confidence": 0.0 to 1.0,
        "tags": ["relevant", "tags", "here"]
    },
    "metadata": {
        "response_type": "type of response",
        "source": "basis of response if applicable"
    }
}
Ensure EVERY response strictly follows this JSON structure."""

@spaces.GPU  # This decorator handles the GPU allocation
def generate_response(model_name, system_instruction, user_input):
    """Generate response with GPU support and JSON formatting"""
    if model_manager.current_model_name != model_name:
        return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2)
    
    if model_manager.current_model is None:
        return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2)

    prompt = f"""### Instruction:
{system_instruction}
Remember to ALWAYS format your response as valid JSON.
### Input:
{user_input}
### Response:
{{"""

    try:
        inputs = model_manager.generate(prompt)
        
        meta_config = {
            "do_sample": False,
            "temperature": 0.0,
            "max_new_tokens": 512,
            "repetition_penalty": 1.1,
            "use_cache": True,
            "pad_token_id": model_manager.current_tokenizer.eos_token_id,
            "eos_token_id": model_manager.current_tokenizer.eos_token_id
        }
        generation_config = GenerationConfig(**meta_config)

        with torch.no_grad():
            outputs = model_manager.current_model.generate(
                **inputs,
                generation_config=generation_config
            )
            
            decoded_output = model_manager.current_tokenizer.batch_decode(
                outputs, 
                skip_special_tokens=True
            )[0]
            
            assistant_response = decoded_output.split("### Response:")[-1].strip()
            
            try:
                last_brace = assistant_response.rindex('}')
                assistant_response = assistant_response[:last_brace + 1]
                json_response = json.loads(assistant_response)
                return json.dumps(json_response, indent=2)
            except (json.JSONDecodeError, ValueError):
                return json.dumps({
                    "error": "Failed to generate valid JSON",
                    "raw_response": assistant_response
                }, indent=2)
                
    except Exception as e:
        return json.dumps({
            "error": f"Error generating response: {str(e)}",
            "details": "An unexpected error occurred during generation"
        }, indent=2)




# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("# Chat Interface with Model Selection (JSON Output)")

    with gr.Row():
        # Left column for inputs
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=load_model_links(),
                label="Select Model",
                info="Choose a model from the list"
            )
            load_button = gr.Button("Load Selected Model")
            model_status = gr.Textbox(label="Model Status")
            
            system_instruction = gr.Textbox(
                value=default_system_message,
                placeholder="Enter system instruction here...",
                label="System Instruction",
                lines=3
            )
            user_input = gr.Textbox(
                placeholder="Type your message here...",
                label="Your Message",
                lines=3
            )
            submit_btn = gr.Button("Submit")

        # Right column for bot response
        with gr.Column():
            response_display = gr.Textbox(
                label="Bot Response (JSON)", 
                interactive=False, 
                placeholder="Response will appear here in JSON format.",
                lines=10
            )

    # Event handlers
    load_button.click(
        fn=model_manager.load_model,
        inputs=[model_dropdown],
        outputs=[model_status]
    )
    
    submit_btn.click(
        fn=generate_response,
        inputs=[model_dropdown, system_instruction, user_input],
        outputs=[response_display]
    )

# Launch the app
demo.launch()