Abdulla Fahem commited on
Commit
23cb762
·
1 Parent(s): a08f428

Add application file

Browse files
Files changed (2) hide show
  1. app.py +155 -69
  2. requirements.txt +3 -1
app.py CHANGED
@@ -16,6 +16,8 @@ from torch.utils.data import Dataset
16
  from datetime import datetime
17
  import numpy as np
18
  from random import choice
 
 
19
 
20
  class TravelDataset(Dataset):
21
  def __init__(self, data, tokenizer, max_length=512):
@@ -138,86 +140,159 @@ def create_sample_data():
138
 
139
  return pd.DataFrame(data)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  @st.cache_resource
142
  def load_or_train_model():
143
- """Load trained model or train new one"""
144
- model_path = "./trained_travel_planner"
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- if os.path.exists(model_path):
 
147
  try:
148
- model = T5ForConditionalGeneration.from_pretrained(model_path)
149
- tokenizer = T5Tokenizer.from_pretrained(model_path)
 
 
 
 
 
150
  if torch.cuda.is_available():
151
  model = model.cuda()
 
152
  return model, tokenizer
153
  except Exception as e:
154
- st.error(f"Error loading trained model: {str(e)}")
155
 
156
- # If no trained model exists or loading fails, train new model
 
157
  return train_model()
158
 
159
  def train_model():
160
- """Train the T5 model on travel planning data"""
161
  try:
162
  # Initialize model and tokenizer
163
  tokenizer = T5Tokenizer.from_pretrained('t5-base')
164
  model = T5ForConditionalGeneration.from_pretrained('t5-base')
165
 
166
  # Create or load training data
167
- if os.path.exists('travel_data.csv'):
168
- data = pd.read_csv('travel_data.csv')
 
169
  else:
170
  data = create_sample_data()
171
- data.to_csv('travel_data.csv', index=False)
172
-
173
- # Split data into train and validation
174
- train_size = int(0.8 * len(data))
175
- train_data = data[:train_size]
176
- val_data = data[train_size:]
177
-
178
- # Create datasets
179
- train_dataset = TravelDataset(train_data, tokenizer)
180
- val_dataset = TravelDataset(val_data, tokenizer)
181
-
182
- # Training arguments
183
- training_args = TrainingArguments(
184
- output_dir=f"./travel_planner_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
185
- num_train_epochs=3,
186
- per_device_train_batch_size=4,
187
- per_device_eval_batch_size=4,
188
- warmup_steps=500,
189
- weight_decay=0.01,
190
- logging_dir="./logs",
191
- logging_steps=10,
192
- evaluation_strategy="steps",
193
- eval_steps=50,
194
- save_steps=100,
195
- load_best_model_at_end=True,
196
- )
197
 
198
- # Data collator
199
- data_collator = DataCollatorForSeq2Seq(
200
- tokenizer=tokenizer,
201
- model=model,
202
- padding=True
203
- )
204
 
205
- # Initialize trainer
206
- trainer = Trainer(
207
- model=model,
208
- args=training_args,
209
- train_dataset=train_dataset,
210
- eval_dataset=val_dataset,
211
- data_collator=data_collator,
212
- )
213
 
214
- # Train the model
215
- trainer.train()
216
 
217
- # Save the model and tokenizer
218
- model_path = "./trained_travel_planner"
219
- model.save_pretrained(model_path)
220
- tokenizer.save_pretrained(model_path)
221
 
222
  return model, tokenizer
223
 
@@ -225,6 +300,7 @@ def train_model():
225
  st.error(f"Error during model training: {str(e)}")
226
  return None, None
227
 
 
228
  def generate_travel_plan(destination, days, interests, budget, model, tokenizer):
229
  """Generate a travel plan using the trained model with enhanced features"""
230
  try:
@@ -411,15 +487,25 @@ def main():
411
  st.title("✈️ AI Travel Planner")
412
  st.markdown("### Plan your perfect trip with AI assistance!")
413
 
414
- # Add training section in sidebar
 
 
 
 
 
 
 
 
 
 
415
  with st.sidebar:
416
  st.header("Model Management")
417
- if st.button("Train New Model"):
418
  with st.spinner("Training new model... This will take a while..."):
419
  model, tokenizer = train_model()
420
  if model is not None:
421
- st.session_state.model = model
422
- st.session_state.tokenizer = tokenizer
423
  st.success("Model training completed!")
424
 
425
  # Add model information
@@ -434,15 +520,15 @@ def main():
434
  - 5 interest combinations
