openfree commited on
Commit
9d31513
·
verified ·
1 Parent(s): f40c908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -232
app.py CHANGED
@@ -1,244 +1,135 @@
1
- import os, math, torch, cv2
 
 
 
 
 
2
  from PIL import Image
3
- from omegaconf import OmegaConf
4
- from tqdm import tqdm
5
 
6
- from diffusers import AutoencoderKLTemporalDecoder
7
- from diffusers.schedulers import EulerDiscreteScheduler
8
- from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
9
 
10
- from src.utils.util import save_videos_grid, seed_everything
11
- from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
12
- from src.models.base.unet_spatio_temporal_condition import (
13
- UNetSpatioTemporalConditionModel, add_ip_adapters,
 
 
 
 
 
 
 
14
  )
15
- from src.pipelines.pipeline_sonic import SonicPipeline
16
- from src.models.audio_adapter.audio_proj import AudioProjModel
17
- from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
18
- from src.utils.RIFE.RIFE_HDv3 import RIFEModel
19
- from src.dataset.face_align.align import AlignImage
20
-
21
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
 
 
23
 
24
  # ------------------------------------------------------------------
25
- # single image + speech → video-tensor generator
26
  # ------------------------------------------------------------------
27
- def test(
28
- pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
29
- width, height, batch,
30
- ):
31
- # ---- 배치 차원 맞추기 -----------------------------------------
32
- for k, v in batch.items():
33
- if isinstance(v, torch.Tensor):
34
- batch[k] = v.unsqueeze(0).to(pipe.device).float()
35
-
36
- ref_img = batch["ref_img"]
37
- clip_img = batch["clip_images"]
38
- face_mask = batch["face_mask"]
39
- image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
40
-
41
- audio_feature = batch["audio_feature"] # (1, 80, T)
42
- audio_len = int(batch["audio_len"])
43
- step = int(config.step)
44
-
45
- window = 16_000 # 1-sec chunks
46
- audio_prompts, last_prompts = [], []
47
-
48
- for i in range(0, audio_feature.shape[-1], window):
49
- chunk = audio_feature[:, :, i : i + window] # (1, 80, win)
50
- layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
51
- last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
52
- audio_prompts.append(torch.stack(layers, dim=2)) # (1, w, L, 384)
53
- last_prompts.append(last)
54
-
55
- if not audio_prompts:
56
- raise ValueError("[ERROR] No speech recognised in the provided audio.")
57
-
58
- audio_prompts = torch.cat(audio_prompts, dim=1)
59
- last_prompts = torch.cat(last_prompts, dim=1)
60
-
61
- # padding 규칙
62
- audio_prompts = torch.cat(
63
- [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
64
- torch.zeros_like(audio_prompts[:, :6])], dim=1)
65
- last_prompts = torch.cat(
66
- [torch.zeros_like(last_prompts[:, :24]), last_prompts,
67
- torch.zeros_like(last_prompts[:, :26])], dim=1)
68
-
69
- total_tokens = audio_prompts.shape[1]
70
- num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
71
-
72
- ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
73
-
74
- for i in tqdm(range(num_chunks)):
75
- start = i * 2 * step
76
-
77
- # ------------ cond_clip : (1,1,10,5,384) ------------------
78
- clip_raw = audio_prompts[:, start : start + 10] # (1, ≤10, L, 384)
79
-
80
- # ★ W-padding은 dim=1 이어야 함!
81
- if clip_raw.shape[1] < 10:
82
- pad_w = torch.zeros_like(clip_raw[:, : 10 - clip_raw.shape[1]])
83
- clip_raw = torch.cat([clip_raw, pad_w], dim=1)
84
-
85
- # ★ L-padding은 dim=2
86
- while clip_raw.shape[2] < 5:
87
- clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
88
- clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
89
-
90
- cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
91
-
92
- # ------------ bucket_clip : (1,1,50,1,384) -----------------
93
- bucket_raw = last_prompts[:, start : start + 50]
94
- if bucket_raw.shape[1] < 50: # ★ dim=1
95
- pad_w = torch.zeros_like(bucket_raw[:, : 50 - bucket_raw.shape[1]])
96
- bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
97
- bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
98
-
99
- motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
100
-
101
- ref_list.append(ref_img[0])
102
- audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
103
- uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
104
- motion_buckets.append(motion[0])
105
-
106
- # ---- Stable-Video-Diffusion 호출 ------------------------------
107
- video = pipe(
108
- ref_img, clip_img, face_mask,
109
- audio_list, uncond_list, motion_buckets,
110
- height=height, width=width,
111
- num_frames=len(audio_list),
112
- decode_chunk_size=config.decode_chunk_size,
113
- motion_bucket_scale=config.motion_bucket_scale,
114
- fps=config.fps,
115
- noise_aug_strength=config.noise_aug_strength,
116
- min_guidance_scale1=config.min_appearance_guidance_scale,
117
- max_guidance_scale1=config.max_appearance_guidance_scale,
118
- min_guidance_scale2=config.audio_guidance_scale,
119
- max_guidance_scale2=config.audio_guidance_scale,
120
- overlap=config.overlap,
121
- shift_offset=config.shift_offset,
122
- frames_per_batch=config.n_sample_frames,
123
- num_inference_steps=config.num_inference_steps,
124
- i2i_noise_strength=config.i2i_noise_strength,
125
- ).frames
126
-
127
- video = (video * 0.5 + 0.5).clamp(0, 1)
128
- return video.to(pipe.device).unsqueeze(0).cpu()
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # ------------------------------------------------------------------
132
- # Sonic 클래스
133
  # ------------------------------------------------------------------
