Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from transformers import ViTForImageClassification | |
from PIL import Image | |
import torch.optim as optim | |
import os | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
class CustomDataset(Dataset): | |
def __init__(self, dataframe, transform=None): | |
self.dataframe = dataframe | |
self.transform = transform | |
def __len__(self): | |
return len(self.dataframe) | |
def __getitem__(self, idx): | |
image_path = self.dataframe.iloc[idx, 0] | |
image = Image.open(image_path).convert('RGB') # Convert to RGB format | |
if self.transform: | |
image = self.transform(image) | |
label = self.dataframe.iloc[idx, 1] | |
return image, label | |
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): | |
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) | |
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) | |
return train_df, val_df | |
class Custom_VIT_Model: | |
def __init__(self): | |
# Use gpu if exist (nvidia only) else cpu (any) | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load the pre-trained ViT model | |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device) | |
# Freeze pre-trained layers | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# Define a new classifier that has 2 outputs (0,1) | |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device) | |
# Set optimizer | |
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) | |
# Set the image preprocessing (resize image) and make it tensor ( Tensor - add a dimension ) | |
self.preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
# Initialize DataFrame for user data | |
self.data_file = 'user_data.csv' | |
if os.path.exists(self.data_file): | |
self.df = pd.read_csv(self.data_file) | |
else: | |
self.df = pd.DataFrame(columns=['image_path', 'label']) | |
def add_data(self, image_path: str, label: int): | |
# Create a new DataFrame entry | |
new_entry = pd.DataFrame({'image_path': [image_path], 'label': [label]}) | |
# Append the new entry to the existing DataFrame | |
self.df = pd.concat([self.df, new_entry], ignore_index=True) | |
# Save the updated DataFrame to the specified CSV file | |
self.df.to_csv(self.data_file, index=False) | |
# Print the current state of the training data for debugging | |
print("Current training data:") | |
print(self.df) | |
# Check if we have 100 images for retraining | |
if len(self.df) >= 100: | |
print("Retraining the model as we have enough data.") | |
self.retrain_model() | |
def retrain_model(self): | |
# Shuffle and split the data | |
train_df, val_df = shuffle_and_split_data(self.df) | |
# Define the dataset and dataloaders | |
train_dataset = CustomDataset(train_df, transform=self.preprocess) | |
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) | |
val_dataset = CustomDataset(val_df, transform=self.preprocess) | |
val_loader = DataLoader(val_dataset, batch_size=32) | |
# Define the loss function | |
criterion = nn.CrossEntropyLoss().to(self.device) | |
# Training loop | |
num_epochs = 10 | |
for epoch in range(num_epochs): | |
self.model.train() | |
running_loss = 0.0 | |
for images, labels in train_loader: | |
images, labels = images.to(self.device), labels.to(self.device) | |
self.optimizer.zero_grad() | |
outputs = self.model(images) | |
logits = outputs.logits # Extract logits from the output | |
loss = criterion(logits, labels) | |
loss.backward() | |
self.optimizer.step() | |
running_loss += loss.item() | |
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}") | |
# Validation loop | |
self.model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for images, labels in val_loader: | |
images, labels = images.to(self.device), labels.to(self.device) | |
outputs = self.model(images) | |
logits = outputs.logits | |
_, predicted = torch.max(logits, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f"Validation Accuracy: {correct / total}") | |
# Save the retrained model | |
torch.save(self.model.state_dict(), 'trained_model.pth') | |
print("Model retrained and updated!") | |
if __name__ == "__main__": | |
custom_model = Custom_VIT_Model() | |