import streamlit as st import torch import numpy as np from PIL import Image from model import CycleGAN, get_val_transform, de_normalize # Configure page st.set_page_config( page_title="CycleGAN Image Converter", page_icon="🎨", layout="wide" ) # Get the best available device @st.cache_resource def get_device(): if torch.cuda.is_available(): device = torch.device("cuda") st.sidebar.success("Using GPU 🚀") elif torch.backends.mps.is_available(): device = torch.device("mps") st.sidebar.success("Using Apple Silicon 🍎") else: device = torch.device("cpu") st.sidebar.info("Using CPU 💻") return device # Add custom CSS st.markdown(""" """, unsafe_allow_html=True) # Title and description st.title("CycleGAN Image Converter 🎨") st.markdown(""" Transform images between different domains using CycleGAN. Upload an image and see it converted in real-time! *Note: Images will be resized to 256x256 pixels during conversion.* """) # Available models and their configurations MODELS = [ { "name": "Cezanne ↔ Photo", "id": "cezanne2photo", "model_path": "waleko/cyclegan", "description": "Convert between Cezanne's painting style and photographs" } ] # Sidebar controls with st.sidebar: st.header("Settings") # Model selection selected_model = st.selectbox( "Conversion Type", options=range(len(MODELS)), format_func=lambda x: MODELS[x]["name"] ) # Direction selection direction = st.radio( "Conversion Direction", options=["A → B", "B → A"], help="A → B: Convert from domain A to B\nB → A: Convert from domain B to A" ) # Load model @st.cache_resource def load_model(model_path): device = get_device() model = CycleGAN.from_pretrained(model_path) model = model.to(device) model.eval() return model # Process image def process_image(image, model, direction): # Prepare transform transform = get_val_transform(model, direction) # Convert PIL image to tensor tensor = transform(np.array(image)).unsqueeze(0) # Move to appropriate device tensor = tensor.to(next(model.parameters()).device) # Process with torch.no_grad(): if direction == "A → B": output = model.generator_ab(tensor) else: output = model.generator_ba(tensor) # Convert back to image result = de_normalize(output[0], model, direction) return result # Main interface col1, col2 = st.columns(2) with col1: st.subheader("Input Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: input_image = Image.open(uploaded_file) st.image(input_image, use_column_width=True) with col2: st.subheader("Converted Image") if uploaded_file is not None: try: # Load and process model = load_model(MODELS[selected_model]["model_path"]) result = process_image(input_image, model, direction) # Display st.image(result, use_column_width=True) except Exception as e: st.error(f"Error during conversion: {str(e)}") else: st.info("Upload an image to see the conversion result")