ajsbsd commited on
Commit
5995f30
·
verified ·
1 Parent(s): 82ce89a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
+
6
+ # --- Configuration ---
7
+ # IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID.
8
+ # Examples:
9
+ # LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct"
10
+ # HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing
11
+ HUGGINGFACE_MODEL_ID = "."
12
+
13
+ # You might need to adjust TORCH_DTYPE based on your GPU and model support
14
+ # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
15
+ TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32
16
+
17
+ # Generation parameters (can be adjusted for different response styles)
18
+ MAX_NEW_TOKENS = 512
19
+ DO_SAMPLE = True
20
+ TEMPERATURE = 0.7
21
+ TOP_K = 50
22
+ TOP_P = 0.95
23
+
24
+ # --- Global variables for models and tokenizers ---
25
+ tokenizer = None
26
+ model = None
27
+
28
+ # --- Load Models and Tokenizers Function ---
29
+ def load_model_and_tokenizer():
30
+ """
31
+ Loads the language model and tokenizer from Hugging Face Hub or a local path.
32
+ This function will be called once when the Gradio app starts up.
33
+ """
34
+ global tokenizer, model
35
+
36
+ if tokenizer is not None and model is not None:
37
+ print("Model and tokenizer already loaded.")
38
+ return
39
+
40
+ print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}")
41
+ try:
42
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID)
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+ print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
46
+
47
+ print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...")
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ HUGGINGFACE_MODEL_ID,
50
+ torch_dtype=TORCH_DTYPE,
51
+ device_map="auto" # Automatically maps model to GPU if available, else CPU
52
+ )
53
+ model.eval() # Set model to evaluation mode
54
+ print("Model loaded successfully.")
55
+ except Exception as e:
56
+ print(f"Error loading model or tokenizer: {e}")
57
+ print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.")
58
+ tokenizer = None
59
+ model = None
60
+ raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.")
61
+
62
+
63
+ # --- Generate Response Function ---
64
+ def generate_response(
65
+ message: str, # Current user message
66
+ history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
67
+ ) -> list: # Returns updated history for the Chatbot
68
+ """
69
+ Generates a text response from the loaded model based on user input and chat history.
70
+ """
71
+ global tokenizer, model
72
+
73
+ # Initialize models if not already loaded
74
+ if tokenizer is None or model is None:
75
+ load_model_and_tokenizer()
76
+
77
+ if tokenizer is None or model is None: # Check again in case loading failed
78
+ # history.append([message, "Error: Chatbot model not loaded. Please check logs."])
79
+ # For 'messages' type history, append a dictionary
80
+ history.append({"role": "user", "content": message})
81
+ history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."})
82
+ return history
83
+
84
+ # Format messages for the model's chat template (e.g., for Instruct models)
85
+ # The 'history' now directly contains dictionaries if type='messages' is used.
86
+ messages = history # Use history directly as it's already in the correct format
87
+ messages.append({"role": "user", "content": message}) # Add current user message
88
+
89
+ # Apply the chat template and tokenize
90
+ try:
91
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
+ except Exception as e:
93
+ print(f"Error applying chat template: {e}")
94
+ # Fallback if chat template fails (e.g., for non-chat models)
95
+ # Reconstruct input_text for models without explicit chat templates
96
+ input_text = ""
97
+ for item in history:
98
+ if item["role"] == "user":
99
+ input_text += f"User: {item['content']}\n"
100
+ elif item["role"] == "assistant":
101
+ input_text += f"Assistant: {item['content']}\n"
102
+ input_text += f"User: {message}\nAssistant:"
103
+
104
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
105
+
106
+ # Generate response
107
+ with torch.no_grad(): # Disable gradient calculations for inference
108
+ output_ids = model.generate(
109
+ input_ids,
110
+ max_new_tokens=MAX_NEW_TOKENS,
111
+ do_sample=DO_SAMPLE,
112
+ temperature=TEMPERATURE,
113
+ top_k=TOP_K,
114
+ top_p=TOP_P,
115
+ pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
116
+ )
117
+
118
+ # Decode the generated text, excluding the input prompt part
119
+ generated_token_ids = output_ids[0][input_ids.shape[-1]:]
120
+ generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
121
+
122
+ # --- Update Chat History ---
123
+ # Append the latest generated response to the history with its role
124
+ history.append({"role": "assistant", "content": generated_text})
125
+
126
+ return history
127
+
128
+ # --- Gradio Interface ---
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown(
131
+ """
132
+ # Local Chatbot Powered by Hugging Face Transformers
133
+ Type your message below and chat with the model loaded locally on your machine!
134
+ """
135
+ )
136
+
137
+ # Set type='messages' for the chatbot to use OpenAI-style dictionaries
138
+ chatbot = gr.Chatbot(label="Conversation", type='messages')
139
+ with gr.Row():
140
+ text_input = gr.Textbox(
141
+ label="Your message",
142
+ placeholder="Type your message here...",
143
+ scale=4
144
+ )
145
+ submit_button = gr.Button("Send", scale=1)
146
+
147
+ # Link the text input and button to the generation function
148
+ # Note: 'inputs' will be current message and the full history (as 'messages' type)
149
+ # 'outputs' will be the updated full history
150
+ submit_button.click(
151
+ fn=generate_response,
152
+ inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history
153
+ outputs=[chatbot],
154
+ queue=True # Queue requests for better concurrency
155
+ )
156
+ text_input.submit( # Also trigger on Enter key
157
+ fn=generate_response,
158
+ inputs=[text_input, chatbot],
159
+ outputs=[chatbot],
160
+ queue=True
161
+ )
162
+
163
+ # Clear button
164
+ def clear_chat():
165
+ # When type='messages', the clear function should return an empty list for history
166
+ # and an empty string for the text input.
167
+ return [], ""
168
+ clear_button = gr.Button("Clear Chat")
169
+ clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input])
170
+
171
+
172
+ # Load the model when the app starts. This will ensure it's ready when the first request comes in.
173
+ load_model_and_tokenizer()
174
+
175
+ # Launch the Gradio app
176
+ #demo.queue().launch() # For local development, use launch()
177
+ demo.queue().launch(server_name="0.0.0.0")