DeIT-Dreamer / app.py
SoggyKiwi's picture
i fogor πŸ’€
3c13f2b
raw
history blame
2.52 kB
import gradio as gr
import torch
from torch.nn import BCEWithLogitsLoss
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()
def get_encoder_activations(x):
encoder_output = model.vit(x)
final_activations = encoder_output.last_hidden_state[:,0,:]
return final_activations
def process_image(input_image, learning_rate, iterations, n_targets, seed):
if input_image is None:
return None
image = input_image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
pixel_values.requires_grad_(True)
torch.manual_seed(int(seed))
random_one_logits = torch.zeros(1000)
random_one_logits[torch.randperm(1000)[:int(n_targets)]] = 1
random_one_logits = random_one_logits.to(pixel_values.device)
for iteration in range(int(iterations)):
model.zero_grad()
if pixel_values.grad is not None:
pixel_values.grad.data.zero_()
final_activations = get_encoder_activations(pixel_values.to('cuda'))
logits = model.classifier(final_activations[0]).to(pixel_values.device)
original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
original_loss.backward()
with torch.no_grad():
pixel_values.data += learning_rate * pixel_values.grad.data
pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
return updated_pixel_values_np
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil"),
gr.Number(value=1.0, minimum=0, label="Learning Rate"),
gr.Number(value=2, minimum=1, label="Iterations"),
gr.Number(value=420, minimum=0, label="Seed"),
gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
],
outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
)
iface.launch()