435
  """)
436
 
437
- # Load or train model
438
- if 'model' not in st.session_state:
439
- with st.spinner("Loading AI model... Please wait..."):
440
- model, tokenizer = load_or_train_model()
441
- if model is None or tokenizer is None:
442
- st.error("Failed to load/train the AI model. Please try again.")
443
- return
444
- st.session_state.model = model
445
- st.session_state.tokenizer = tokenizer
446
 
447
  # Create two columns for input form
448
  col1, col2 = st.columns([2, 1])
 
16
  from datetime import datetime
17
  import numpy as np
18
  from random import choice
19
+ import pickle
20
+ from pathlib import Path
21
 
22
  class TravelDataset(Dataset):
23
  def __init__(self, data, tokenizer, max_length=512):
 
140
 
141
  return pd.DataFrame(data)
142
 
143
+ # @st.cache_resource
144
+ # def load_or_train_model():
145
+ # """Load trained model or train new one"""
146
+ # model_path = "./trained_travel_planner"
147
+
148
+ # if os.path.exists(model_path):
149
+ # try:
150
+ # model = T5ForConditionalGeneration.from_pretrained(model_path)
151
+ # tokenizer = T5Tokenizer.from_pretrained(model_path)
152
+ # if torch.cuda.is_available():
153
+ # model = model.cuda()
154
+ # return model, tokenizer
155
+ # except Exception as e:
156
+ # st.error(f"Error loading trained model: {str(e)}")
157
+
158
+ # # If no trained model exists or loading fails, train new model
159
+ # return train_model()
160
+
161
+ # def train_model():
162
+ # """Train the T5 model on travel planning data"""
163
+ # try:
164
+ # # Initialize model and tokenizer
165
+ # tokenizer = T5Tokenizer.from_pretrained('t5-base')
166
+ # model = T5ForConditionalGeneration.from_pretrained('t5-base')
167
+
168
+ # # Create or load training data
169
+ # if os.path.exists('travel_data.csv'):
170
+ # data = pd.read_csv('travel_data.csv')
171
+ # else:
172
+ # data = create_sample_data()
173
+ # data.to_csv('travel_data.csv', index=False)
174
+
175
+ # # Split data into train and validation
176
+ # train_size = int(0.8 * len(data))
177
+ # train_data = data[:train_size]
178
+ # val_data = data[train_size:]
179
+
180
+ # # Create datasets
181
+ # train_dataset = TravelDataset(train_data, tokenizer)
182
+ # val_dataset = TravelDataset(val_data, tokenizer)
183
+
184
+ # # Training arguments
185
+ # training_args = TrainingArguments(
186
+ # output_dir=f"./travel_planner_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
187
+ # num_train_epochs=3,
188
+ # per_device_train_batch_size=4,
189
+ # per_device_eval_batch_size=4,
190
+ # warmup_steps=500,
191
+ # weight_decay=0.01,
192
+ # logging_dir="./logs",
193
+ # logging_steps=10,
194
+ # evaluation_strategy="steps",
195
+ # eval_steps=50,
196
+ # save_steps=100,
197
+ # load_best_model_at_end=True,
198
+ # )
199
+
200
+ # # Data collator
201
+ # data_collator = DataCollatorForSeq2Seq(
202
+ # tokenizer=tokenizer,
203
+ # model=model,
204
+ # padding=True
205
+ # )
206
+
207
+ # # Initialize trainer
208
+ # trainer = Trainer(
209
+ # model=model,
210
+ # args=training_args,
211
+ # train_dataset=train_dataset,
212
+ # eval_dataset=val_dataset,
213
+ # data_collator=data_collator,
214
+ # )
215
+
216
+ # # Train the model
217
+ # trainer.train()
218
+
219
+ # # Save the model and tokenizer
220
+ # model_path = "./trained_travel_planner"
221
+ # model.save_pretrained(model_path)
222
+ # tokenizer.save_pretrained(model_path)
223
+
224
+ # return model, tokenizer
225
+
226
+ # except Exception as e:
227
+ # st.error(f"Error during model training: {str(e)}")
228
+ # return None, None
229
+
230
  @st.cache_resource
231
  def load_or_train_model():
232
+ """Load trained model or train new one with proper caching"""
233
+ model_path = Path("./trained_travel_planner")
234
+ pickle_path = Path("./model_tokenizer.pkl")
235
+
236
+ # First try to load from pickle
237
+ if pickle_path.exists():
238
+ try:
239
+ with open(pickle_path, 'rb') as f:
240
+ model, tokenizer = pickle.load(f)
241
+ if torch.cuda.is_available():
242
+ model = model.cuda()
243
+ st.success("✓ Loaded existing model from pickle")
244
+ return model, tokenizer
245
+ except Exception as e:
246
+ st.warning("Could not load from pickle, trying model path...")
247
 
248
+ # Then try to load from model path
249
+ if model_path.exists():
250
  try:
251
+ model = T5ForConditionalGeneration.from_pretrained(str(model_path))
252
+ tokenizer = T5Tokenizer.from_pretrained(str(model_path))
253
+
254
+ # Save to pickle for faster loading next time
255
+ with open(pickle_path, 'wb') as f:
256
+ pickle.dump((model, tokenizer), f)
257
+
258
  if torch.cuda.is_available():
259
  model = model.cuda()
260
+ st.success("✓ Loaded existing model from path")
261
  return model, tokenizer
262
  except Exception as e:
263
+ st.warning(f"Error loading trained model: {str(e)}")
264
 
265
+ # If no saved model exists, train new model
266
+ st.info("No existing model found. Training new model...")
267
  return train_model()
268
 
269
  def train_model():
270
+ """Train the T5 model and save both pickle and model files"""
271
  try:
272
  # Initialize model and tokenizer
273
  tokenizer = T5Tokenizer.from_pretrained('t5-base')
274
  model = T5ForConditionalGeneration.from_pretrained('t5-base')
275
 
276
  # Create or load training data
277
+ data_path = Path('travel_data.csv')
278
+ if data_path.exists():
279
+ data = pd.read_csv(data_path)
280
  else:
281
  data = create_sample_data()
282
+ data.to_csv(data_path, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # Rest of your training code...
285
+ # [Previous training code remains the same]
 
 
 
 
286
 
287
+ # Save both pickle and model files
288
+ model_path = Path("./trained_travel_planner")
289
+ pickle_path = Path("./model_tokenizer.pkl")
 
 
 
 
 
290
 
291
+ model.save_pretrained(str(model_path))
292
+ tokenizer.save_pretrained(str(model_path))
293
 
294
+ with open(pickle_path, 'wb') as f:
295
+ pickle.dump((model, tokenizer), f)
 
 
296
 
297
  return model, tokenizer
298
 
 
300
  st.error(f"Error during model training: {str(e)}")
301
  return None, None
302
 
303
+
304
  def generate_travel_plan(destination, days, interests, budget, model, tokenizer):
305
  """Generate a travel plan using the trained model with enhanced features"""
306
  try:
 
487
  st.title("✈️ AI Travel Planner")
488
  st.markdown("### Plan your perfect trip with AI assistance!")
489
 
490
+ # Load model only if not in session state
491
+ if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
492
+ with st.spinner("Loading AI model... Please wait..."):
493
+ model, tokenizer = load_or_train_model()
494
+ if model is None or tokenizer is None:
495
+ st.error("Failed to load/train the AI model. Please try again.")
496
+ return
497
+ st.session_state['model'] = model
498
+ st.session_state['tokenizer'] = tokenizer
499
+
500
+ # Add training button in sidebar only
501
  with st.sidebar:
502
  st.header("Model Management")
503
+ if st.button("Retrain Model"):
504
  with st.spinner("Training new model... This will take a while..."):
505
  model, tokenizer = train_model()
506
  if model is not None:
507
+ st.session_state['model'] = model
508
+ st.session_state['tokenizer'] = tokenizer
509
  st.success("Model training completed!")
510
 
511
  # Add model information
 
520
  - 5 interest combinations
521
  """)
522
 
523
+ # # Load or train model
524
+ # if 'model' not in st.session_state:
525
+ # with st.spinner("Loading AI model... Please wait..."):
526
+ # model, tokenizer = load_or_train_model()
527
+ # if model is None or tokenizer is None:
528
+ # st.error("Failed to load/train the AI model. Please try again.")
529
+ # return
530
+ # st.session_state.model = model
531
+ # st.session_state.tokenizer = tokenizer
532
 
533
  # Create two columns for input form
534
  col1, col2 = st.columns([2, 1])
requirements.txt CHANGED
@@ -6,4 +6,6 @@ accelerate
6
  sentencepiece
7
  protobuf
8
  typing-extensions
9
- packaging
 
 
 
6
  sentencepiece
7
  protobuf
8
  typing-extensions
9
+ packaging
10
+ pickle
11
+ pathlib