Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,10 +17,6 @@ from datetime import datetime
|
|
17 |
import numpy as np
|
18 |
from random import choice
|
19 |
|
20 |
-
# Initialize session state for model
|
21 |
-
if 'initialized' not in st.session_state:
|
22 |
-
st.session_state.initialized = False
|
23 |
-
|
24 |
class TravelDataset(Dataset):
|
25 |
def __init__(self, data, tokenizer, max_length=512):
|
26 |
"""
|
@@ -142,54 +138,23 @@ def create_sample_data():
|
|
142 |
|
143 |
return pd.DataFrame(data)
|
144 |
|
145 |
-
@st.cache_resource
|
146 |
def load_or_train_model():
|
147 |
-
|
148 |
-
|
149 |
-
if hasattr(st.session_state, 'model') and hasattr(st.session_state, 'tokenizer'):
|
150 |
-
return st.session_state.model, st.session_state.tokenizer
|
151 |
-
|
152 |
-
model_path = "./trained_travel_planner"
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
st.session_state.tokenizer = tokenizer
|
164 |
-
return model, tokenizer
|
165 |
-
except Exception as e:
|
166 |
-
st.error(f"Error loading model: {str(e)}")
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
model, tokenizer = train_model()
|
171 |
-
# Store in session state
|
172 |
-
st.session_state.model = model
|
173 |
-
st.session_state.tokenizer = tokenizer
|
174 |
-
return model, tokenizer
|
175 |
-
|
176 |
-
# @st.cache_resource
|
177 |
-
# def load_or_train_model():
|
178 |
-
# """Load trained model or train new one"""
|
179 |
-
# model_path = "./trained_travel_planner"
|
180 |
-
|
181 |
-
# if os.path.exists(model_path):
|
182 |
-
# try:
|
183 |
-
# model = T5ForConditionalGeneration.from_pretrained(model_path)
|
184 |
-
# tokenizer = T5Tokenizer.from_pretrained(model_path)
|
185 |
-
# if torch.cuda.is_available():
|
186 |
-
# model = model.cuda()
|
187 |
-
# return model, tokenizer
|
188 |
-
# except Exception as e:
|
189 |
-
# st.error(f"Error loading trained model: {str(e)}")
|
190 |
-
|
191 |
-
# # If no trained model exists or loading fails, train new model
|
192 |
-
# return train_model()
|
193 |
|
194 |
def train_model():
|
195 |
"""Train the T5 model on travel planning data"""
|
@@ -446,22 +411,6 @@ def main():
|
|
446 |
st.title("✈️ AI Travel Planner")
|
447 |
st.markdown("### Plan your perfect trip with AI assistance!")
|
448 |
|
449 |
-
# Move model loading to a initialization section
|
450 |
-
if 'initialized' not in st.session_state:
|
451 |
-
with st.spinner("Loading AI model... Please wait..."):
|
452 |
-
model, tokenizer = load_or_train_model()
|
453 |
-
st.session_state.initialized = True
|
454 |
-
|
455 |
-
# # Load model only if not in session state
|
456 |
-
# if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
|
457 |
-
# with st.spinner("Loading AI model... Please wait..."):
|
458 |
-
# model, tokenizer = load_or_train_model()
|
459 |
-
# if model is None or tokenizer is None:
|
460 |
-
# st.error("Failed to load/train the AI model. Please try again.")
|
461 |
-
# return
|
462 |
-
# st.session_state['model'] = model
|
463 |
-
# st.session_state['tokenizer'] = tokenizer
|
464 |
-
|
465 |
# Add training button in sidebar only
|
466 |
with st.sidebar:
|
467 |
st.header("Model Management")
|
@@ -485,15 +434,15 @@ def main():
|
|
485 |
- 5 interest combinations
|
486 |
""")
|
487 |
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
|
498 |
# Create two columns for input form
|
499 |
col1, col2 = st.columns([2, 1])
|
|
|
17 |
import numpy as np
|
18 |
from random import choice
|
19 |
|
|
|
|
|
|
|
|
|
20 |
class TravelDataset(Dataset):
|
21 |
def __init__(self, data, tokenizer, max_length=512):
|
22 |
"""
|
|
|
138 |
|
139 |
return pd.DataFrame(data)
|
140 |
|
141 |
+
@st.cache_resource
|
142 |
def load_or_train_model():
|
143 |
+
"""Load trained model or train new one"""
|
144 |
+
model_path = "./trained_travel_planner"
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
if os.path.exists(model_path):
|
147 |
+
try:
|
148 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
149 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
model = model.cuda()
|
152 |
+
return model, tokenizer
|
153 |
+
except Exception as e:
|
154 |
+
st.error(f"Error loading trained model: {str(e)}")
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
# If no trained model exists or loading fails, train new model
|
157 |
+
return train_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
def train_model():
|
160 |
"""Train the T5 model on travel planning data"""
|
|
|
411 |
st.title("✈️ AI Travel Planner")
|
412 |
st.markdown("### Plan your perfect trip with AI assistance!")
|
413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
# Add training button in sidebar only
|
415 |
with st.sidebar:
|
416 |
st.header("Model Management")
|
|
|
434 |
- 5 interest combinations
|
435 |
""")
|
436 |
|
437 |
+
# Load or train model
|
438 |
+
if 'model' not in st.session_state:
|
439 |
+
with st.spinner("Loading AI model... Please wait..."):
|
440 |
+
model, tokenizer = load_or_train_model()
|
441 |
+
if model is None or tokenizer is None:
|
442 |
+
st.error("Failed to load/train the AI model. Please try again.")
|
443 |
+
return
|
444 |
+
st.session_state.model = model
|
445 |
+
st.session_state.tokenizer = tokenizer
|
446 |
|
447 |
# Create two columns for input form
|
448 |
col1, col2 = st.columns([2, 1])
|