Image Classification
Transformers
English
art
benjaminStreltzin commited on
Commit
be92907
·
verified ·
1 Parent(s): f6991e3

Update vit_model_training.py

Browse files
Files changed (1) hide show
  1. vit_model_training.py +6 -6
vit_model_training.py CHANGED
@@ -8,7 +8,7 @@ import torch.optim as optim
8
  import os
9
  import pandas as pd
10
  from sklearn.model_selection import train_test_split
11
- ## working 18.5.24
12
 
13
 
14
  def labeling(path_real, path_fake):
@@ -36,13 +36,13 @@ class CustomDataset(Dataset):
36
  return len(self.dataframe)
37
 
38
  def __getitem__(self, idx):
39
- image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
40
  image = Image.open(image_path).convert('RGB') # Convert to RGB format
41
 
42
  if self.transform:
43
  image = self.transform(image)
44
 
45
- label = self.dataframe.iloc[idx, 1] # Label is in the second column
46
  return image, label
47
 
48
 
@@ -62,21 +62,21 @@ if __name__ == "__main__":
62
  # Check for GPU availability
63
  device = torch.device('cuda')
64
 
65
- # Load the pre-trained ViT model and move it to GPU
66
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
67
 
68
  # Freeze pre-trained layers
69
  for param in model.parameters():
70
  param.requires_grad = False
71
 
72
- # Define a new classifier and move it to GPU
73
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) # Two output classes: 'REAL' and 'FAKE'
74
 
75
  print(model)
76
  # Define the optimizer
77
  optimizer = optim.Adam(model.parameters(), lr=0.001)
78
 
79
- # Define the image preprocessing pipeline
80
  preprocess = transforms.Compose([
81
  transforms.Resize((224, 224)),
82
  transforms.ToTensor()
 
8
  import os
9
  import pandas as pd
10
  from sklearn.model_selection import train_test_split
11
+ ## working 18.9.24
12
 
13
 
14
  def labeling(path_real, path_fake):
 
36
  return len(self.dataframe)
37
 
38
  def __getitem__(self, idx):
39
+ image_path = self.dataframe.iloc[idx, 0]
40
  image = Image.open(image_path).convert('RGB') # Convert to RGB format
41
 
42
  if self.transform:
43
  image = self.transform(image)
44
 
45
+ label = self.dataframe.iloc[idx, 1]
46
  return image, label
47
 
48
 
 
62
  # Check for GPU availability
63
  device = torch.device('cuda')
64
 
65
+ # Load the pre-trained ViT model
66
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
67
 
68
  # Freeze pre-trained layers
69
  for param in model.parameters():
70
  param.requires_grad = False
71
 
72
+ # Define a new classifier
73
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) # Two output classes: 'REAL' and 'FAKE'
74
 
75
  print(model)
76
  # Define the optimizer
77
  optimizer = optim.Adam(model.parameters(), lr=0.001)
78
 
79
+ # Resize image and make it a tensor (add dimension)
80
  preprocess = transforms.Compose([
81
  transforms.Resize((224, 224)),
82
  transforms.ToTensor()