abdullafahem commited on
Commit
12c8cdf
·
verified ·
1 Parent(s): d79e153

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -74
app.py CHANGED
@@ -17,10 +17,6 @@ from datetime import datetime
17
  import numpy as np
18
  from random import choice
19
 
20
- # Initialize session state for model
21
- if 'initialized' not in st.session_state:
22
- st.session_state.initialized = False
23
-
24
  class TravelDataset(Dataset):
25
  def __init__(self, data, tokenizer, max_length=512):
26
  """
@@ -142,54 +138,23 @@ def create_sample_data():
142
 
143
  return pd.DataFrame(data)
144
 
145
- @st.cache_resource(show_spinner=False)
146
  def load_or_train_model():
147
- """Load trained model or train new one"""
148
- # Check if model exists in session state first
149
- if hasattr(st.session_state, 'model') and hasattr(st.session_state, 'tokenizer'):
150
- return st.session_state.model, st.session_state.tokenizer
151
-
152
- model_path = "./trained_travel_planner"
153
 
154
- if os.path.exists(model_path):
155
- try:
156
- # Load existing model and tokenizer
157
- model = T5ForConditionalGeneration.from_pretrained(model_path)
158
- tokenizer = T5Tokenizer.from_pretrained(model_path)
159
- if torch.cuda.is_available():
160
- model = model.cuda()
161
- # Store in session state
162
- st.session_state.model = model
163
- st.session_state.tokenizer = tokenizer
164
- return model, tokenizer
165
- except Exception as e:
166
- st.error(f"Error loading model: {str(e)}")
167
 
168
- # If no model exists, train new one
169
- st.warning("No trained model found. Training new model...")
170
- model, tokenizer = train_model()
171
- # Store in session state
172
- st.session_state.model = model
173
- st.session_state.tokenizer = tokenizer
174
- return model, tokenizer
175
-
176
- # @st.cache_resource
177
- # def load_or_train_model():
178
- # """Load trained model or train new one"""
179
- # model_path = "./trained_travel_planner"
180
-
181
- # if os.path.exists(model_path):
182
- # try:
183
- # model = T5ForConditionalGeneration.from_pretrained(model_path)
184
- # tokenizer = T5Tokenizer.from_pretrained(model_path)
185
- # if torch.cuda.is_available():
186
- # model = model.cuda()
187
- # return model, tokenizer
188
- # except Exception as e:
189
- # st.error(f"Error loading trained model: {str(e)}")
190
-
191
- # # If no trained model exists or loading fails, train new model
192
- # return train_model()
193
 
194
  def train_model():
195
  """Train the T5 model on travel planning data"""
@@ -446,22 +411,6 @@ def main():
446
  st.title("✈️ AI Travel Planner")
447
  st.markdown("### Plan your perfect trip with AI assistance!")
448
 
449
- # Move model loading to a initialization section
450
- if 'initialized' not in st.session_state:
451
- with st.spinner("Loading AI model... Please wait..."):
452
- model, tokenizer = load_or_train_model()
453
- st.session_state.initialized = True
454
-
455
- # # Load model only if not in session state
456
- # if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
457
- # with st.spinner("Loading AI model... Please wait..."):
458
- # model, tokenizer = load_or_train_model()
459
- # if model is None or tokenizer is None:
460
- # st.error("Failed to load/train the AI model. Please try again.")
461
- # return
462
- # st.session_state['model'] = model
463
- # st.session_state['tokenizer'] = tokenizer
464
-
465
  # Add training button in sidebar only
466
  with st.sidebar:
467
  st.header("Model Management")
@@ -485,15 +434,15 @@ def main():
485
  - 5 interest combinations
486
  """)
487
 
488
- # # Load or train model
489
- # if 'model' not in st.session_state:
490
- # with st.spinner("Loading AI model... Please wait..."):
491
- # model, tokenizer = load_or_train_model()
492
- # if model is None or tokenizer is None:
493
- # st.error("Failed to load/train the AI model. Please try again.")
494
- # return
495
- # st.session_state.model = model
496
- # st.session_state.tokenizer = tokenizer
497
 
498
  # Create two columns for input form
499
  col1, col2 = st.columns([2, 1])
 
17
  import numpy as np
18
  from random import choice
19
 
 
 
 
 
20
  class TravelDataset(Dataset):
21
  def __init__(self, data, tokenizer, max_length=512):
22
  """
 
138
 
139
  return pd.DataFrame(data)
140
 
141
+ @st.cache_resource
142
  def load_or_train_model():
143
+ """Load trained model or train new one"""
144
+ model_path = "./trained_travel_planner"
 
 
 
 
145
 
146
+ if os.path.exists(model_path):
147
+ try:
148
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
149
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
150
+ if torch.cuda.is_available():
151
+ model = model.cuda()
152
+ return model, tokenizer
153
+ except Exception as e:
154
+ st.error(f"Error loading trained model: {str(e)}")
 
 
 
 
155
 
156
+ # If no trained model exists or loading fails, train new model
157
+ return train_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def train_model():
160
  """Train the T5 model on travel planning data"""
 
411
  st.title("✈️ AI Travel Planner")
412
  st.markdown("### Plan your perfect trip with AI assistance!")
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  # Add training button in sidebar only
415
  with st.sidebar:
416
  st.header("Model Management")
 
434
  - 5 interest combinations
435
  """)
436
 
437
+ # Load or train model
438
+ if 'model' not in st.session_state:
439
+ with st.spinner("Loading AI model... Please wait..."):
440
+ model, tokenizer = load_or_train_model()
441
+ if model is None or tokenizer is None:
442
+ st.error("Failed to load/train the AI model. Please try again.")
443
+ return
444
+ st.session_state.model = model
445
+ st.session_state.tokenizer = tokenizer
446
 
447
  # Create two columns for input form
448
  col1, col2 = st.columns([2, 1])