nicolasni1977's picture
change ui.py to app.py
4b9bb89
raw
history blame
4.92 kB
import streamlit as st
import os
import subprocess
import time
from PIL import Image
from io import BytesIO
from persona import Persona
url = 'https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?share=EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA'
output_path = 'wav2lip/wav2lip_gan.pth'
# Check if the file does not exist
if not os.path.exists(output_path):
try:
# Using subprocess to call wget
subprocess.run(['wget', '-c', '--read-timeout=5', '-O', output_path, url], check=True)
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
else:
print(f"The file '{output_path}' already exists.")
# Initialize a session state variable that tracks the sidebar state (either 'expanded' or 'collapsed').
if 'sidebar_state' not in st.session_state:
st.session_state.sidebar_state = 'expanded'
# Streamlit page configuration
# st.set_page_config(page_title="Talking Head Generator", layout="wide")
st.set_page_config(initial_sidebar_state=st.session_state.sidebar_state, page_title="Talking Head Generator", layout="wide")
@st.cache_data
def generate_talking_head(text_prompt, voice, speed, image_file, driver_video):
video_file = Persona(text_prompt, voice, speed, image_file, driver_video)
time.sleep(5) # Simulating processing time
return video_file
@st.cache_data
def save_uploaded_file(uploaded_file, destination_dir="temp"):
file_path = os.path.join(destination_dir, uploaded_file.name)
# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)
# Check if the source and destination are the same
if os.path.abspath(uploaded_file.name) != os.path.abspath(file_path):
try:
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path
except Exception as e:
st.error(f"Error saving file: {e}")
return None
else:
# If the file already exists in the destination, return its path
return file_path
# UI Layout
col1, col2, col3 = st.columns([0.5, 4, 1])
# Column 1: Input Prompt and Image Upload
with col1:
# st.header("Your Inputs")
# st.sidebar.markdown('Your Inputs')
text_prompt = st.sidebar.text_area("Enter your text prompt:", height = 200)
voice = st.sidebar.selectbox(
'Choose Voices?',('alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'))
speed = st.sidebar.slider('Talking speed?', 1, 10, 1)
uploaded_image = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
uploaded_driver_video = st.sidebar.file_uploader("Upload a driver video", type=["mp4"])
# Column 2: Display Logo, Title, Description and Trigger Button
with col2:
st.image("images/logo1.jpg", width=650)
#st.title("Talking Head Generator")
# st.write("Generate realistic talking heads from **text prompts** and **images**.")
generate_button = st.button("Generate Talking Head")
st.write('β€’αŠαŠ||၊|။||||α‹β€Œβ€Œβ€Œβ€Œβ€ŒαŠ| β€’ Voice Selected :', voice)
# Column 3: Download Button
with col3:
st.subheader("##πŸŽ₯")
display_video = st.empty()
download_button = st.empty()
# Handling Button Click
if generate_button and text_prompt and voice and speed and uploaded_image and uploaded_driver_video:
with st.spinner('Generating Talking Head...'):
image_path = save_uploaded_file(uploaded_image)
driver_video_path = save_uploaded_file(uploaded_driver_video)
video_path = generate_talking_head(text_prompt, voice, speed, image_path, driver_video_path)
# Display progress bar
progress_bar = st.progress(0)
for percent_complete in range(100):
time.sleep(0.1)
progress_bar.progress(percent_complete + 1)
progress_bar.empty()
# display video
with display_video:
if os.path.exists(video_path):
video_file = open(video_path, 'rb')
video_bytes = video_file.read()
st.video(video_bytes)
else:
st.error("Video file not found. Please ensure the video generation process completes successfully.")
# Show download button
with download_button:
if os.path.exists(video_path):
with open(video_path, "rb") as file:
btn = st.download_button(
label="Download Video",
data=file,
file_name="talking_head.mp4",
mime="video/mp4")
else:
st.error("Video file not found. Please ensure the video generation process completes successfully.")
# Placeholder for external Python script integration
# This part should include the actual calls to the scripts provided.