Rahatara's picture
Update app.py
666dec2 verified
raw
history blame
2.87 kB
import streamlit as st
from PIL import Image
import numpy as np
import io
import zipfile
def mock_encoder(image):
"""Simulates encoding an image into a latent representation."""
# This is a placeholder. In practice, this would be your trained encoder's output.
return np.random.normal(0, 1, (1, 100)), np.random.normal(0, 1, (1, 100)), np.random.normal(0, 1, (1, 100))
def mock_decoder(latent_representation):
"""Simulates decoding a latent representation back into an image."""
# Returns a random image for demonstration
return np.random.rand(28, 28, 1) * 255
def latent_space_augmentation(image, encoder, decoder, noise_scale=0.1):
"""Performs latent space augmentation by adding noise to the latent representation."""
z_mean, z_log_var, _ = encoder(image)
epsilon = np.random.normal(size=z_mean.shape)
z_augmented = z_mean + np.exp(0.5 * z_log_var) * epsilon * noise_scale
augmented_image = decoder(z_augmented)
return np.squeeze(augmented_image)
def create_downloadable_zip(augmented_images):
"""Creates a ZIP file in memory for downloading."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
for idx, image_data in enumerate(augmented_images):
img_byte_arr = io.BytesIO(image_data)
zip_file.writestr(f"augmented_image_{idx+1}.jpeg", img_byte_arr.getvalue())
zip_buffer.seek(0)
return zip_buffer
st.title("Batch Image Augmentation with Latent Space Manipulation")
uploaded_files = st.file_uploader("Choose images (1-10)", accept_multiple_files=True, type=["jpg", "jpeg", "png"])
augmentations_count = st.number_input("Number of augmented samples per image", min_value=1, max_value=10, value=3)
if uploaded_files and st.button("Generate Augmented Images"):
all_augmented_images = []
for uploaded_file in uploaded_files:
image = Image.open(uploaded_file).convert("RGB")
image = image.resize((28, 28)) # Resize for simplicity with the mock decoder
# Convert to numpy for processing
image_np = np.array(image) / 255.0 # Normalize
for _ in range(augmentations_count):
augmented_image_np = latent_space_augmentation(image_np, mock_encoder, mock_decoder)
augmented_image = (augmented_image_np * 255).astype(np.uint8) # Denormalize
augmented_images_io = io.BytesIO()
Image.fromarray(augmented_image).save(augmented_images_io, format="JPEG")
all_augmented_images.append(augmented_images_io.getvalue())
if all_augmented_images:
zip_buffer = create_downloadable_zip(all_augmented_images)
st.download_button(
label="Download Augmented Dataset",
data=zip_buffer,
file_name="augmented_images.zip",
mime="application/zip"
)