import gradio as gr import torch from huggingface_hub import hf_hub_download import json from omegaconf import OmegaConf import sys import os from PIL import Image import torchvision.transforms as transforms photos_folder = "Photos" # Pobierz model i config repo_id = "Kiwinicki/sat2map-generator" generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") model_path = hf_hub_download(repo_id=repo_id, filename="model.py") # Dodaj ścieżkę do modelu sys.path.append(os.path.dirname(model_path)) from model import Generator # Załaduj konfigurację with open(config_path, "r") as f: config_dict = json.load(f) cfg = OmegaConf.create(config_dict) # Inicjalizacja modelu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generator = Generator(cfg).to(device) generator.load_state_dict(torch.load(generator_path, map_location=device)) generator.eval() # Transformacje transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def process_image(image): # Konwersja do tensora image_tensor = transform(image).unsqueeze(0).to(device) # Inferencja with torch.no_grad(): output_tensor = generator(image_tensor) # Przygotowanie wyjścia output_image = output_tensor.squeeze(0).cpu() output_image = output_image * 0.5 + 0.5 # Denormalizacja output_image = transforms.ToPILImage()(output_image) return output_image def load_images_from_folder(folder): images = [] for filename in os.listdir(folder): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(folder, filename) img = Image.open(img_path) images.append((img, filename)) return images def load_image_from_gallery(images, index): if images and 0 <= index < len(images): image = images[index] if isinstance(image, tuple): image = image[0] return image return None def gallery_click_event(images, evt: gr.SelectData): index = evt.index selected_img = load_image_from_gallery(images, index) return selected_img def clear_image(): return None def app(): images = load_images_from_folder(photos_folder) with gr.Blocks(css=".container { background-color: white; }") as demo: with gr.Row(): with gr.Column(): selected_image = gr.Image(label="Input Image", type="pil") clear_button = gr.Button("Clear") with gr.Column(): image_gallery = gr.Gallery(label="Image Gallery", elem_id="gallery", type="pil", value=[img for img, _ in images]) with gr.Column(): result_image = gr.Image(label="Result Image", type="pil") image_gallery.select( fn=gallery_click_event, inputs=image_gallery, outputs=selected_image ) selected_image.change( fn=process_image, inputs=selected_image, outputs=result_image ) clear_button.click( fn=clear_image, inputs=None, outputs=selected_image ) demo.launch() if __name__ == "__main__": app()