import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from transformers import ViTForImageClassification from PIL import Image import torch.optim as optim import os import pandas as pd from sklearn.model_selection import train_test_split class CustomDataset(Dataset): def __init__(self, dataframe, transform=None): self.dataframe = dataframe self.transform = transform def __len__(self): return len(self.dataframe) def __getitem__(self, idx): image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column image = Image.open(image_path).convert('RGB') # Convert to RGB format if self.transform: image = self.transform(image) label = self.dataframe.iloc[idx, 1] # Label is in the second column return image, label def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) return train_df, val_df class Custom_VIT_Model: def __init__(self): # Check for GPU availability self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the pre-trained ViT model and move it to the device self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device) # Freeze pre-trained layers for param in self.model.parameters(): param.requires_grad = False # Define a new classifier and move it to the device self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device) # Define the optimizer self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) # Define the image preprocessing pipeline self.preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Initialize DataFrame for user data self.data_file = 'user_data.csv' if os.path.exists(self.data_file): self.df = pd.read_csv(self.data_file) else: self.df = pd.DataFrame(columns=['image_path', 'label']) def add_data(self, image_path: str, label: int): # Create a new DataFrame entry new_entry = pd.DataFrame({'image_path': [image_path], 'label': [label]}) # Append the new entry to the existing DataFrame self.df = pd.concat([self.df, new_entry], ignore_index=True) # Save the updated DataFrame to the specified CSV file self.df.to_csv(self.data_file, index=False) # Print the current state of the training data for debugging print("Current training data:") print(self.df) self.df= None # Check if we have 100 images for retraining if len(self.df) >= 100: print("Retraining the model as we have enough data.") self.retrain_model() def retrain_model(self): # Shuffle and split the data train_df, val_df = shuffle_and_split_data(self.df) # Define the dataset and dataloaders train_dataset = CustomDataset(train_df, transform=self.preprocess) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_dataset = CustomDataset(val_df, transform=self.preprocess) val_loader = DataLoader(val_dataset, batch_size=32) # Define the loss function criterion = nn.CrossEntropyLoss().to(self.device) # Training loop num_epochs = 10 for epoch in range(num_epochs): self.model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) logits = outputs.logits # Extract logits from the output loss = criterion(logits, labels) loss.backward() self.optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}") # Validation loop self.model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) logits = outputs.logits _, predicted = torch.max(logits, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Validation Accuracy: {correct / total}") # Save the retrained model torch.save(self.model.state_dict(), 'trained_model.pth') print("Model retrained and updated!") if __name__ == "__main__": # Initialize the model custom_model = Custom_VIT_Model() # Example usage: adding a new image and label # custom_model.add_data('path/to/image.jpg', 0) # 0 for real, 1 for fake