Diffusers
Safetensors
English
vae
convolutional
generative
vpippi commited on
Commit
edf61ce
·
verified ·
1 Parent(s): dd90e2c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -21
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  language:
3
  - "en"
4
  tags:
@@ -14,22 +15,23 @@ metrics:
14
  - KL
15
  - CER
16
  library_name: diffusers
 
17
  ---
18
 
19
  ## Emuru Convolutional VAE
20
 
21
- This repository hosts the **Emuru Convolutional VAE**, described in our paper. The model features a convolutional encoder and decoder, each with four layers. The output channels for these layers are 32, 64, 128, and 256, respectively. The encoder downsamples an input RGB image \( I \in \mathbb{R}^{3 \times W \times H} \) to a latent representation with a single channel and spatial dimensions \( h \times w \) (where \( h = H/8 \) and \( w = W/8 \)). This design compresses the style information in the image, allowing a lightweight Transformer Decoder to efficiently process the latent features.
22
 
23
  ### Training Details
24
 
25
  - **Hardware:** NVIDIA RTX 4090
26
- - **Iterations:** 60k
27
- - **Optimizer:** AdamW with a learning rate of 1e-4
28
  - **Loss Components:**
29
- - **MAE Loss (\(\mathcal{L}_{MAE}\))** with weight 1
30
- - **WID Loss (\(\mathcal{L}_{WID}\))** with weight 0.005
31
- - **HTR Loss (\(\mathcal{L}_{HTR}\))** with weight 0.3 (using noisy teacher-forcing with probability 0.3)
32
- - **KL Loss (\(\mathcal{L}_{KL}\))** with weight \(\beta = 1\text{e-6}\)
33
 
34
  ### Auxiliary Networks
35
 
@@ -54,29 +56,27 @@ from diffusers import AutoencoderKL
54
  import torch
55
  from torchvision.transforms.functional import to_tensor, to_pil_image
56
  from PIL import Image
57
- import requests
58
- from io import BytesIO
59
 
60
  # Load the pre-trained Emuru VAE from Hugging Face Hub.
61
  model = AutoencoderKL.from_pretrained("vpippi/emuru_vae")
62
 
63
- # Function to load and preprocess an RGB image from a URL:
64
- # Fetches the image via requests, converts it to RGB, and transforms it to a tensor normalized to [0, 1].
65
- def preprocess_image_from_url(url):
66
- response = requests.get(url)
67
- image = Image.open(BytesIO(response.content)).convert("RGB")
68
- image_tensor = to_tensor(image).unsqueeze(0) # Add batch dimension.
69
  return image_tensor
70
 
71
  # Function to postprocess a tensor back to a PIL image for visualization:
72
  # Clamps the tensor to [0, 1] and converts it to a PIL image.
73
  def postprocess_tensor(tensor):
74
- tensor = torch.clamp(tensor, 0, 1).squeeze(0) # Remove batch dimension.
75
  return to_pil_image(tensor)
76
 
77
- # Example URL of the image.
78
- image_url = "https://aimagelab.ing.unimore.it/imagelab/uploadedImages/000883.jpg"
79
- input_image = preprocess_image_from_url(image_url)
 
80
 
81
  # Encode the image to the latent space.
82
  # The encode() method returns an object with a 'latent_dist' attribute.
@@ -89,6 +89,8 @@ with torch.no_grad():
89
  with torch.no_grad():
90
  reconstructed = model.decode(latents).sample
91
 
 
 
92
  # Convert the reconstructed tensor back to a PIL image.
93
  reconstructed_image = postprocess_tensor(reconstructed)
94
 
@@ -106,8 +108,8 @@ If you'd like to test with images hosted directly on the Hugging Face Hub, consi
106
  from huggingface_hub import hf_hub_download
107
  from PIL import Image
108
 
109
- # Replace 'vpippi/emuru_vae' and 'samples/example_image.jpg' with your details.
110
- image_path = hf_hub_download(repo_id="vpippi/emuru_vae", filename="samples/example_image.jpg")
111
  sample_image = Image.open(image_path).convert("RGB")
112
  sample_image.show()
113
  ```
 
1
  ---
2
+
3
  language:
4
  - "en"
5
  tags:
 
15
  - KL
16
  - CER
17
  library_name: diffusers
18
+
19
  ---
20
 
21
  ## Emuru Convolutional VAE
22
 
23
+ This repository hosts the **Emuru Convolutional VAE**, described in our paper. The model features a convolutional encoder and decoder, each with four layers. The output channels for these layers are 32, 64, 128, and 256, respectively. The encoder downsamples an input RGB image (with three channels and dimensions width and height) to a latent representation with a single channel and spatial dimensions that are one-eighth of the original height and width. This design compresses the style information in the image, allowing a lightweight Transformer Decoder to efficiently process the latent features.
24
 
25
  ### Training Details
26
 
27
  - **Hardware:** NVIDIA RTX 4090
28
+ - **Iterations:** 60,000
29
+ - **Optimizer:** AdamW with a learning rate of 0.0001
30
  - **Loss Components:**
31
+ - **MAE Loss:** weight of 1
32
+ - **WID Loss:** weight of 0.005
33
+ - **HTR Loss:** weight of 0.3 (using noisy teacher-forcing with a probability of 0.3)
34
+ - **KL Loss:** with a beta parameter set to 1e-6
35
 
36
  ### Auxiliary Networks
37
 
 
56
  import torch
57
  from torchvision.transforms.functional import to_tensor, to_pil_image
58
  from PIL import Image
 
 
59
 
60
  # Load the pre-trained Emuru VAE from Hugging Face Hub.
61
  model = AutoencoderKL.from_pretrained("vpippi/emuru_vae")
62
 
63
+ # Function to preprocess an RGB image:
64
+ # Loads the image, converts it to RGB, and transforms it to a tensor normalized to [0, 1].
65
+ def preprocess_image(image_path):
66
+ image = Image.open(image_path).convert("RGB")
67
+ image_tensor = to_tensor(image).unsqueeze(0) # Add batch dimension
 
68
  return image_tensor
69
 
70
  # Function to postprocess a tensor back to a PIL image for visualization:
71
  # Clamps the tensor to [0, 1] and converts it to a PIL image.
72
  def postprocess_tensor(tensor):
73
+ tensor = torch.clamp(tensor, 0, 1).squeeze(0) # Remove batch dimension
74
  return to_pil_image(tensor)
75
 
76
+ # Example: Encode and decode an image.
77
+ # Replace with your image path.
78
+ image_path = "/path/to/image"
79
+ input_image = preprocess_image(image_path)
80
 
81
  # Encode the image to the latent space.
82
  # The encode() method returns an object with a 'latent_dist' attribute.
 
89
  with torch.no_grad():
90
  reconstructed = model.decode(latents).sample
91
 
92
+ # Load the original image for comparison.
93
+ original_image = Image.open(image_path).convert("RGB")
94
  # Convert the reconstructed tensor back to a PIL image.
95
  reconstructed_image = postprocess_tensor(reconstructed)
96
 
 
108
  from huggingface_hub import hf_hub_download
109
  from PIL import Image
110
 
111
+ # Replace 'vpippi/emuru_vae' and 'samples/lam_sample.jpg' with your details.
112
+ image_path = hf_hub_download(repo_id="vpippi/emuru_vae", filename="samples/lam_sample.jpg")
113
  sample_image = Image.open(image_path).convert("RGB")
114
  sample_image.show()
115
  ```