Abdulla Fahem commited on
Commit
07913c3
·
1 Parent(s): ce5a0ef

Add application file

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -138,6 +138,24 @@ def create_sample_data():
138
 
139
  return pd.DataFrame(data)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def train_model():
142
  """Train the T5 model on travel planning data"""
143
  try:
@@ -207,24 +225,6 @@ def train_model():
207
  st.error(f"Error during model training: {str(e)}")
208
  return None, None
209
 
210
- @st.cache_resource
211
- def load_or_train_model():
212
- """Load trained model or train new one"""
213
- model_path = "./trained_travel_planner"
214
-
215
- if os.path.exists(model_path):
216
- try:
217
- model = T5ForConditionalGeneration.from_pretrained(model_path)
218
- tokenizer = T5Tokenizer.from_pretrained(model_path)
219
- if torch.cuda.is_available():
220
- model = model.cuda()
221
- return model, tokenizer
222
- except Exception as e:
223
- st.error(f"Error loading trained model: {str(e)}")
224
-
225
- # If no trained model exists or loading fails, train new model
226
- return train_model()
227
-
228
  def generate_travel_plan(destination, days, interests, budget, model, tokenizer):
229
  """Generate a travel plan using the trained model with enhanced features"""
230
  try:
 
138
 
139
  return pd.DataFrame(data)
140
 
141
+ @st.cache_resource
142
+ def load_or_train_model():
143
+ """Load trained model or train new one"""
144
+ model_path = "./trained_travel_planner"
145
+
146
+ if os.path.exists(model_path):
147
+ try:
148
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
149
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
150
+ if torch.cuda.is_available():
151
+ model = model.cuda()
152
+ return model, tokenizer
153
+ except Exception as e:
154
+ st.error(f"Error loading trained model: {str(e)}")
155
+
156
+ # If no trained model exists or loading fails, train new model
157
+ return train_model()
158
+
159
  def train_model():
160
  """Train the T5 model on travel planning data"""
161
  try:
 
225
  st.error(f"Error during model training: {str(e)}")
226
  return None, None
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def generate_travel_plan(destination, days, interests, budget, model, tokenizer):
229
  """Generate a travel plan using the trained model with enhanced features"""
230
  try: