File size: 11,755 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import os
import io
import math
import uuid
import base64
import imageio
import torch
import torchvision
from PIL import Image
import numpy as np
from copy import deepcopy
from einops import rearrange
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from hymm_sp.data_kits.audio_dataset import get_audio_feature
from hymm_sp.data_kits.ffmpeg_utils import save_video

TEMP_DIR = "./temp"
if not os.path.exists(TEMP_DIR):
    os.makedirs(TEMP_DIR, exist_ok=True)


def data_preprocess_server(args, image_path, audio_path, prompts, feature_extractor):
    llava_transform = transforms.Compose(
            [
                transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), 
                transforms.ToTensor(), 
                transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )
    
    """ 生成prompt """
    if prompts is None:
        prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed." 
    else:
        prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + prompts

    fps = 25
    
    img_size = args.image_size
    ref_image = Image.open(image_path).convert('RGB')
    
    # Resize reference image
    w, h = ref_image.size
    scale = img_size / min(w, h)
    new_w = round(w * scale / 64) * 64
    new_h = round(h * scale / 64) * 64

    if img_size == 704:
        img_size_long = 1216
    if new_w * new_h > img_size * img_size_long:
        scale = math.sqrt(img_size * img_size_long / w / h)
        new_w = round(w * scale / 64) * 64
        new_h = round(h * scale / 64) * 64

    ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
    
    ref_image = np.array(ref_image)
    ref_image = torch.from_numpy(ref_image)
        
    audio_input, audio_len = get_audio_feature(feature_extractor, audio_path)
    audio_prompts = audio_input[0]
    
    motion_bucket_id_heads = np.array([25] * 4)
    motion_bucket_id_exps = np.array([30] * 4)
    motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
    motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
    fps = torch.from_numpy(np.array(fps))
    
    to_pil = ToPILImage()
    pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w")   # (b c h w)
    
    pixel_value_ref_llava = [llava_transform(to_pil(image)) for image in pixel_value_ref]
    pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)

    batch = {
        "text_prompt": [prompts],
        "audio_path": [audio_path],
        "image_path": [image_path],
        "fps": fps.unsqueeze(0).to(dtype=torch.float16),
        "audio_prompts": audio_prompts.unsqueeze(0).to(dtype=torch.float16),
        "audio_len": [audio_len],
        "motion_bucket_id_exps": motion_bucket_id_exps.unsqueeze(0),
        "motion_bucket_id_heads": motion_bucket_id_heads.unsqueeze(0),
        "pixel_value_ref": pixel_value_ref.unsqueeze(0).to(dtype=torch.float16),
        "pixel_value_ref_llava": pixel_value_ref_llava.unsqueeze(0).to(dtype=torch.float16)
    }

    return batch

def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = torch.clamp(x,0,1)
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps, quality=quality)

def encode_image_to_base64(image_path):
    try:
        with open(image_path, 'rb') as image_file:
            image_data = image_file.read()
        encoded_data = base64.b64encode(image_data).decode('utf-8')
        print(f"Image file '{image_path}' has been successfully encoded to Base64.")
        return encoded_data
    
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

def encode_video_to_base64(video_path):
    try:
        with open(video_path, 'rb') as video_file:
            video_data = video_file.read()
        encoded_data = base64.b64encode(video_data).decode('utf-8')
        print(f"Video file '{video_path}' has been successfully encoded to Base64.")
        return encoded_data
    
    except Exception as e:
        print(f"Error encoding video: {e}")
        return None
    
def encode_wav_to_base64(wav_path):
    try:
        with open(wav_path, 'rb') as audio_file:
            audio_data = audio_file.read()
        encoded_data = base64.b64encode(audio_data).decode('utf-8')
        print(f"Audio file '{wav_path}' has been successfully encoded to Base64.")
        return encoded_data
    
    except Exception as e:
        print(f"Error encoding audio: {e}")
        return None
    
def encode_pkl_to_base64(pkl_path):
    try:
        with open(pkl_path, 'rb') as pkl_file:
            pkl_data = pkl_file.read()
        
        encoded_data = base64.b64encode(pkl_data).decode('utf-8')
        
        print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.")
        return encoded_data

    except Exception as e:
        print(f"Error encoding pickle: {e}")
        return None
      
def decode_base64_to_image(base64_buffer_str):
    try:
        image_data = base64.b64decode(base64_buffer_str)
        image = Image.open(io.BytesIO(image_data))
        image_array = np.array(image)
        print(f"Image Base64 string has beed succesfully decoded to image.")
        return image_array
    except Exception as e:
        print(f"Error encdecodingoding image: {e}")
        return None
    
def decode_base64_to_video(base64_buffer_str):
    try:
        video_data = base64.b64decode(base64_buffer_str)
        video_bytes = io.BytesIO(video_data)
        video_bytes.seek(0)
        video_reader = imageio.get_reader(video_bytes, 'ffmpeg')
        video_frames = [frame for frame in video_reader]
        return video_frames
    except Exception as e:
        print(f"Error decoding video: {e}")
        return None

    
