|
To load and initialize the `Generator` model from the repository, follow these steps: |
|
|
|
1. **Install Required Packages**: Ensure you have the necessary Python packages installed: |
|
|
|
```python |
|
pip install torch omegaconf huggingface_hub |
|
``` |
|
|
|
2. **Download Model Files**: Retrieve the `generator.pth`, `config.json`, and `model.py` files from the Hugging Face repository. You can use the `huggingface_hub` library for this: |
|
|
|
```python |
|
from huggingface_hub import hf_hub_download |
|
|
|
repo_id = "Kiwinicki/sat2map-generator" |
|
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") |
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.py") |
|
``` |
|
|
|
3. **Load the Model**: Incorporate the downloaded `model.py` to define the `Generator` class, then load the model's state dictionary and configuration: |
|
|
|
```python |
|
import torch |
|
import json |
|
from omegaconf import OmegaConf |
|
import sys |
|
from pathlib import Path |
|
from model import Generator |
|
|
|
# Load configuration |
|
with open(config_path, "r") as f: |
|
config_dict = json.load(f) |
|
cfg = OmegaConf.create(config_dict) |
|
|
|
# Initialize and load the generator model |
|
generator = Generator(cfg) |
|
generator.load_state_dict(torch.load(generator_path)) |
|
generator.eval() |
|
x = torch.randn([1, cfg['channels'], 256, 256]) |
|
out = generator(x) |
|
``` |
|
|
|
Here, `generator` is the initialized model ready for inference. |