llSourcell commited on
Commit
9624517
·
1 Parent(s): 7ff5563

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -4
app.py CHANGED
@@ -1,7 +1,154 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from flask import Flask
3
+ import gc
4
+ import math
5
  import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from encoded_video import EncodedVideo, write_video
9
+ from PIL import Image
10
+ from torchvision.transforms.functional import center_crop, to_tensor
11
 
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
13
 
14
+ print("🧠 Loading Model...")
15
+ model = torch.hub.load(
16
+ "AK391/animegan2-pytorch:main",
17
+ "generator",
18
+ pretrained=True,
19
+ device=device,
20
+ progress=True,
21
+ )
22
+
23
+
24
+ def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = device):
25
+ w, h = img.size
26
+ s = min(w, h)
27
+ img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
28
+ img = img.resize((size, size), Image.LANCZOS)
29
+
30
+ with torch.no_grad():
31
+ input = to_tensor(img).unsqueeze(0) * 2 - 1
32
+ output = model(input.to(device)).cpu()[0]
33
+
34
+ output = (output * 0.5 + 0.5).clip(0, 1) * 255.0
35
+
36
+ return output
37
+
38
+
39
+ # This function is taken from pytorchvideo!
40
+ def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:
41
+ """
42
+ Uniformly subsamples num_samples indices from the temporal dimension of the video.
43
+ When num_samples is larger than the size of temporal dimension of the video, it
44
+ will sample frames based on nearest neighbor interpolation.
45
+ Args:
46
+ x (torch.Tensor): A video tensor with dimension larger than one with torch
47
+ tensor type includes int, long, float, complex, etc.
48
+ num_samples (int): The number of equispaced samples to be selected
49
+ temporal_dim (int): dimension of temporal to perform temporal subsample.
50
+ Returns:
51
+ An x-like Tensor with subsampled temporal dimension.
52
+ """
53
+ t = x.shape[temporal_dim]
54
+ assert num_samples > 0 and t > 0
55
+ # Sample by nearest neighbor interpolation if num_samples > t.
56
+ indices = torch.linspace(0, t - 1, num_samples)
57
+ indices = torch.clamp(indices, 0, t - 1).long()
58
+ return torch.index_select(x, temporal_dim, indices)
59
+
60
+
61
+ # This function is taken from pytorchvideo!
62
+ def short_side_scale(
63
+ x: torch.Tensor,
64
+ size: int,
65
+ interpolation: str = "bilinear",
66
+ ) -> torch.Tensor:
67
+ """
68
+ Determines the shorter spatial dim of the video (i.e. width or height) and scales
69
+ it to the given size. To maintain aspect ratio, the longer side is then scaled
70
+ accordingly.
71
+ Args:
72
+ x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.
73
+ size (int): The size the shorter side is scaled to.
74
+ interpolation (str): Algorithm used for upsampling,
75
+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
76
+ Returns:
77
+ An x-like Tensor with scaled spatial dims.
78
+ """
79
+ assert len(x.shape) == 4
80
+ assert x.dtype == torch.float32
81
+ c, t, h, w = x.shape
82
+ if w < h:
83
+ new_h = int(math.floor((float(h) / w) * size))
84
+ new_w = size
85
+ else:
86
+ new_h = size
87
+ new_w = int(math.floor((float(w) / h) * size))
88
+
89
+ return torch.nn.functional.interpolate(x, size=(new_h, new_w), mode=interpolation, align_corners=False)
90
+
91
+
92
+ def inference_step(vid, start_sec, duration, out_fps):
93
+
94
+ clip = vid.get_clip(start_sec, start_sec + duration)
95
+ video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)
96
+ audio_arr = np.expand_dims(clip['audio'], 0)
97
+ audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate
98
+
99
+ x = uniform_temporal_subsample(video_arr, duration * out_fps)
100
+ x = center_crop(short_side_scale(x, 512), 512)
101
+ x /= 255.0
102
+ x = x.permute(1, 0, 2, 3)
103
+ with torch.no_grad():
104
+ output = model(x.to(device)).detach().cpu()
105
+ output = (output * 0.5 + 0.5).clip(0, 1) * 255.0
106
+ output_video = output.permute(0, 2, 3, 1).numpy()
107
+
108
+ return output_video, audio_arr, out_fps, audio_fps
109
+
110
+
111
+ def predict_fn(filepath, start_sec, duration):
112
+ out_fps = 18
113
+ vid = EncodedVideo.from_path(filepath)
114
+ for i in range(duration):
115
+ print(f"🖼️ Processing step {i + 1}/{duration}...")
116
+ video, audio, fps, audio_fps = inference_step(vid=vid, start_sec=i + start_sec, duration=1, out_fps=out_fps)
117
+ gc.collect()
118
+ if i == 0:
119
+ video_all = video
120
+ audio_all = audio
121
+ else:
122
+ video_all = np.concatenate((video_all, video))
123
+ audio_all = np.hstack((audio_all, audio))
124
+
125
+ print(f"💾 Writing output video...")
126
+
127
+ try:
128
+ write_video('out.mp4', video_all, fps=fps, audio_array=audio_all, audio_fps=audio_fps, audio_codec='aac')
129
+ except:
130
+ print("❌ Error when writing with audio...trying without audio")
131
+ write_video('out.mp4', video_all, fps=fps)
132
+
133
+ print(f"✅ Done!")
134
+ del video_all
135
+ del audio_all
136
+
137
+ return 'out.mp4'
138
+
139
+
140
+ iface_file = gr.Interface(
141
+ predict_fn,
142
+ inputs=[
143
+ gr.inputs.Video(source="upload"),
144
+ gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0),
145
+ gr.inputs.Slider(minimum=1, maximum=1000, step=1, default=2),
146
+ ],
147
+ outputs=gr.outputs.Video(),
148
+ title='Animusica Studio',
149
+ description="",
150
+ article="",
151
+ css="footer {visibility: hidden}",
152
+ allow_flagging='never',
153
+ theme="default",
154
+ ).launch(enable_queue=True, share=True)