File size: 1,927 Bytes
b7ebb88
03f7bd7
 
 
 
b7ebb88
37ebd45
 
 
 
 
 
03f7bd7
37ebd45
07b1c90
353541c
07b1c90
03f7bd7
 
2795721
03f7bd7
37ebd45
03f7bd7
 
37ebd45
03f7bd7
 
5c39195
03f7bd7
 
 
 
37ebd45
03f7bd7
 
 
 
37ebd45
03f7bd7
 
 
 
 
 
b7ebb88
 
 
68fa56c
03f7bd7
2795721
 
68fa56c
353541c
b7ebb88
 
37ebd45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()

def process_image(input_image, learning_rate, iterations):
    if input_image is None:
        return None
    
    def get_encoder_activations(x):
        encoder_output = model.vit(x)
        final_activations = encoder_output.last_hidden_state[:,0,:]
        return final_activations

    image = input_image.convert('RGB')
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    pixel_values.requires_grad_(True)

    for iteration in range(int(iterations)):
        model.zero_grad()
        if pixel_values.grad is not None:
            pixel_values.grad.data.zero_()

        final_activations = get_encoder_activations(pixel_values)
        target_sum = final_activations.sum()
        target_sum.backward()

        with torch.no_grad():
            pixel_values.data += learning_rate * pixel_values.grad.data
        pixel_values.data = torch.clamp(pixel_values.data, -1, 1)

    updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
    updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)

    return updated_pixel_values_np

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil"), 
        gr.Number(value=4.0, label="Learning Rate"), 
        gr.Number(value=4, label="Iterations")
    ],
    outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
)

iface.launch()