134
- class Sonic:
135
- config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
136
- config = OmegaConf.load(config_file)
137
-
138
- def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
139
- cfg = self.config
140
- cfg.use_interframe = enable_interpolate_frame
141
- self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
142
- cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
143
-
144
- self._load_models(cfg)
145
- print("Sonic init done")
146
-
147
- # --------------------------------------------------------------
148
- def _load_models(self, cfg):
149
- dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
150
-
151
- vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
152
- sched = EulerDiscreteScheduler.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="scheduler")
153
- img_e = CLIPVisionModelWithProjection.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
154
- unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
155
- add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
156
-
157
- a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
158
- a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
159
-
160
- unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
161
- a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
162
- a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
163
-
164
- whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
165
- whisper.requires_grad_(False)
166
-
167
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
168
- self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
169
- if cfg.use_interframe:
170
- self.rife = RIFEModel(device=self.device)
171
- self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
172
-
173
- img_e.to(dtype); vae.to(dtype); unet.to(dtype)
174
-
175
- self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
176
- self.image_encoder = img_e
177
- self.audio2token = a2t
178
- self.audio2bucket = a2b
179
- self.whisper = whisper
180
-
181
- # --------------------------------------------------------------
182
- def preprocess(self, img_path: str, expand_ratio: float = 1.0):
183
- img = cv2.imread(img_path)
184
- h, w = img.shape[:2]
185
- _, _, faces = self.face_det(img, maxface=True)
186
- if faces:
187
- x1, y1, ww, hh = faces[0]
188
- return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
189
- return {"face_num": 0, "crop_bbox": None}
190
-
191
- # --------------------------------------------------------------
192
- @torch.no_grad()
193
- def process(
194
- self,
195
- img_path: str,
196
- audio_path:str,
197
- out_path: str,
198
- min_resolution: int = 512,
199
- inference_steps:int = 25,
200
- dynamic_scale: float = 1.0,
201
- keep_resolution: bool = False,
202
- seed: int | None = None,
203
- ):
204
- cfg = self.config
205
- if seed is not None: cfg.seed = seed
206
- cfg.num_inference_steps = inference_steps
207
- cfg.motion_bucket_scale = dynamic_scale
208
- seed_everything(cfg.seed)
209
-
210
- sample = image_audio_to_tensor(
211
- self.face_det, self.feature_extractor,
212
- img_path, audio_path,
213
- limit=-1, image_size=min_resolution, area=cfg.area,
214
- )
215
- if sample is None:
216
- return -1
217
-
218
- h, w = sample["ref_img"].shape[-2:]
219
- resolution = (f"{(Image.open(img_path).width //2)*2}x{(Image.open(img_path).height//2)*2}"
220
- if keep_resolution else f"{w}x{h}")
221
-
222
- video = test(
223
- self.pipe, cfg, self.whisper, self.audio2token,
224
- self.audio2bucket, self.image_encoder,
225
- w, h, sample,
226
- )
227
-
228
- if cfg.use_interframe:
229
- out = video.to(self.device)
230
- frames = []
231
- for i in tqdm(range(out.shape[2] - 1), ncols=0):
232
- mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
233
- frames.extend([out[:, :, i], mid])
234
- frames.append(out[:, :, -1])
235
- video = torch.stack(frames, 2).cpu()
236
-
237
- tmp = out_path.replace(".mp4", "_noaudio.mp4")
238
- save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
239
- os.system(
240
- f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
241
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error"
242
- )
243
- os.remove(tmp)
244
- return 0
 
