Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| import torchaudio | |
| import torchvision | |
| import numpy as np | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| import sys | |
| # Add parent directory to path to import the preprocess functions | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from preprocess import process_audio_data, process_image_data | |
| # Print library versions | |
| print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") | |
| print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") | |
| print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") | |
| # Device selection | |
| device = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() else "cpu" | |
| ) | |
| print(f"\033[92mINFO\033[0m: Using device: {device}") | |
| # Hyperparameters | |
| batch_size = 16 | |
| epochs = 2 | |
| learning_rate = 0.0001 | |
| # Model save directory | |
| os.makedirs("models/", exist_ok=True) | |
| class WatermelonDataset(Dataset): | |
| def __init__(self, data_dir): | |
| self.data_dir = data_dir | |
| self.samples = [] | |
| # Walk through the directory structure | |
| for sweetness_dir in os.listdir(data_dir): | |
| sweetness = float(sweetness_dir) | |
| sweetness_path = os.path.join(data_dir, sweetness_dir) | |
| if os.path.isdir(sweetness_path): | |
| for id_dir in os.listdir(sweetness_path): | |
| id_path = os.path.join(sweetness_path, id_dir) | |
| if os.path.isdir(id_path): | |
| audio_file = os.path.join(id_path, f"{id_dir}.wav") | |
| image_file = os.path.join(id_path, f"{id_dir}.jpg") | |
| if os.path.exists(audio_file) and os.path.exists(image_file): | |
| self.samples.append((audio_file, image_file, sweetness)) | |
| print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| audio_path, image_path, label = self.samples[idx] | |
| # Load and process audio | |
| try: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| mfcc = process_audio_data(waveform, sample_rate) | |
| # Load and process image | |
| image = torchvision.io.read_image(image_path) | |
| image = image.float() | |
| processed_image = process_image_data(image) | |
| return mfcc, processed_image, torch.tensor(label).float() | |
| except Exception as e: | |
| print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}") | |
| # Return a fallback sample or skip this sample | |
| # For simplicity, we'll return the first sample again | |
| if idx == 0: # Prevent infinite recursion | |
| raise e | |
| return self.__getitem__(0) | |
| class WatermelonModel(torch.nn.Module): | |
| def __init__(self): | |
| super(WatermelonModel, self).__init__() | |
| # LSTM for audio features | |
| self.lstm = torch.nn.LSTM( | |
| input_size=376, hidden_size=64, num_layers=2, batch_first=True | |
| ) | |
| self.lstm_fc = torch.nn.Linear( | |
| 64, 128 | |
| ) # Convert LSTM output to 128-dim for merging | |
| # ResNet50 for image features | |
| self.resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) | |
| self.resnet.fc = torch.nn.Linear( | |
| self.resnet.fc.in_features, 128 | |
| ) # Convert ResNet output to 128-dim for merging | |
| # Fully connected layers for final prediction | |
| self.fc1 = torch.nn.Linear(256, 64) | |
| self.fc2 = torch.nn.Linear(64, 1) | |
| self.relu = torch.nn.ReLU() | |
| def forward(self, mfcc, image): | |
| # LSTM branch | |
| lstm_output, _ = self.lstm(mfcc) | |
| lstm_output = lstm_output[:, -1, :] # Use the output of the last time step | |
| lstm_output = self.lstm_fc(lstm_output) | |
| # ResNet branch | |
| resnet_output = self.resnet(image) | |
| # Concatenate LSTM and ResNet outputs | |
| merged = torch.cat((lstm_output, resnet_output), dim=1) | |
| # Fully connected layers | |
| output = self.relu(self.fc1(merged)) | |
| output = self.fc2(output) | |
| return output | |
| def train_model(data_dir, output_dir="models/"): | |
| # Create dataset | |
| dataset = WatermelonDataset(data_dir) | |
| n_samples = len(dataset) | |
| # Split dataset | |
| train_size = int(0.7 * n_samples) | |
| val_size = int(0.2 * n_samples) | |
| test_size = n_samples - train_size - val_size | |
| train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( | |
| dataset, [train_size, val_size, test_size] | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
| # Initialize model | |
| model = WatermelonModel().to(device) | |
| # Loss function and optimizer | |
| criterion = torch.nn.MSELoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| # TensorBoard | |
| writer = SummaryWriter("runs/") | |
| global_step = 0 | |
| print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs") | |
| print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}") | |
| print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}") | |
| print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}") | |
| print(f"\033[92mINFO\033[0m: Batch size: {batch_size}") | |
| # Training loop | |
| for epoch in range(epochs): | |
| print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})") | |
| model.train() | |
| running_loss = 0.0 | |
| for i, (mfcc, image, label) in enumerate(train_loader): | |
| try: | |
| mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
| optimizer.zero_grad() | |
| output = model(mfcc, image) | |
| label = label.view(-1, 1).float() | |
| loss = criterion(output, label) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| writer.add_scalar("Training Loss", loss.item(), global_step) | |
| global_step += 1 | |
| if i % 10 == 0: | |
| print(f"\033[92mINFO\033[0m: Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}") | |
| except Exception as e: | |
| print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}") | |
| continue | |
| # Validation phase | |
| model.eval() | |
| val_loss = 0.0 | |
| with torch.no_grad(): | |
| for i, (mfcc, image, label) in enumerate(val_loader): | |
| try: | |
| mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
| output = model(mfcc, image) | |
| label = label.view(-1, 1).float() | |
| loss = criterion(output, label) | |
| val_loss += loss.item() | |
| except Exception as e: | |
| print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}") | |
| continue | |
| avg_train_loss = running_loss / len(train_loader) if len(train_loader) > 0 else float('inf') | |
| avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf') | |
| # Record validation loss | |
| writer.add_scalar("Validation Loss", avg_val_loss, epoch) | |
| print( | |
| f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}, " | |
| f"Validation Loss: {avg_val_loss:.4f}" | |
| ) | |
| # Save model checkpoint | |
| timestamp = time.strftime("%Y%m%d-%H%M%S") | |
| model_path = os.path.join(output_dir, f"model_{epoch+1}_{timestamp}.pt") | |
| torch.save(model.state_dict(), model_path) | |
| print( | |
| f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}" | |
| ) | |
| # Save final model | |
| final_model_path = os.path.join(output_dir, "watermelon_model_final.pt") | |
| torch.save(model.state_dict(), final_model_path) | |
| print(f"\033[92mINFO\033[0m: Final model saved: {final_model_path}") | |
| print(f"\033[92mINFO\033[0m: Training complete") | |
| return final_model_path | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Train the Watermelon Sweetness Prediction Model") | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default="../cleaned", | |
| help="Path to the cleaned dataset directory" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="models/", | |
| help="Directory to save model checkpoints and the final model" | |
| ) | |
| args = parser.parse_args() | |
| # Ensure output directory exists | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Train the model | |
| final_model_path = train_model(args.data_dir, args.output_dir) | |
| print(f"\033[92mINFO\033[0m: Training completed successfully!") | |
| print(f"\033[92mINFO\033[0m: Final model saved at: {final_model_path}") |