File size: 2,717 Bytes
b7ebb88
03f7bd7
 
 
 
b7ebb88
37ebd45
 
 
 
 
 
03f7bd7
f93fa3d
 
 
 
 
 
 
 
 
 
8c65b05
07b1c90
353541c
37ebd45
03f7bd7
 
37ebd45
03f7bd7
 
a4244e1
 
534e187
a4244e1
5c39195
03f7bd7
 
 
 
37ebd45
a4244e1
f93fa3d
8c65b05
f93fa3d
8c65b05
f93fa3d
03f7bd7
 
37ebd45
03f7bd7
 
 
 
 
 
b7ebb88
 
 
68fa56c
03f7bd7
8c65b05
 
 
f93fa3d
8c65b05
68fa56c
353541c
b7ebb88
 
37ebd45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()

def get_encoder_activations(x):
    encoder_output = model.vit(x)
    final_activations = encoder_output.last_hidden_state[:,0,:]
    return final_activations

def total_variation_loss(img):
    pixel_dif1 = img[:, :, 1:, :] - img[:, :, :-1, :]
    pixel_dif2 = img[:, :, :, 1:] - img[:, :, :, :-1]
    return (torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2)))

def process_image(input_image, learning_rate, tv_weight, iterations, n_targets, seed):
    if input_image is None:
        return None

    image = input_image.convert('RGB')
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    pixel_values.requires_grad_(True)

    
    torch.manual_seed(int(seed))
    random_indices = torch.randperm(1000)[:int(n_targets)].to(pixel_values.device)

    for iteration in range(int(iterations)):
        model.zero_grad()
        if pixel_values.grad is not None:
            pixel_values.grad.data.zero_()

        final_activations = get_encoder_activations(pixel_values)
        logits = model.classifier(final_activations[0])
    
        original_loss = logits[random_indices].sum()
        tv_loss = total_variation_loss(pixel_values)
        total_loss = original_loss - tv_weight * tv_loss
        total_loss.backward()

        with torch.no_grad():
            pixel_values.data += learning_rate * pixel_values.grad.data
        pixel_values.data = torch.clamp(pixel_values.data, -1, 1)

    updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
    updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)

    return updated_pixel_values_np

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil"), 
        gr.Number(value=16.0, minimum=0, label="Learning Rate"),
        gr.Number(value=0.0001, label="Total Variation Loss"), 
        gr.Number(value=4, minimum=1, label="Iterations"),
        gr.Number(value=420, minimum=0, label="Seed"),
        gr.Number(value=500, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
    ],
    outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
)

iface.launch()