Update README.md
Browse files
README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
---
|
|
|
2 |
language:
|
3 |
- "en"
|
4 |
tags:
|
@@ -14,22 +15,23 @@ metrics:
|
|
14 |
- KL
|
15 |
- CER
|
16 |
library_name: diffusers
|
|
|
17 |
---
|
18 |
|
19 |
## Emuru Convolutional VAE
|
20 |
|
21 |
-
This repository hosts the **Emuru Convolutional VAE**, described in our paper. The model features a convolutional encoder and decoder, each with four layers. The output channels for these layers are 32, 64, 128, and 256, respectively. The encoder downsamples an input RGB image
|
22 |
|
23 |
### Training Details
|
24 |
|
25 |
- **Hardware:** NVIDIA RTX 4090
|
26 |
-
- **Iterations:**
|
27 |
-
- **Optimizer:** AdamW with a learning rate of
|
28 |
- **Loss Components:**
|
29 |
-
- **MAE Loss
|
30 |
-
- **WID Loss
|
31 |
-
- **HTR Loss
|
32 |
-
- **KL Loss
|
33 |
|
34 |
### Auxiliary Networks
|
35 |
|
@@ -54,29 +56,27 @@ from diffusers import AutoencoderKL
|
|
54 |
import torch
|
55 |
from torchvision.transforms.functional import to_tensor, to_pil_image
|
56 |
from PIL import Image
|
57 |
-
import requests
|
58 |
-
from io import BytesIO
|
59 |
|
60 |
# Load the pre-trained Emuru VAE from Hugging Face Hub.
|
61 |
model = AutoencoderKL.from_pretrained("vpippi/emuru_vae")
|
62 |
|
63 |
-
# Function to
|
64 |
-
#
|
65 |
-
def
|
66 |
-
|
67 |
-
|
68 |
-
image_tensor = to_tensor(image).unsqueeze(0) # Add batch dimension.
|
69 |
return image_tensor
|
70 |
|
71 |
# Function to postprocess a tensor back to a PIL image for visualization:
|
72 |
# Clamps the tensor to [0, 1] and converts it to a PIL image.
|
73 |
def postprocess_tensor(tensor):
|
74 |
-
tensor = torch.clamp(tensor, 0, 1).squeeze(0) # Remove batch dimension
|
75 |
return to_pil_image(tensor)
|
76 |
|
77 |
-
# Example
|
78 |
-
|
79 |
-
|
|
|
80 |
|
81 |
# Encode the image to the latent space.
|
82 |
# The encode() method returns an object with a 'latent_dist' attribute.
|
@@ -89,6 +89,8 @@ with torch.no_grad():
|
|
89 |
with torch.no_grad():
|
90 |
reconstructed = model.decode(latents).sample
|
91 |
|
|
|
|
|
92 |
# Convert the reconstructed tensor back to a PIL image.
|
93 |
reconstructed_image = postprocess_tensor(reconstructed)
|
94 |
|
@@ -106,8 +108,8 @@ If you'd like to test with images hosted directly on the Hugging Face Hub, consi
|
|
106 |
from huggingface_hub import hf_hub_download
|
107 |
from PIL import Image
|
108 |
|
109 |
-
# Replace 'vpippi/emuru_vae' and 'samples/
|
110 |
-
image_path = hf_hub_download(repo_id="vpippi/emuru_vae", filename="samples/
|
111 |
sample_image = Image.open(image_path).convert("RGB")
|
112 |
sample_image.show()
|
113 |
```
|
|
|
1 |
---
|
2 |
+
|
3 |
language:
|
4 |
- "en"
|
5 |
tags:
|
|
|
15 |
- KL
|
16 |
- CER
|
17 |
library_name: diffusers
|
18 |
+
|
19 |
---
|
20 |
|
21 |
## Emuru Convolutional VAE
|
22 |
|
23 |
+
This repository hosts the **Emuru Convolutional VAE**, described in our paper. The model features a convolutional encoder and decoder, each with four layers. The output channels for these layers are 32, 64, 128, and 256, respectively. The encoder downsamples an input RGB image (with three channels and dimensions width and height) to a latent representation with a single channel and spatial dimensions that are one-eighth of the original height and width. This design compresses the style information in the image, allowing a lightweight Transformer Decoder to efficiently process the latent features.
|
24 |
|
25 |
### Training Details
|
26 |
|
27 |
- **Hardware:** NVIDIA RTX 4090
|
28 |
+
- **Iterations:** 60,000
|
29 |
+
- **Optimizer:** AdamW with a learning rate of 0.0001
|
30 |
- **Loss Components:**
|
31 |
+
- **MAE Loss:** weight of 1
|
32 |
+
- **WID Loss:** weight of 0.005
|
33 |
+
- **HTR Loss:** weight of 0.3 (using noisy teacher-forcing with a probability of 0.3)
|
34 |
+
- **KL Loss:** with a beta parameter set to 1e-6
|
35 |
|
36 |
### Auxiliary Networks
|
37 |
|
|
|
56 |
import torch
|
57 |
from torchvision.transforms.functional import to_tensor, to_pil_image
|
58 |
from PIL import Image
|
|
|
|
|
59 |
|
60 |
# Load the pre-trained Emuru VAE from Hugging Face Hub.
|
61 |
model = AutoencoderKL.from_pretrained("vpippi/emuru_vae")
|
62 |
|
63 |
+
# Function to preprocess an RGB image:
|
64 |
+
# Loads the image, converts it to RGB, and transforms it to a tensor normalized to [0, 1].
|
65 |
+
def preprocess_image(image_path):
|
66 |
+
image = Image.open(image_path).convert("RGB")
|
67 |
+
image_tensor = to_tensor(image).unsqueeze(0) # Add batch dimension
|
|
|
68 |
return image_tensor
|
69 |
|
70 |
# Function to postprocess a tensor back to a PIL image for visualization:
|
71 |
# Clamps the tensor to [0, 1] and converts it to a PIL image.
|
72 |
def postprocess_tensor(tensor):
|
73 |
+
tensor = torch.clamp(tensor, 0, 1).squeeze(0) # Remove batch dimension
|
74 |
return to_pil_image(tensor)
|
75 |
|
76 |
+
# Example: Encode and decode an image.
|
77 |
+
# Replace with your image path.
|
78 |
+
image_path = "/path/to/image"
|
79 |
+
input_image = preprocess_image(image_path)
|
80 |
|
81 |
# Encode the image to the latent space.
|
82 |
# The encode() method returns an object with a 'latent_dist' attribute.
|
|
|
89 |
with torch.no_grad():
|
90 |
reconstructed = model.decode(latents).sample
|
91 |
|
92 |
+
# Load the original image for comparison.
|
93 |
+
original_image = Image.open(image_path).convert("RGB")
|
94 |
# Convert the reconstructed tensor back to a PIL image.
|
95 |
reconstructed_image = postprocess_tensor(reconstructed)
|
96 |
|
|
|
108 |
from huggingface_hub import hf_hub_download
|
109 |
from PIL import Image
|
110 |
|
111 |
+
# Replace 'vpippi/emuru_vae' and 'samples/lam_sample.jpg' with your details.
|
112 |
+
image_path = hf_hub_download(repo_id="vpippi/emuru_vae", filename="samples/lam_sample.jpg")
|
113 |
sample_image = Image.open(image_path).convert("RGB")
|
114 |
sample_image.show()
|
115 |
```
|