Spaces:
Runtime error
Runtime error
import time | |
from config import * | |
import cv2 | |
import glob | |
import numpy as np | |
import os | |
from basicsr.utils import imwrite | |
from pathos.pools import ParallelPool | |
import subprocess | |
import platform | |
from mutagen.wave import WAVE | |
import tqdm | |
from p_tqdm import * | |
import torch | |
from PIL import Image | |
from RealESRGAN import RealESRGAN | |
def vid2frames(vidPath, framesOutPath): | |
print(vidPath) | |
print(framesOutPath) | |
vidcap = cv2.VideoCapture(vidPath) | |
success,image = vidcap.read() | |
frame = 1 | |
while success: | |
cv2.imwrite(os.path.join(framesOutPath, str(frame).zfill(5) + '.png'), image) | |
success,image = vidcap.read() | |
frame += 1 | |
def restore_frames(audiofilePath, videoOutPath, improveOutputPath): | |
no_of_frames = count_files(improveOutputPath) | |
audio_duration = get_audio_duration(audiofilePath) | |
framesPath = improveOutputPath + "/%5d.png" | |
fps = no_of_frames/audio_duration | |
command = f"ffmpeg -y -r {fps} -f image2 -i {framesPath} -i {audiofilePath} -vcodec mpeg4 -b:v 20000k {videoOutPath}" | |
print(command) | |
subprocess.call(command, shell=platform.system() != 'Windows') | |
def get_audio_duration(audioPath): | |
audio = WAVE(audioPath) | |
duration = audio.info.length | |
return duration | |
def count_files(directory): | |
return len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))]) | |
def improve(disassembledPath, improvedPath): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = RealESRGAN(device, scale=4) | |
model.load_weights('weights/RealESRGAN_x4.pth', download=True) | |
files = glob.glob(os.path.join(disassembledPath,"*.png")) | |
# pool = ParallelPool(nodes=20) | |
# results = pool.amap(real_esrgan, files, [model]*len(files), [improvedPath] * len(files)) | |
results = t_map(real_esrgan, files, [model]*len(files), [improvedPath] * len(files)) | |
def real_esrgan(img_path, model, improvedPath): | |
image = Image.open(img_path).convert('RGB') | |
sr_image = model.predict(image) | |
img_name = os.path.basename(img_path) | |
sr_image.save(os.path.join(improvedPath, img_name)) |