Spaces:
Running
Running
import gradio as gr | |
import torch.nn as nn | |
from torch import tanh, Tensor | |
from abc import ABC, abstractmethod | |
from huggingface_hub import hf_hub_download | |
import torch | |
import json | |
from omegaconf import OmegaConf | |
import sys | |
sys.path.append(os.path.dirname(model_path)) | |
from model import Generator | |
repo_id = "Kiwinicki/sat2map-generator" | |
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
model_path = hf_hub_download(repo_id=repo_id, filename="model.py") | |
with open(config_path, "r") as f: | |
config_dict = json.load(f) | |
cfg = OmegaConf.create(config_dict) | |
generator = Generator(cfg) | |
generator.load_state_dict(torch.load(generator_path)) | |
generator.eval() | |
from PIL import Image | |
import torchvision.transforms as transforms | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
def process_image(image): | |
image_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output_tensor = generator(image_tensor) | |
output_image = output_tensor.squeeze(0) | |
output_image = transforms.ToPILImage()(output_image) | |
return output_image | |
iface = gr.Interface(fn=process_image, inputs="image", outputs="image", title="Image Generator") | |
iface.launch() | |