Spaces:
Sleeping
Sleeping
File size: 1,859 Bytes
b7ebb88 03f7bd7 b7ebb88 37ebd45 03f7bd7 37ebd45 03f7bd7 37ebd45 03f7bd7 37ebd45 03f7bd7 37ebd45 03f7bd7 37ebd45 03f7bd7 37ebd45 03f7bd7 b7ebb88 68fa56c 03f7bd7 68fa56c 03f7bd7 b7ebb88 37ebd45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gradio as gr
import torch
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()
def process_image(input_image, learning_rate, iterations):
def get_encoder_activations(x):
encoder_output = model.vit(x)
final_activations = encoder_output.last_hidden_state
return final_activations
image = input_image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
pixel_values.requires_grad_(True)
for iteration in range(iterations):
model.zero_grad()
if pixel_values.grad is not None:
pixel_values.grad.data.zero_()
final_activations = get_encoder_activations(pixel_values)
target_sum = final_activations.sum()
target_sum.backward()
with torch.no_grad():
pixel_values.data += learning_rate * pixel_values.grad.data
pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
return updated_pixel_values_np
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil"),
gr.Number(value=0.01, label="Learning Rate"),
gr.Number(value=1, label="Iterations")
],
outputs=gr.Image(type="numpy", label="Processed Image")
)
iface.launch()
|