Abdulla Fahem commited on
Commit
27e1d61
·
1 Parent(s): b121792

Add application file

Browse files
Files changed (1) hide show
  1. app.py +59 -11
app.py CHANGED
@@ -17,6 +17,17 @@ from datetime import datetime
17
  import numpy as np
18
  from random import choice
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  class TravelDataset(Dataset):
21
  def __init__(self, data, tokenizer, max_length=512):
22
  """
@@ -138,23 +149,54 @@ def create_sample_data():
138
 
139
  return pd.DataFrame(data)
140
 
141
- @st.cache_resource
142
  def load_or_train_model():
143
  """Load trained model or train new one"""
 
 
 
 
144
  model_path = "./trained_travel_planner"
145
 
146
  if os.path.exists(model_path):
147
  try:
 
148
  model = T5ForConditionalGeneration.from_pretrained(model_path)
149
  tokenizer = T5Tokenizer.from_pretrained(model_path)
150
  if torch.cuda.is_available():
151
  model = model.cuda()
 
 
 
152
  return model, tokenizer
153
  except Exception as e:
154
- st.error(f"Error loading trained model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- # If no trained model exists or loading fails, train new model
157
- return train_model()
158
 
159
  def train_model():
160
  """Train the T5 model on travel planning data"""
@@ -411,15 +453,21 @@ def main():
411
  st.title("✈️ AI Travel Planner")
412
  st.markdown("### Plan your perfect trip with AI assistance!")
413
 
414
- # Load model only if not in session state
415
- if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
416
  with st.spinner("Loading AI model... Please wait..."):
417
  model, tokenizer = load_or_train_model()
418
- if model is None or tokenizer is None:
419
- st.error("Failed to load/train the AI model. Please try again.")
420
- return
421
- st.session_state['model'] = model
422
- st.session_state['tokenizer'] = tokenizer
 
 
 
 
 
 
423
 
424
  # Add training button in sidebar only
425
  with st.sidebar:
 
17
  import numpy as np
18
  from random import choice
19
 
20
+ # Add these settings at the top level of your script
21
+ st.set_page_config(
22
+ page_title="AI Travel Planner",
23
+ page_icon="✈️",
24
+ layout="wide"
25
+ )
26
+
27
+ # Initialize session state for model
28
+ if 'initialized' not in st.session_state:
29
+ st.session_state.initialized = False
30
+
31
  class TravelDataset(Dataset):
32
  def __init__(self, data, tokenizer, max_length=512):
33
  """
 
149
 
150
  return pd.DataFrame(data)
151
 
152
+ @st.cache_resource(show_spinner=False)
153
  def load_or_train_model():
154
  """Load trained model or train new one"""
155
+ # Check if model exists in session state first
156
+ if hasattr(st.session_state, 'model') and hasattr(st.session_state, 'tokenizer'):
157
+ return st.session_state.model, st.session_state.tokenizer
158
+
159
  model_path = "./trained_travel_planner"
160
 
161
  if os.path.exists(model_path):
162
  try:
163
+ # Load existing model and tokenizer
164
  model = T5ForConditionalGeneration.from_pretrained(model_path)
165
  tokenizer = T5Tokenizer.from_pretrained(model_path)
166
  if torch.cuda.is_available():
167
  model = model.cuda()
168
+ # Store in session state
169
+ st.session_state.model = model
170
+ st.session_state.tokenizer = tokenizer
171
  return model, tokenizer
172
  except Exception as e:
173
+ st.error(f"Error loading model: {str(e)}")
174
+
175
+ # If no model exists, train new one
176
+ st.warning("No trained model found. Training new model...")
177
+ model, tokenizer = train_model()
178
+ # Store in session state
179
+ st.session_state.model = model
180
+ st.session_state.tokenizer = tokenizer
181
+ return model, tokenizer
182
+
183
+ # @st.cache_resource
184
+ # def load_or_train_model():
185
+ # """Load trained model or train new one"""
186
+ # model_path = "./trained_travel_planner"
187
+
188
+ # if os.path.exists(model_path):
189
+ # try:
190
+ # model = T5ForConditionalGeneration.from_pretrained(model_path)
191
+ # tokenizer = T5Tokenizer.from_pretrained(model_path)
192
+ # if torch.cuda.is_available():
193
+ # model = model.cuda()
194
+ # return model, tokenizer
195
+ # except Exception as e:
196
+ # st.error(f"Error loading trained model: {str(e)}")
197
 
198
+ # # If no trained model exists or loading fails, train new model
199
+ # return train_model()
200
 
201
  def train_model():
202
  """Train the T5 model on travel planning data"""
 
453
  st.title("✈️ AI Travel Planner")
454
  st.markdown("### Plan your perfect trip with AI assistance!")
455
 
456
+ # Move model loading to a initialization section
457
+ if 'initialized' not in st.session_state:
458
  with st.spinner("Loading AI model... Please wait..."):
459
  model, tokenizer = load_or_train_model()
460
+ st.session_state.initialized = True
461
+
462
+ # # Load model only if not in session state
463
+ # if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
464
+ # with st.spinner("Loading AI model... Please wait..."):
465
+ # model, tokenizer = load_or_train_model()
466
+ # if model is None or tokenizer is None:
467
+ # st.error("Failed to load/train the AI model. Please try again.")
468
+ # return
469
+ # st.session_state['model'] = model
470
+ # st.session_state['tokenizer'] = tokenizer
471
 
472
  # Add training button in sidebar only
473
  with st.sidebar: