File size: 3,383 Bytes
139d7b2
827021c
2e786fb
827021c
 
2e786fb
 
 
 
827021c
57fc91e
 
2e786fb
827021c
 
 
 
 
2e786fb
e49c48c
 
 
2e786fb
827021c
 
 
 
2e786fb
 
 
 
827021c
 
2e786fb
d288725
 
 
 
 
827021c
d288725
2e786fb
 
 
 
d288725
 
2e786fb
 
 
 
d288725
2e786fb
d288725
827021c
2e786fb
57fc91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import json
from omegaconf import OmegaConf
import sys
import os
from PIL import Image
import torchvision.transforms as transforms

photos_folder = "Photos"

# Pobierz model i config
repo_id = "Kiwinicki/sat2map-generator"
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="model.py")

# Dodaj ścieżkę do modelu
sys.path.append(os.path.dirname(model_path))
from model import Generator

# Załaduj konfigurację
with open(config_path, "r") as f:
    config_dict = json.load(f)
cfg = OmegaConf.create(config_dict)

# Inicjalizacja modelu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(cfg).to(device)
generator.load_state_dict(torch.load(generator_path, map_location=device))
generator.eval()

# Transformacje
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def process_image(image):
    # Konwersja do tensora
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Inferencja
    with torch.no_grad():
        output_tensor = generator(image_tensor)
    
    # Przygotowanie wyjścia
    output_image = output_tensor.squeeze(0).cpu()
    output_image = output_image * 0.5 + 0.5  # Denormalizacja
    output_image = transforms.ToPILImage()(output_image)
    
    return output_image


def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(folder, filename)
            img = Image.open(img_path)
            images.append((img, filename))
    return images



def load_image_from_gallery(images, index):
    if images and 0 <= index < len(images):
        image = images[index]
        if isinstance(image, tuple):
            image = image[0]
        return image
    return None


def gallery_click_event(images, evt: gr.SelectData):
    index = evt.index
    selected_img = load_image_from_gallery(images, index)
    return selected_img


def clear_image():
    return None


def app():
    images = load_images_from_folder(photos_folder)

    with gr.Blocks(css=".container { background-color: white; }") as demo:
        with gr.Row():
            with gr.Column():
                selected_image = gr.Image(label="Input Image", type="pil")
                clear_button = gr.Button("Clear")

            with gr.Column():
                image_gallery = gr.Gallery(label="Image Gallery", elem_id="gallery", type="pil", value=[img for img, _ in images])

            with gr.Column():
                result_image = gr.Image(label="Result Image", type="pil")

        image_gallery.select(
            fn=gallery_click_event,
            inputs=image_gallery,
            outputs=selected_image
        )

        selected_image.change(
            fn=process_image,
            inputs=selected_image,
            outputs=result_image
        )

        clear_button.click(
            fn=clear_image,
            inputs=None,
            outputs=selected_image
        )

    demo.launch()


if __name__ == "__main__":
    app()