Hugomartinezg commited on
Commit
654fcf7
·
verified ·
1 Parent(s): b6d8fe5

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +62 -6
utils.py CHANGED
@@ -1,18 +1,74 @@
 
 
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- ## Cargamos el modelo desde el Hub de Hugging Face
7
- def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
8
- gan = LightweightGAN.from_pretrained(model_name, version=model_version)
9
  gan.eval()
10
  return gan
11
 
12
 
13
- ## Usamos el modelo GAN para generar imágenes
14
- def genera(gan, batch_size=1):
 
 
 
 
 
 
 
15
  with torch.no_grad():
16
  ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
17
  ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
18
- return ims
 
1
+ import json
2
+
3
  import numpy as np
4
  import torch
5
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ CONFIG_NAME = "config.json"
9
+ revision = None
10
+ cache_dir = None
11
+ force_download = False
12
+ proxies = None
13
+ resume_download = False
14
+ local_files_only = False
15
+ token = None
16
+
17
+
18
+ def load_model(model_name="ceyda/butterfly_cropped_uniq1K_512"):
19
+ """
20
+ Loads a pre-trained LightweightGAN model from Hugging Face Model Hub.
21
+ Args:
22
+ model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512".
23
+ model_version (str): The version of the pre-trained model to load. Defaults to None.
24
+ Returns:
25
+ LightweightGAN: The loaded pre-trained model.
26
+ """
27
+ # Load the config
28
+ config_file = hf_hub_download(
29
+ repo_id=str(model_name),
30
+ filename=CONFIG_NAME,
31
+ revision=revision,
32
+ cache_dir=cache_dir,
33
+ force_download=force_download,
34
+ proxies=proxies,
35
+ resume_download=resume_download,
36
+ token=token,
37
+ local_files_only=local_files_only,
38
+ )
39
+ with open(config_file, "r", encoding="utf-8") as f:
40
+ config = json.load(f)
41
+
42
+ # Call the _from_pretrained with all the needed arguments
43
+ gan = LightweightGAN(latent_dim=256, image_size=512)
44
 
45
+ gan = gan._from_pretrained(
46
+ model_id=str(model_name),
47
+ revision=revision,
48
+ cache_dir=cache_dir,
49
+ force_download=force_download,
50
+ proxies=proxies,
51
+ resume_download=resume_download,
52
+ local_files_only=local_files_only,
53
+ token=token,
54
+ use_auth_token=False,
55
+ config=config, # usually in **model_kwargs
56
+ )
57
 
 
 
 
58
  gan.eval()
59
  return gan
60
 
61
 
62
+ def generate(gan, batch_size=1):
63
+ """
64
+ Generates images using the given GAN model.
65
+ Args:
66
+ gan (nn.Module): The GAN model to use for generating images.
67
+ batch_size (int, optional): The number of images to generate in each batch. Defaults to 1.
68
+ Returns:
69
+ numpy.ndarray: A numpy array of generated images.
70
+ """
71
  with torch.no_grad():
72
  ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
73
  ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
74
+ return ims