import streamlit as st from PIL import Image import torchvision.transforms as transforms from streamlit_image_comparison import image_comparison import numpy as np import torch import torchvision ######################################### Utils ######################################## video_extensions = ["mp4"] image_extensions = ["png", "jpg"] def check_type(file_name: str): for image_extension in image_extensions: if file_name.endswith(image_extension): return "image" for video_extension in video_extensions: if file_name.endswith(video_extension): return "video" return None transform = transforms.Compose( [transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] ) ###################################### Load model ###################################### @st.cache_resource def load_model(): model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True) model.eval() return model model = load_model() ########################################## UI ########################################## st.title("Colorization") uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions) if uploaded_file: # Image if check_type(file_name=uploaded_file.name) == "image": image = np.array(Image.open(uploaded_file), dtype=np.float32) input_tensor = torchvision.transforms.functional.normalize( torch.tensor(image).permute(2, 0, 1), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ).unsqueeze(0) process_button = st.button("Process") if process_button: with st.spinner("Từ từ coi..."): prediction = model(input_tensor) segment = prediction["out"][0].permute(1, 2, 0) segment = segment.detach().numpy() st.image(segment) st.image(image) image_comparison( img1=image, img2=np.array(segment), label1="Grayscale", label2="Colorized", make_responsive=True, show_labels=True, ) # Video else: # video = open(uploaded_file.name) st.video("https://youtu.be/dQw4w9WgXcQ") hide_menu_style = """ """ st.markdown(hide_menu_style, unsafe_allow_html=True)