trip_planner / app.py
Abdulla Fahem
Add application file
a86a6db
raw
history blame
6.47 kB
import os
import sys
import torch
import pandas as pd
import streamlit as st
from datetime import datetime
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
Trainer,
TrainingArguments,
DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import random
# Ensure reproducibility
torch.manual_seed(42)
random.seed(42)
# Environment setup
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class TravelDataset(Dataset):
def __init__(self, data, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
print(f"Dataset loaded with {len(data)} samples")
print("Columns:", list(data.columns))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
# Input: query
input_text = row['query']
# Target: reference_information
target_text = row['reference_information']
# Tokenize inputs
input_encodings = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Tokenize targets
target_encodings = self.tokenizer(
target_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': input_encodings['input_ids'].squeeze(),
'attention_mask': input_encodings['attention_mask'].squeeze(),
'labels': target_encodings['input_ids'].squeeze()
}
def load_dataset():
"""
Load the travel planning dataset from CSV.
"""
try:
data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
required_columns = ['query', 'reference_information']
for col in required_columns:
if col not in data.columns:
raise ValueError(f"Missing required column: {col}")
print(f"Dataset loaded successfully with {len(data)} rows.")
return data
except Exception as e:
print(f"Error loading dataset: {e}")
sys.exit(1)
def train_model():
try:
# Load dataset
data = load_dataset()
# Initialize model and tokenizer
print("Initializing T5 model and tokenizer...")
tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
model = T5ForConditionalGeneration.from_pretrained('t5-base')
# Split data
train_size = int(0.8 * len(data))
train_data = data[:train_size]
val_data = data[train_size:]
train_dataset = TravelDataset(train_data, tokenizer)
val_dataset = TravelDataset(val_data, tokenizer)
training_args = TrainingArguments(
output_dir="./trained_travel_planner",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
load_best_model_at_end=True,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator
)
print("Training model...")
trainer.train()
model.save_pretrained("./trained_travel_planner")
tokenizer.save_pretrained("./trained_travel_planner")
print("Model training complete!")
return model, tokenizer
except Exception as e:
print(f"Training error: {e}")
return None, None
def generate_travel_plan(query, model, tokenizer):
"""
Generate a travel plan using the trained model.
"""
try:
inputs = tokenizer(
query,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
model = model.cuda()
outputs = model.generate(
**inputs,
max_length=512,
num_beams=4,
no_repeat_ngram_size=3,
num_return_sequences=1
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error generating travel plan: {e}"
def main():
st.set_page_config(
page_title="AI Travel Planner",
page_icon="✈️",
layout="wide"
)
st.title("✈️ AI Travel Planner")
# Sidebar to train model
with st.sidebar:
st.header("Model Management")
if st.button("Retrain Model"):
with st.spinner("Training the model..."):
model, tokenizer = train_model()
if model:
st.session_state['model'] = model
st.session_state['tokenizer'] = tokenizer
st.success("Model retrained successfully!")
else:
st.error("Model retraining failed.")
# Load model if not already loaded
if 'model' not in st.session_state:
with st.spinner("Loading model..."):
model, tokenizer = train_model()
st.session_state['model'] = model
st.session_state['tokenizer'] = tokenizer
# Input query
st.subheader("Plan Your Trip")
query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')")
if st.button("Generate Plan"):
if not query:
st.error("Please enter a query.")
else:
with st.spinner("Generating your travel plan..."):
travel_plan = generate_travel_plan(
query,
st.session_state['model'],
st.session_state['tokenizer']
)
st.subheader("Your Travel Plan")
st.write(travel_plan)
if __name__ == "__main__":
main()