Spaces:
Sleeping
Sleeping
Abdulla Fahem
commited on
Commit
·
27e1d61
1
Parent(s):
b121792
Add application file
Browse files
app.py
CHANGED
@@ -17,6 +17,17 @@ from datetime import datetime
|
|
17 |
import numpy as np
|
18 |
from random import choice
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class TravelDataset(Dataset):
|
21 |
def __init__(self, data, tokenizer, max_length=512):
|
22 |
"""
|
@@ -138,23 +149,54 @@ def create_sample_data():
|
|
138 |
|
139 |
return pd.DataFrame(data)
|
140 |
|
141 |
-
@st.cache_resource
|
142 |
def load_or_train_model():
|
143 |
"""Load trained model or train new one"""
|
|
|
|
|
|
|
|
|
144 |
model_path = "./trained_travel_planner"
|
145 |
|
146 |
if os.path.exists(model_path):
|
147 |
try:
|
|
|
148 |
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
149 |
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
150 |
if torch.cuda.is_available():
|
151 |
model = model.cuda()
|
|
|
|
|
|
|
152 |
return model, tokenizer
|
153 |
except Exception as e:
|
154 |
-
st.error(f"Error loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
|
159 |
def train_model():
|
160 |
"""Train the T5 model on travel planning data"""
|
@@ -411,15 +453,21 @@ def main():
|
|
411 |
st.title("✈️ AI Travel Planner")
|
412 |
st.markdown("### Plan your perfect trip with AI assistance!")
|
413 |
|
414 |
-
#
|
415 |
-
if '
|
416 |
with st.spinner("Loading AI model... Please wait..."):
|
417 |
model, tokenizer = load_or_train_model()
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
# Add training button in sidebar only
|
425 |
with st.sidebar:
|
|
|
17 |
import numpy as np
|
18 |
from random import choice
|
19 |
|
20 |
+
# Add these settings at the top level of your script
|
21 |
+
st.set_page_config(
|
22 |
+
page_title="AI Travel Planner",
|
23 |
+
page_icon="✈️",
|
24 |
+
layout="wide"
|
25 |
+
)
|
26 |
+
|
27 |
+
# Initialize session state for model
|
28 |
+
if 'initialized' not in st.session_state:
|
29 |
+
st.session_state.initialized = False
|
30 |
+
|
31 |
class TravelDataset(Dataset):
|
32 |
def __init__(self, data, tokenizer, max_length=512):
|
33 |
"""
|
|
|
149 |
|
150 |
return pd.DataFrame(data)
|
151 |
|
152 |
+
@st.cache_resource(show_spinner=False)
|
153 |
def load_or_train_model():
|
154 |
"""Load trained model or train new one"""
|
155 |
+
# Check if model exists in session state first
|
156 |
+
if hasattr(st.session_state, 'model') and hasattr(st.session_state, 'tokenizer'):
|
157 |
+
return st.session_state.model, st.session_state.tokenizer
|
158 |
+
|
159 |
model_path = "./trained_travel_planner"
|
160 |
|
161 |
if os.path.exists(model_path):
|
162 |
try:
|
163 |
+
# Load existing model and tokenizer
|
164 |
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
165 |
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
166 |
if torch.cuda.is_available():
|
167 |
model = model.cuda()
|
168 |
+
# Store in session state
|
169 |
+
st.session_state.model = model
|
170 |
+
st.session_state.tokenizer = tokenizer
|
171 |
return model, tokenizer
|
172 |
except Exception as e:
|
173 |
+
st.error(f"Error loading model: {str(e)}")
|
174 |
+
|
175 |
+
# If no model exists, train new one
|
176 |
+
st.warning("No trained model found. Training new model...")
|
177 |
+
model, tokenizer = train_model()
|
178 |
+
# Store in session state
|
179 |
+
st.session_state.model = model
|
180 |
+
st.session_state.tokenizer = tokenizer
|
181 |
+
return model, tokenizer
|
182 |
+
|
183 |
+
# @st.cache_resource
|
184 |
+
# def load_or_train_model():
|
185 |
+
# """Load trained model or train new one"""
|
186 |
+
# model_path = "./trained_travel_planner"
|
187 |
+
|
188 |
+
# if os.path.exists(model_path):
|
189 |
+
# try:
|
190 |
+
# model = T5ForConditionalGeneration.from_pretrained(model_path)
|
191 |
+
# tokenizer = T5Tokenizer.from_pretrained(model_path)
|
192 |
+
# if torch.cuda.is_available():
|
193 |
+
# model = model.cuda()
|
194 |
+
# return model, tokenizer
|
195 |
+
# except Exception as e:
|
196 |
+
# st.error(f"Error loading trained model: {str(e)}")
|
197 |
|
198 |
+
# # If no trained model exists or loading fails, train new model
|
199 |
+
# return train_model()
|
200 |
|
201 |
def train_model():
|
202 |
"""Train the T5 model on travel planning data"""
|
|
|
453 |
st.title("✈️ AI Travel Planner")
|
454 |
st.markdown("### Plan your perfect trip with AI assistance!")
|
455 |
|
456 |
+
# Move model loading to a initialization section
|
457 |
+
if 'initialized' not in st.session_state:
|
458 |
with st.spinner("Loading AI model... Please wait..."):
|
459 |
model, tokenizer = load_or_train_model()
|
460 |
+
st.session_state.initialized = True
|
461 |
+
|
462 |
+
# # Load model only if not in session state
|
463 |
+
# if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
|
464 |
+
# with st.spinner("Loading AI model... Please wait..."):
|
465 |
+
# model, tokenizer = load_or_train_model()
|
466 |
+
# if model is None or tokenizer is None:
|
467 |
+
# st.error("Failed to load/train the AI model. Please try again.")
|
468 |
+
# return
|
469 |
+
# st.session_state['model'] = model
|
470 |
+
# st.session_state['tokenizer'] = tokenizer
|
471 |
|
472 |
# Add training button in sidebar only
|
473 |
with st.sidebar:
|