Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import os | |
| import numpy as np | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.download_util import load_file_from_url | |
| from realesrgan import RealESRGANer | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan import GFPGANer | |
| # Function to load the model | |
| def load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id): | |
| if model_name == 'RealESRGAN_x4plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] | |
| elif model_name == 'RealESRGAN_x4plus_anime_6B': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] | |
| elif model_name == 'RealESRGAN_x2plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| netscale = 2 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] | |
| # Determine model paths | |
| if model_path is not None: | |
| model_path = model_path | |
| else: | |
| model_path = os.path.join('weights', model_name + '.pth') | |
| if not os.path.isfile(model_path): | |
| for url in file_url: | |
| # Model_path will be updated | |
| model_path = load_file_from_url( | |
| url=url, model_dir=os.path.join(os.getcwd(), 'weights'), progress=True, file_name=model_name + '.pth') | |
| dni_weight = None | |
| if model_name == 'realesr-general-x4v3' and denoise_strength != 1: | |
| model_path = [model_path, model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')] | |
| dni_weight = [denoise_strength, 1 - denoise_strength] | |
| # Use DNI to control the denoise strength | |
| dni_weight = None | |
| if model_name == 'realesr-general-x4v3' and denoise_strength != 1: | |
| wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') | |
| model_path = [model_path, wdn_model_path] | |
| dni_weight = [denoise_strength, 1 - denoise_strength] | |
| # Restorer | |
| upsampler = RealESRGANer( | |
| scale=netscale, | |
| model_path=model_path, | |
| dni_weight=dni_weight, | |
| model=model, | |
| tile=tile, | |
| tile_pad=tile_pad, | |
| pre_pad=pre_pad, | |
| half=not fp32, | |
| gpu_id=gpu_id) | |
| return upsampler | |
| # Function to download model weights if not present | |
| def ensure_model_weights(model_name): | |
| weights_dir = 'weights' | |
| model_file = f"{model_name}.pth" | |
| model_path = os.path.join(weights_dir, model_file) | |
| if not os.path.exists(weights_dir): | |
| os.makedirs(weights_dir) | |
| if not os.path.isfile(model_path): | |
| if model_name == 'RealESRGAN_x4plus': | |
| file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' | |
| elif model_name == 'RealESRGAN_x4plus_anime_6B': | |
| file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth' | |
| elif model_name == 'RealESRGAN_x2plus': | |
| file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth' | |
| model_path = load_file_from_url( | |
| url=file_url, model_dir=weights_dir, progress=True, file_name=model_file) | |
| return model_path | |
| # Streamlit app | |
| st.title("Real-ESRGAN Image Enhancement") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
| # User selects model name, denoise strength, and other parameters | |
| model_name = st.selectbox("Model Name", ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus']) | |
| denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5) | |
| outscale = st.slider("Output Scale", 1, 4, 4) | |
| tile = 0 | |
| tile_pad = 10 | |
| pre_pad = 0 | |
| face_enhance = st.checkbox("Face Enhance") | |
| fp32 = st.checkbox("Use FP32 Precision") | |
| gpu_id = None # or set to 0, 1, etc. if you have multiple GPUs | |
| if uploaded_file is not None: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("### Original Image") | |
| st.image(uploaded_file, use_column_width=True) | |
| run_button = st.button("Run") | |
| # Save uploaded image to disk | |
| input_image_path = os.path.join("temp", "input_image.png") | |
| os.makedirs("temp", exist_ok=True) | |
| with open(input_image_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| if not run_button: | |
| st.warning("Click the 'Run' button to start the enhancement process.") | |
| if run_button: | |
| # Ensure model weights are downloaded | |
| model_path = ensure_model_weights(model_name) | |
| # Load the model | |
| upsampler = load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id) | |
| # Load the image | |
| img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_UNCHANGED) | |
| if img is None: | |
| st.error("Error loading image. Please try again.") | |
| else: | |
| img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None | |
| try: | |
| if face_enhance: | |
| face_enhancer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=outscale, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=upsampler) | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
| else: | |
| output, _ = upsampler.enhance(img, outscale=outscale) | |
| except RuntimeError as error: | |
| st.error(f"Error: {error}") | |
| st.error('If you encounter CUDA out of memory, try to set --tile with a smaller number.') | |
| else: | |
| # Save and display the output image | |
| output_image_path = os.path.join("temp", "output_image.png") | |
| cv2.imwrite(output_image_path, output) | |
| with col2: | |
| st.write("### Enhanced Image") | |
| st.image(output_image_path, use_column_width=True) | |
| if 'output_image_path' in locals(): | |
| st.download_button("Download Enhanced Image", data=open(output_image_path, "rb").read(), file_name="output_image.png", mime="image/png") |