|
|
import os |
|
|
import logging |
|
|
import torch |
|
|
import asyncio |
|
|
import aiohttp |
|
|
import requests |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') |
|
|
MODELS_DIR = os.path.join(DATA_ROOT, "models") |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
HF_REPO_ID = "jbilcke-hf/model-cocktail" |
|
|
|
|
|
|
|
|
MODEL_FILES = [ |
|
|
"dwpose/dw-ll_ucoco_384.pth", |
|
|
"face-detector/s3fd-619a316812.pth", |
|
|
|
|
|
"liveportrait/spade_generator.pth", |
|
|
"liveportrait/warping_module.pth", |
|
|
"liveportrait/motion_extractor.pth", |
|
|
"liveportrait/stitching_retargeting_module.pth", |
|
|
"liveportrait/appearance_feature_extractor.pth", |
|
|
"liveportrait/landmark.onnx", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"insightface/models/buffalo_l.zip", |
|
|
|
|
|
"insightface/buffalo_l/det_10g.onnx", |
|
|
"insightface/buffalo_l/2d106det.onnx", |
|
|
"sd-vae-ft-mse/diffusion_pytorch_model.bin", |
|
|
"sd-vae-ft-mse/diffusion_pytorch_model.safetensors", |
|
|
"sd-vae-ft-mse/config.json", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
def create_directory(directory): |
|
|
"""Create a directory if it doesn't exist and log its status.""" |
|
|
if not os.path.exists(directory): |
|
|
os.makedirs(directory) |
|
|
logger.info(f" Directory created: {directory}") |
|
|
else: |
|
|
logger.info(f" Directory already exists: {directory}") |
|
|
|
|
|
def print_directory_structure(startpath): |
|
|
"""Print the directory structure starting from the given path.""" |
|
|
for root, dirs, files in os.walk(startpath): |
|
|
level = root.replace(startpath, '').count(os.sep) |
|
|
indent = ' ' * 4 * level |
|
|
logger.info(f"{indent}{os.path.basename(root)}/") |
|
|
subindent = ' ' * 4 * (level + 1) |
|
|
for f in files: |
|
|
logger.info(f"{subindent}{f}") |
|
|
|
|
|
async def download_hf_file(filename: str) -> None: |
|
|
"""Download a file from Hugging Face to the models directory.""" |
|
|
dest = os.path.join(MODELS_DIR, filename) |
|
|
os.makedirs(os.path.dirname(dest), exist_ok=True) |
|
|
if os.path.exists(dest): |
|
|
|
|
|
logger.debug(f" β
{filename}") |
|
|
return |
|
|
|
|
|
logger.info(f" β³ Downloading {HF_REPO_ID}/{filename}") |
|
|
|
|
|
try: |
|
|
await asyncio.get_event_loop().run_in_executor( |
|
|
None, |
|
|
lambda: hf_hub_download( |
|
|
repo_id=HF_REPO_ID, |
|
|
filename=filename, |
|
|
local_dir=MODELS_DIR |
|
|
) |
|
|
) |
|
|
logger.info(f" β
Downloaded {filename}") |
|
|
except Exception as e: |
|
|
logger.error(f"π¨ Error downloading file from Hugging Face: {e}") |
|
|
if os.path.exists(dest): |
|
|
os.remove(dest) |
|
|
raise |
|
|
|
|
|
async def download_all_models(): |
|
|
"""Download all required models from the Hugging Face repository.""" |
|
|
logger.info(" π Looking for models...") |
|
|
tasks = [download_hf_file(filename) for filename in MODEL_FILES] |
|
|
await asyncio.gather(*tasks) |
|
|
logger.info(" β
All models are available") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelLoader: |
|
|
"""A class responsible for loading and initializing all required models.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = DEVICE |
|
|
self.models_dir = MODELS_DIR |
|
|
|
|
|
async def load_live_portrait(self): |
|
|
"""Load LivePortrait models.""" |
|
|
from liveportrait.config.inference_config import InferenceConfig |
|
|
from liveportrait.config.crop_config import CropConfig |
|
|
from liveportrait.live_portrait_pipeline import LivePortraitPipeline |
|
|
|
|
|
logger.info(" β³ Loading LivePortrait models...") |
|
|
live_portrait_pipeline = await asyncio.to_thread( |
|
|
LivePortraitPipeline, |
|
|
inference_cfg=InferenceConfig( |
|
|
|
|
|
flag_stitching=True, |
|
|
flag_relative=True, |
|
|
flag_pasteback=True, |
|
|
flag_do_crop= True, |
|
|
flag_do_rot=True, |
|
|
), |
|
|
crop_cfg=CropConfig() |
|
|
) |
|
|
logger.info(" β
LivePortrait models loaded successfully.") |
|
|
return live_portrait_pipeline |
|
|
|
|
|
async def initialize_models(): |
|
|
"""Initialize and load all required models.""" |
|
|
logger.info("π Starting model initialization...") |
|
|
|
|
|
|
|
|
await download_all_models() |
|
|
|
|
|
|
|
|
loader = ModelLoader() |
|
|
|
|
|
|
|
|
live_portrait = await loader.load_live_portrait() |
|
|
|
|
|
logger.info("β
Model initialization completed.") |
|
|
return live_portrait |
|
|
|
|
|
|
|
|
logger.info("π Setting up storage directories...") |
|
|
create_directory(MODELS_DIR) |
|
|
logger.info("β
Storage directories setup completed.") |
|
|
|