1
+ # ---------------------------------------------------------
2
+ # app.py – Gradio UI + inference wrapper for the revised Sonic
3
+ # ---------------------------------------------------------
4
+ import os, io, hashlib
5
+ import numpy as np
6
+ from pydub import AudioSegment
7
  from PIL import Image
8
+ import gradio as gr
9
+ import spaces
10
 
11
+ from sonic import Sonic # ← 현재 수정-완료된 sonic.py 를 사용
 
 
12
 
13
+ # ------------------------------------------------------------------
14
+ # 1. 필요 리소스(모델) 자동 다운로드 ── HF Spaces에서는 캐시 활용
15
+ # ------------------------------------------------------------------
16
+ os.system(
17
+ 'python3 -m pip install "huggingface_hub[cli]" accelerate -q; '
18
+ 'huggingface-cli download LeonJoe13/Sonic '
19
+ ' --local-dir checkpoints -q; '
20
+ 'huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt '
21
+ ' --local-dir checkpoints/stable-video-diffusion-img2vid-xt -q; '
22
+ 'huggingface-cli download openai/whisper-tiny '
23
+ ' --local-dir checkpoints/whisper-tiny -q'
24
  )
 
 
 
 
 
 
 
25
 
26
+ pipe = Sonic() # GPU 메모리를 즉시 점유
27
 
28
  # ------------------------------------------------------------------
29
+ # 2. 유틸
30
  # ------------------------------------------------------------------
31
+ def md5(b: bytes) -> str:
32
+ return hashlib.md5(b).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ TMP_DIR = "./tmp_path"; os.makedirs(TMP_DIR, exist_ok=True)
35
+ RES_DIR = "./res_path"; os.makedirs(RES_DIR, exist_ok=True)
36
+
37
+ # ------------------------------------------------------------------
38
+ # 3. Sonic 실행 (GPU 태그 10 min)
39
+ # ------------------------------------------------------------------
40
+ @spaces.GPU(duration=600)
41
+ def get_video_res(img_path, wav_path, out_path, dyn_scale=1.0):
42
+ """실제 Sonic 파이프라인 실행."""
43
+ audio = AudioSegment.from_file(wav_path)
44
+ dur_s = len(audio) / 1000.0 # 초
45
+
46
+ # 프레임 수 ≈ 초당 12.5 → inference_steps
47
+ inf_steps = max(25, min(int(dur_s * 12.5), 750))
48
+ print(f"[INFO] Audio duration: {dur_s:.2f}s → inference_steps={inf_steps}")
49
+
50
+ # 얼굴 사전 검출(디버그용 로그)
51
+ face_info = pipe.preprocess(img_path)
52
+ print(f"[INFO] Face detection info: {face_info}")
53
+
54
+ if face_info["face_num"] == 0:
55
+ return -1 # 얼굴 없음
56
+
57
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
58
+ pipe.process(
59
+ img_path, wav_path, out_path,
60
+ inference_steps=inf_steps,
61
+ dynamic_scale=dyn_scale,
62
+ min_resolution=512,
63
+ )
64
+ return out_path
65
 
