Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| from omegaconf import OmegaConf | |
| import sys | |
| import os | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| photos_folder = "Photos" | |
| # Pobierz model i config | |
| repo_id = "Kiwinicki/sat2map-generator" | |
| generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.py") | |
| # Dodaj ścieżkę do modelu | |
| sys.path.append(os.path.dirname(model_path)) | |
| from model import Generator | |
| # Załaduj konfigurację | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| cfg = OmegaConf.create(config_dict) | |
| # Inicjalizacja modelu | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| generator = Generator(cfg).to(device) | |
| generator.load_state_dict(torch.load(generator_path, map_location=device)) | |
| generator.eval() | |
| # Transformacje | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| def process_image(image): | |
| # Konwersja do tensora | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| # Inferencja | |
| with torch.no_grad(): | |
| output_tensor = generator(image_tensor) | |
| # Przygotowanie wyjścia | |
| output_image = output_tensor.squeeze(0).cpu() | |
| output_image = output_image * 0.5 + 0.5 # Denormalizacja | |
| output_image = transforms.ToPILImage()(output_image) | |
| return output_image | |
| def load_images_from_folder(folder): | |
| images = [] | |
| for filename in os.listdir(folder): | |
| if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| img_path = os.path.join(folder, filename) | |
| img = Image.open(img_path) | |
| images.append((img, filename)) | |
| return images | |
| def load_image_from_gallery(images, index): | |
| if images and 0 <= index < len(images): | |
| image = images[index] | |
| if isinstance(image, tuple): | |
| image = image[0] | |
| return image | |
| return None | |
| def gallery_click_event(images, evt: gr.SelectData): | |
| index = evt.index | |
| selected_img = load_image_from_gallery(images, index) | |
| return selected_img | |
| def clear_image(): | |
| return None | |
| def app(): | |
| images = load_images_from_folder(photos_folder) | |
| with gr.Blocks(css=".container { background-color: white; }") as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| selected_image = gr.Image(label="Input Image", type="pil") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(): | |
| image_gallery = gr.Gallery(label="Image Gallery", elem_id="gallery", type="pil", value=[img for img, _ in images]) | |
| with gr.Column(): | |
| result_image = gr.Image(label="Result Image", type="pil") | |
| image_gallery.select( | |
| fn=gallery_click_event, | |
| inputs=image_gallery, | |
| outputs=selected_image | |
| ) | |
| selected_image.change( | |
| fn=process_image, | |
| inputs=selected_image, | |
| outputs=result_image | |
| ) | |
| clear_button.click( | |
| fn=clear_image, | |
| inputs=None, | |
| outputs=selected_image | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| app() |