Image Classification
Transformers
English
art
Vit / vit_model_traning.py
benjaminStreltzin's picture
image path updated
9d763af
raw
history blame
4.71 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
import torch.optim as optim
import os
import pandas as pd
from sklearn.model_selection import train_test_split
## working 18.5.24
def labeling(path_real, path_fake):
image_paths = []
labels = []
for filename in os.listdir(path_real):
image_paths.append(os.path.join(path_real, filename))
labels.append(0)
for filename in os.listdir(path_fake):
image_paths.append(os.path.join(path_fake, filename))
labels.append(1)
dataset = pd.DataFrame({'image_path': image_paths, 'label': labels})
return dataset
class CustomDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
image = Image.open(image_path).convert('RGB') # Convert to RGB format
if self.transform:
image = self.transform(image)
label = self.dataframe.iloc[idx, 1] # Label is in the second column
return image, label
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
# Shuffle the DataFrame
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
# Split the DataFrame into train and validation sets
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
return train_df, val_df
if __name__ == "__main__":
# Check for GPU availability
device = torch.device('cuda')
# Load the pre-trained ViT model and move it to GPU
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
# Freeze pre-trained layers
for param in model.parameters():
param.requires_grad = False
# Define a new classifier and move it to GPU
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) # Two output classes: 'REAL' and 'FAKE'
print(model)
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Define the image preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Assuming you have already defined your dataset class and split it into training and validation sets
# Let's call it CustomDataset
train_real_folder = 'datasets/training_set/real'
train_fake_folder = 'datasets/training_set/fake'
train_dataset_df = labeling(train_real_folder, train_fake_folder)
train_dataset_df , val_dataset_df = shuffle_and_split_data(train_dataset_df)
# Define the dataset and dataloaders
train_dataset = CustomDataset(train_dataset_df, transform=preprocess)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = CustomDataset(val_dataset_df, transform=preprocess)
val_loader = DataLoader(val_dataset, batch_size=32)
# Define the loss function and move it to GPU
criterion = nn.CrossEntropyLoss().to(device)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
# Move inputs and labels to GPU
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
logits = outputs.logits # Extract logits from the output
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
# Validation loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device) # Move inputs and labels to GPU
outputs = model(images)
logits = outputs.logits # Extract logits from the output
_, predicted = torch.max(logits, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Validation Accuracy: {correct / total}")
# Save the trained model
torch.save(model.state_dict(), 'trained_model.pth')