66
  # ------------------------------------------------------------------
67
+ # 4. Gradio 인터페이스
68
  # ------------------------------------------------------------------
69
+ def process_sonic(img: Image.Image, audio_tuple, dyn_scale):
70
+ if img is None:
71
+ raise gr.Error("Please upload an image.")
72
+ if audio_tuple is None:
73
+ raise gr.Error("Please upload an audio file.")
74
+
75
+ # ---- 캐싱 키 ----------------------------------------------------
76
+ img_bytes = io.BytesIO(); img.save(img_bytes, format="PNG")
77
+ img_key = md5(img_bytes.getvalue())
78
+
79
+ rate, arr = audio_tuple
80
+ if arr.ndim == 1:
81
+ arr = arr[:, None]
82
+ segment = AudioSegment(
83
+ arr.tobytes(), frame_rate=rate,
84
+ sample_width=arr.dtype.itemsize, channels=arr.shape[1]
85
+ ).set_channels(1).set_frame_rate(16_000)
86
+
87
+ segment = segment[:60_000] # ≤60 s
88
+ buf_audio = io.BytesIO(); segment.export(buf_audio, format="wav")
89
+ aud_key = md5(buf_audio.getvalue())
90
+
91
+ img_path = os.path.join(TMP_DIR, f"{img_key}.png")
92
+ wav_path = os.path.join(TMP_DIR, f"{aud_key}.wav")
93
+ out_path = os.path.join(RES_DIR, f"{img_key}_{aud_key}_{dyn_scale}.mp4")
94
+
95
+ # ---- 캐시 저장 --------------------------------------------------
96
+ if not os.path.exists(img_path):
97
+ with open(img_path, "wb") as f: f.write(img_bytes.getvalue())
98
+ if not os.path.exists(wav_path):
99
+ with open(wav_path, "wb") as f: f.write(buf_audio.getvalue())
100
+
101
+ if os.path.exists(out_path):
102
+ print(f"[INFO] Using cached result: {out_path}")
103
+ return out_path
104
+
105
+ print(f"[INFO] Generating new video with dynamic_scale={dyn_scale}")
106
+ res = get_video_res(img_path, wav_path, out_path, dyn_scale)
107
+ if res == -1:
108
+ raise gr.Error("No face detected in the image.")
109
+ return res
110
+
111
+ # ---- Gradio UI -----------------------------------------------------
112
+ CSS = """
113
+ .gradio-container {font-family: 'Arial', sans-serif;}
114
+ .main-header {text-align:center;color:#2a2a2a;margin-bottom:2em;}
115
+ """
116
+
117
+ with gr.Blocks(css=CSS) as demo:
118
+ gr.HTML("""
119
+ <div class="main-header">
120
+ <h1>🎭 Sonic Portrait Animation (≤60 s audio)</h1>
121
+ <p>Still image → talking-head video, driven by your voice.</p>
122
+ </div>
123
+ """)
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ img_in = gr.Image(type="pil", label="Portrait Image")
128
+ aud_in = gr.Audio(type="numpy", label="Voice (≤1 min)")
129
+ dyn_sl = gr.Slider(0.5, 2.0, 1.0, 0.1, label="Animation Intensity")
130
+ btn_go = gr.Button("Generate", variant="primary")
131
+ vid_out = gr.Video(label="Result")
132
+
133
+ btn_go.click(process_sonic, inputs=[img_in, aud_in, dyn_sl], outputs=vid_out)
134
+
135
+ demo.launch(share=True)