Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from huggingface_hub import hf_hub_url, hf_hub_download | |
| from .convnext import ConvNeXt | |
| from wmdetection.utils import FP16Module | |
| def get_convnext_model(name): | |
| if name == 'convnext-tiny' or name == 'convnext-wm_1102': | |
| model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) | |
| model_ft.head = nn.Sequential( | |
| nn.Linear(in_features=768, out_features=512), | |
| nn.GELU(), | |
| nn.Linear(in_features=512, out_features=256), | |
| nn.GELU(), | |
| nn.Linear(in_features=256, out_features=2), | |
| ) | |
| detector_transforms = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| return model_ft, detector_transforms | |
| def get_resnext_model(name): | |
| if name == 'resnext50_32x4d-small': | |
| model_ft = models.resnext50_32x4d(pretrained=False) | |
| elif name == 'resnext101_32x8d-large': | |
| model_ft = models.resnext101_32x8d(pretrained=False) | |
| num_ftrs = model_ft.fc.in_features | |
| model_ft.fc = nn.Linear(num_ftrs, 2) | |
| detector_transforms = transforms.Compose([ | |
| transforms.Resize((320, 320)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| return model_ft, detector_transforms | |
| def get_watermarks_detection_model(name, device='cpu', fp16=True, pretrained=True, cache_dir='/tmp/watermark-detection'): | |
| assert name in MODELS, f"Unknown model name: {name}" | |
| assert not (fp16 and name.startswith('convnext')), "Can`t use fp16 mode with convnext models" | |
| config = MODELS[name] | |
| model_ft, detector_transforms = config['constructor'](name) | |
| if pretrained: | |
| hf_hub_download(repo_id=config['repo_id'], filename=config['filename'], | |
| cache_dir=cache_dir, force_filename=config['filename']) | |
| weights = torch.load(os.path.join(cache_dir, config['filename']), device) | |
| model_ft.load_state_dict(weights) | |
| if fp16: | |
| model_ft = FP16Module(model_ft) | |
| model_ft.eval() | |
| model_ft = model_ft.to(device) | |
| return model_ft, detector_transforms | |
| MODELS = { | |
| 'convnext-tiny': dict( | |
| constructor=get_convnext_model, | |
| repo_id='boomb0om/watermark-detectors', | |
| filename='convnext-tiny_watermarks_detector.pth', | |
| ), | |
| 'convnext-wm_1102': dict( | |
| constructor=get_convnext_model, | |
| repo_id='Inf009/wm_1102', | |
| filename='convnext_v1_9.pth', | |
| ), | |
| 'resnext101_32x8d-large': dict( | |
| constructor=get_resnext_model, | |
| repo_id='boomb0om/watermark-detectors', | |
| filename='watermark_classifier-resnext101_32x8d-input_size320-4epochs_c097_w082.pth', | |
| ), | |
| 'resnext50_32x4d-small': dict( | |
| constructor=get_resnext_model, | |
| repo_id='boomb0om/watermark-detectors', | |
| filename='watermark_classifier-resnext50_32x4d-input_size320-4epochs_c082_w078.pth', | |
| ) | |
| } |