File size: 1,533 Bytes
bb75d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
To load and initialize the `Generator` model from the repository, follow these steps:

1. **Install Required Packages**: Ensure you have the necessary Python packages installed:

    ```python
    pip install torch omegaconf huggingface_hub
    ```

2. **Download Model Files**: Retrieve the `generator.pth`, `config.json`, and `model.py` files from the Hugging Face repository. You can use the `huggingface_hub` library for this:

    ```python
    from huggingface_hub import hf_hub_download

    repo_id = "Kiwinicki/sat2map-generator"
    generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
    ```

3. **Load the Model**: Incorporate the downloaded `model.py` to define the `Generator` class, then load the model's state dictionary and configuration:

    ```python
    import torch
    import json
    from omegaconf import OmegaConf
    import sys
    from pathlib import Path
    from model import Generator

    # Load configuration
    with open(config_path, "r") as f:
        config_dict = json.load(f)
    cfg = OmegaConf.create(config_dict)

    # Initialize and load the generator model
    generator = Generator(cfg)
    generator.load_state_dict(torch.load(generator_path))
    generator.eval()
    x = torch.randn([1, cfg['channels'], 256, 256])
    out = generator(x)
    ```

    Here, `generator` is the initialized model ready for inference.