VIT_Demo / vit_Training.py
benjaminStreltzin's picture
Rename Vit_Training.py to vit_Training.py
68c6ddf verified
raw
history blame
4.97 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
import torch.optim as optim
import os
import pandas as pd
from sklearn.model_selection import train_test_split
class CustomDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
image = Image.open(image_path).convert('RGB') # Convert to RGB format
if self.transform:
image = self.transform(image)
label = self.dataframe.iloc[idx, 1] # Label is in the second column
return image, label
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
return train_df, val_df
class Custom_VIT_Model:
def __init__(self):
# Check for GPU availability
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the pre-trained ViT model and move it to the device
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
# Freeze pre-trained layers
for param in self.model.parameters():
param.requires_grad = False
# Define a new classifier and move it to the device
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
# Define the optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
# Define the image preprocessing pipeline
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Initialize DataFrame for user data
self.data_file = 'user_data.csv'
if os.path.exists(self.data_file):
self.df = pd.read_csv(self.data_file)
else:
self.df = pd.DataFrame(columns=['image_path', 'label'])
def add_data(self, image_path: str, label: int):
new_entry = pd.DataFrame({'image_path': [image_path], 'label': [label]})
self.df = pd.concat([self.df, new_entry], ignore_index=True)
self.df.to_csv(self.data_file, index=False)
# Check if we have 100 images for retraining
if len(self.df) >= 100:
self.retrain_model()
def retrain_model(self):
# Shuffle and split the data
train_df, val_df = shuffle_and_split_data(self.df)
# Define the dataset and dataloaders
train_dataset = CustomDataset(train_df, transform=self.preprocess)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = CustomDataset(val_df, transform=self.preprocess)
val_loader = DataLoader(val_dataset, batch_size=32)
# Define the loss function
criterion = nn.CrossEntropyLoss().to(self.device)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
self.model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
logits = outputs.logits # Extract logits from the output
loss = criterion(logits, labels)
loss.backward()
self.optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
# Validation loop
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.model(images)
logits = outputs.logits
_, predicted = torch.max(logits, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Validation Accuracy: {correct / total}")
# Save the retrained model
torch.save(self.model.state_dict(), 'trained_model.pth')
print("Model retrained and updated!")
if __name__ == "__main__":
# Initialize the model
custom_model = Custom_VIT_Model()
# Example usage: adding a new image and label
# custom_model.add_data('path/to/image.jpg', 0) # 0 for real, 1 for fake