def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None):
    if video_path is not None and base64_buffer is None:
        video_buffer_base64 = encode_video_to_base64(video_path)
    elif video_path is None and base64_buffer is not None:
        video_buffer_base64 = deepcopy(base64_buffer)
    else:
        print("Please pass either 'video_path' or 'base64_buffer'")
        return None
    
    if video_buffer_base64 is not None:
        video_data = base64.b64decode(video_buffer_base64)
        if output_video_path is None:
            uuid_string = str(uuid.uuid4())
            temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
        else:
            temp_video_path = output_video_path
        with open(temp_video_path, 'wb') as video_file:
            video_file.write(video_data)
        return temp_video_path
    else:
        return None
    
def save_audio_base64_to_local(audio_path=None, base64_buffer=None):
    if audio_path is not None and base64_buffer is None:
        audio_buffer_base64 = encode_wav_to_base64(audio_path)
    elif audio_path is None and base64_buffer is not None:
        audio_buffer_base64 = deepcopy(base64_buffer)
    else:
        print("Please pass either 'audio_path' or 'base64_buffer'")
        return None
    
    if audio_buffer_base64 is not None:
        audio_data = base64.b64decode(audio_buffer_base64)
        uuid_string = str(uuid.uuid4())
        temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav'
        with open(temp_audio_path, 'wb') as audio_file:
            audio_file.write(audio_data)
        return temp_audio_path
    else:
        return None
    
def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None):
    if pkl_path is not None and base64_buffer is None:
        pkl_buffer_base64 = encode_pkl_to_base64(pkl_path)
    elif pkl_path is None and base64_buffer is not None:
        pkl_buffer_base64 = deepcopy(base64_buffer)
    else:
        print("Please pass either 'pkl_path' or 'base64_buffer'")
        return None
    
    if pkl_buffer_base64 is not None:
        pkl_data = base64.b64decode(pkl_buffer_base64)
        uuid_string = str(uuid.uuid4())
        temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl'
        with open(temp_pkl_path, 'wb') as pkl_file:
            pkl_file.write(pkl_data)
        return temp_pkl_path
    else:
        return None
    
def remove_temp_fles(input_dict):
    for key, val in input_dict.items():
        if "_path" in key and val is not None and os.path.exists(val):
            os.remove(val)
            print(f"Remove temporary {key} from {val}")

def process_output_dict(output_dict):

    uuid_string = str(uuid.uuid4())
    temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
    save_video(output_dict["video"], temp_video_path, fps=output_dict.get("save_fps", 25))

    # Add audio
    if output_dict["audio"] is not None and os.path.exists(output_dict["audio"]):
        output_path = temp_video_path
        audio_path = output_dict["audio"]
        save_path = temp_video_path.replace(".mp4", "_audio.mp4")
        print('='*100)
        print(f"output_path = {output_path}\n audio_path = {audio_path}\n save_path = {save_path}")
        os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{save_path}' -y -loglevel quiet; rm '{output_path}'")
    else:
        save_path = temp_video_path

    video_base64_buffer = encode_video_to_base64(save_path)

    encoded_output_dict = {
        "errCode": output_dict["err_code"], 
        "content": [
                    {
                        "buffer": video_base64_buffer
                    },
                ],
        "info":output_dict["err_msg"],
    }
    
    

    return encoded_output_dict


def save_image_base64_to_local(image_path=None, base64_buffer=None):
    # Encode image to base64 buffer
    if image_path is not None and base64_buffer is None:
        image_buffer_base64 = encode_image_to_base64(image_path)
    elif image_path is None and base64_buffer is not None:
        image_buffer_base64 = deepcopy(base64_buffer)
    else:
        print("Please pass either 'image_path' or 'base64_buffer'")
        return None
    
    # Decode base64 buffer and save to local disk
    if image_buffer_base64 is not None:
        image_data = base64.b64decode(image_buffer_base64)
        uuid_string = str(uuid.uuid4())
        temp_image_path = f'{TEMP_DIR}/{uuid_string}.png'
        with open(temp_image_path, 'wb') as image_file:
            image_file.write(image_data)
        return temp_image_path
    else:
        return None
    
def process_input_dict(input_dict):
    
    decoded_input_dict = {}
   
    decoded_input_dict["save_fps"] = input_dict.get("save_fps", 25)

    image_base64_buffer = input_dict.get("image_buffer", None)
    if image_base64_buffer is not None:
        decoded_input_dict["image_path"] = save_image_base64_to_local(
            image_path=None, 
            base64_buffer=image_base64_buffer)
    else:
        decoded_input_dict["image_path"] = None
    
    audio_base64_buffer = input_dict.get("audio_buffer", None)
    if audio_base64_buffer is not None:
        decoded_input_dict["audio_path"] = save_audio_base64_to_local(
            audio_path=None, 
            base64_buffer=audio_base64_buffer)
    else:
        decoded_input_dict["audio_path"] = None
    
    decoded_input_dict["prompt"] = input_dict.get("text", None)
        
    return decoded_input_dict