Segizu commited on
Commit
eaa5320
·
1 Parent(s): fb57428
Files changed (3) hide show
  1. __pycache__/utils.cpython-39.pyc +0 -0
  2. requirements.txt +2 -1
  3. utils.py +50 -3
__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/__pycache__/utils.cpython-39.pyc and b/__pycache__/utils.cpython-39.pyc differ
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34c
2
  transformers
3
  faiss-cpu
4
  paddlehub
5
- paddlepaddle
 
 
2
  transformers
3
  faiss-cpu
4
  paddlehub
5
+ paddlepaddle
6
+ torch
utils.py CHANGED
@@ -1,14 +1,61 @@
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
 
 
 
 
 
 
 
 
 
 
4
 
5
  ## Cargamos el modelo desde el Hub de Hugging Face
6
- def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
7
- gan = LightweightGAN.from_pretrained(model_name, version=model_version, use_auth_token=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  gan.eval()
9
  return gan
10
 
11
-
12
  ## Usamos el modelo GAN para generar imágenes
13
  def genera(gan, batch_size=1):
14
  with torch.no_grad():
 
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ CONFIG_NAME = "config.json"
7
+ revision = None
8
+ cache_dir = None
9
+ force_download = False
10
+ proxies = None
11
+ resume_download = False
12
+ local_files_only = False
13
+ token = None
14
 
15
  ## Cargamos el modelo desde el Hub de Hugging Face
16
+ def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512"):
17
+ """
18
+ Loads a pre-trained LightweightGAN model from Hugging Face Model Hub.
19
+ Args:
20
+ model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512".
21
+ model_version (str): The version of the pre-trained model to load. Defaults to None.
22
+ Returns:
23
+ LightweightGAN: The loaded pre-trained model.
24
+ """
25
+ # Load the config
26
+ config_file = hf_hub_download(
27
+ repo_id=str(model_name),
28
+ filename=CONFIG_NAME,
29
+ revision=revision,
30
+ cache_dir=cache_dir,
31
+ force_download=force_download,
32
+ proxies=proxies,
33
+ resume_download=resume_download,
34
+ token=token,
35
+ local_files_only=local_files_only,
36
+ )
37
+ with open(config_file, "r", encoding="utf-8") as f:
38
+ config = json.load(f)
39
+
40
+ # Call the _from_pretrained with all the needed arguments
41
+ gan = LightweightGAN(latent_dim=256, image_size=512)
42
+
43
+ gan = gan._from_pretrained(
44
+ model_id=str(model_name),
45
+ revision=revision,
46
+ cache_dir=cache_dir,
47
+ force_download=force_download,
48
+ proxies=proxies,
49
+ resume_download=resume_download,
50
+ local_files_only=local_files_only,
51
+ token=token,
52
+ use_auth_token=False,
53
+ config=config, # usually in **model_kwargs
54
+ )
55
+
56
  gan.eval()
57
  return gan
58
 
 
59
  ## Usamos el modelo GAN para generar imágenes
60
  def genera(gan, batch_size=1):
61
  with torch.no_grad():