Spaces:
Sleeping
Sleeping
| from sklearn.metrics import roc_auc_score | |
| from torchmetrics import Accuracy, Recall | |
| import pytorch_lightning as pl | |
| import timm | |
| import torch | |
| import torch.nn.functional as F | |
| import logging | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import v2 | |
| logging.basicConfig(filename='training.log',filemode='w',level=logging.INFO, force=True) | |
| CHECKPOINT = "models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt" | |
| class ImageClassifier(pl.LightningModule): | |
| def __init__(self, lmd=0): | |
| super().__init__() | |
| self.model = timm.create_model('resnet50', pretrained=True, num_classes=1) | |
| self.accuracy = Accuracy(task='binary', threshold=0.5) | |
| self.recall = Recall(task='binary', threshold=0.5) | |
| self.validation_outputs = [] | |
| self.lmd = lmd | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch): | |
| images, labels, _ = batch | |
| outputs = self.forward(images).squeeze() | |
| print(f"Shape of outputs (training): {outputs.shape}") | |
| print(f"Shape of labels (training): {labels.shape}") | |
| loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) | |
| logging.info(f"Training Step - ERM loss: {loss.item()}") | |
| loss += self.lmd * (outputs ** 2).mean() # SD loss penalty | |
| logging.info(f"Training Step - SD loss: {loss.item()}") | |
| return loss | |
| def validation_step(self, batch): | |
| images, labels, _ = batch | |
| outputs = self.forward(images).squeeze() | |
| if outputs.shape == torch.Size([]): | |
| return | |
| print(f"Shape of outputs (validation): {outputs.shape}") | |
| print(f"Shape of labels (validation): {labels.shape}") | |
| loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) | |
| preds = torch.sigmoid(outputs) | |
| self.log('val_loss', loss, prog_bar=True, sync_dist=True) | |
| self.log('val_acc', self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True) | |
| self.log('val_recall', self.recall(preds, labels.int()), prog_bar=True, sync_dist=True) | |
| output = {"val_loss": loss, "preds": preds, "labels": labels} | |
| self.validation_outputs.append(output) | |
| logging.info(f"Validation Step - Batch loss: {loss.item()}") | |
| return output | |
| def predict_step(self, batch): | |
| images, label, domain = batch | |
| outputs = self.forward(images).squeeze() | |
| preds = torch.sigmoid(outputs) | |
| return preds, label, domain | |
| def on_validation_epoch_end(self): | |
| if not self.validation_outputs: | |
| logging.warning("No outputs in validation step to process") | |
| return | |
| preds = torch.cat([x['preds'] for x in self.validation_outputs]) | |
| labels = torch.cat([x['labels'] for x in self.validation_outputs]) | |
| if labels.unique().size(0) == 1: | |
| logging.warning("Only one class in validation step") | |
| return | |
| auc_score = roc_auc_score(labels.cpu(), preds.cpu()) | |
| self.log('val_auc', auc_score, prog_bar=True, sync_dist=True) | |
| logging.info(f"Validation Epoch End - AUC score: {auc_score}") | |
| self.validation_outputs = [] | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005) | |
| return optimizer | |
| def load_image(image_path, transform=None): | |
| image = Image.open(image_path).convert('RGB') | |
| if transform: | |
| image = transform(image) | |
| return image | |
| def predict_single_image(image_path, model, transform=None): | |
| image = load_image(image_path, transform) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| image = image.to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| image = image.unsqueeze(0) | |
| output = model(image).squeeze() | |
| prediction = torch.sigmoid(output).item() | |
| return prediction | |
| def image_generation_detection(image_path): | |
| model = ImageClassifier.load_from_checkpoint(CHECKPOINT) | |
| transform = v2.Compose([ | |
| transforms.ToTensor(), | |
| v2.CenterCrop((256, 256)), | |
| ]) | |
| prediction = predict_single_image(image_path, model, transform) | |
| result = "" | |
| if prediction <= 0.2: | |
| result += "Most likely human" | |
| image_prediction_label = "HUMAN" | |
| else: | |
| result += "Most likely machine" | |
| image_prediction_label = "MACHINE" | |
| image_confidence = min(1, 0.5 + abs(prediction - 0.2)) | |
| result += f" with confidence = {round(image_confidence * 100, 2)}%" | |
| # image_confidence = round(image_confidence * 100, 2) | |
| return image_prediction_label, image_confidence | |
| if __name__ == "__main__": | |
| image_path = "path_to_your_image.jpg" # Replace with your image path | |
| image_prediction_label, image_confidence = image_generation_detection( | |
| image_path, | |
| ) | |