charlesyu108 commited on
Commit
a3b1cb3
·
verified ·
1 Parent(s): 3650e42

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ # Import necessary libraries
3
+ from moviepy.editor import VideoFileClip
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+ def extract_frames(video, frame_dir, n_samples, start=-1, end=-1):
10
+ os.makedirs(frame_dir, exist_ok=True)
11
+
12
+ if start == -1:
13
+ start = 0
14
+ if end == -1:
15
+ end = video.duration
16
+
17
+ duration = end - start
18
+ interval = duration / n_samples
19
+
20
+ for i in range(n_samples):
21
+ frame_time = start + i * interval
22
+ frame = video.get_frame(frame_time)
23
+ frame_image = Image.fromarray(np.uint8(frame))
24
+ frame_path = os.path.join(frame_dir, f"frame_{i+1}.png")
25
+ frame_image.save(frame_path)
26
+
27
+
28
+ def extract_video_parts(video, out_dir):
29
+ os.makedirs(out_dir, exist_ok=True)
30
+
31
+ # Extract audio
32
+ audio_path = f"{out_dir}/audio.mp3"
33
+ video.audio.write_audiofile(audio_path)
34
+
35
+ # Extract 20 frames from the video
36
+ extract_frames(video, f"{out_dir}/frames", 20)
37
+
38
+ # Extract 20 frames from first 5 seconds
39
+ extract_frames(video, f"{out_dir}/5s_frames", 20, start=0, end=5)
40
+
41
+
42
+ # %%
43
+ tags = []
44
+ with open("labels.txt", "r") as f:
45
+ for line in f:
46
+ tags.append(line.strip())
47
+
48
+ # %%
49
+ from transformers import AutoTokenizer, AutoModel
50
+ import torch
51
+ import torch.nn.functional as F
52
+
53
+ # Load the tokenizer and model
54
+ tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5')
55
+ text_model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
56
+ text_model.eval()
57
+
58
+ # Function to get embeddings for tags
59
+ def get_tag_embeddings(tags):
60
+ encoded_input = tokenizer(tags, padding=True, truncation=True, return_tensors='pt')
61
+ with torch.no_grad():
62
+ model_output = text_model(**encoded_input)
63
+ text_embeddings = F.normalize(model_output.last_hidden_state[:, 0], p=2, dim=1)
64
+ return text_embeddings
65
+
66
+ tag_embeddings = get_tag_embeddings(tags)
67
+
68
+ # %%
69
+
70
+ from transformers import AutoImageProcessor, AutoModel
71
+ from PIL import Image
72
+ import os
73
+ from collections import Counter
74
+
75
+ processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5")
76
+ vision_model = AutoModel.from_pretrained("nomic-ai/nomic-embed-vision-v1.5", trust_remote_code=True)
77
+
78
+ def get_frames(frame_dir):
79
+ # Order frames by number but they will have numerical suffixes
80
+ found_frames = [frame for frame in os.listdir(frame_dir) if frame.startswith("frame_")]
81
+ frame_numbers = [int(frame.split("_")[-1].split(".")[0]) for frame in found_frames]
82
+ frames = [Image.open(os.path.join(frame_dir, f"frame_{frame_no}.png")) for frame_no in sorted(frame_numbers)]
83
+ return frames
84
+
85
+ def frames_to_embeddings(frames):
86
+ inputs = processor(frames, return_tensors="pt")
87
+ img_emb = vision_model(**inputs).last_hidden_state
88
+ img_embeddings = F.normalize(img_emb[:, 0], p=2, dim=1)
89
+ return img_embeddings
90
+
91
+ def compute_similarities(img_embeddings, tag_embeddings):
92
+ similarities = torch.matmul(img_embeddings, tag_embeddings.T)
93
+ return similarities
94
+
95
+ def get_top_tags(similarities, tags):
96
+ top_5_tags = similarities.topk(5).indices.tolist()
97
+ return [tags[tag_idx] for tag_idx in top_5_tags]
98
+
99
+ def analyze_frames(frame_dir, tag_embeddings):
100
+ frames = get_frames(frame_dir)
101
+ img_embeddings = frames_to_embeddings(frames)
102
+ cosine_similarities = compute_similarities(img_embeddings, tag_embeddings)
103
+ results = {
104
+ "images": [],
105
+ "summary": {}
106
+ }
107
+ summary = Counter()
108
+ for i, img in enumerate(frames):
109
+ top_5_tags = get_top_tags(cosine_similarities[i], tags)
110
+ results["images"].append({"image": img.filename, "tags": top_5_tags})
111
+ summary.update(top_5_tags)
112
+
113
+ results["summary"]["tags"] = summary
114
+ return results
115
+
116
+
117
+
118
+ # %%
119
+ import openai
120
+
121
+ def transcribe(audio_path):
122
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
123
+ transcript = client.audio.transcriptions.create(model="whisper-1", file=open(audio_path, "rb"))
124
+ return transcript.text
125
+
126
+
127
+ # %%
128
+ # Load model directly
129
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
130
+
131
+ audio_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
132
+ audio_feature_model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
133
+
134
+ # %%
135
+ from pydub import AudioSegment
136
+
137
+ def extract_audio_features(audio_path):
138
+ with open(audio_path, "rb") as file:
139
+ audio = file.read()
140
+
141
+ # Convert to wav
142
+ audio = AudioSegment.from_file(audio_path, format="mp3")
143
+ audio = audio.get_array_of_samples()
144
+ inputs = audio_extractor(audio, return_tensors="pt")
145
+ with torch.no_grad():
146
+ outputs = audio_feature_model(**inputs).logits
147
+ predicted_class_ids = outputs.topk(3).indices.tolist()[0]
148
+ predicted_labels = [audio_feature_model.config.id2label[class_id] for class_id in predicted_class_ids]
149
+ return predicted_labels
150
+
151
+ # %%
152
+ import base64
153
+ from io import BytesIO
154
+
155
+ def base64_encode_image(image):
156
+ buffered = BytesIO()
157
+ new_width = image.width // 2
158
+ new_height = image.height // 2
159
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
160
+ resized_image.save(buffered, format="JPEG")
161
+ img_str = base64.b64encode(buffered.getvalue())
162
+ return 'data:image/jpeg;base64,' + img_str.decode('utf-8')
163
+
164
+ def ai_summary(transcript, frames, audio_description, extra_context=""):
165
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
166
+ messages=[
167
+ {"role": "system", "content": "You are social media content analysis bot trying to uncover trends about what makes a video distinct. Given the transcript, frames, and a description of the audio, give a short analysis of the video content and what makes it unique."},
168
+ {"role": "user",
169
+ "content": [{
170
+ "type": "text",
171
+ "text": f"Transcript: {transcript}\n\n\n\nAudio: {audio_description}\n\nExtra Context?: {extra_context or 'n/a'}",
172
+ },
173
+ *[
174
+ {
175
+ "type": "image_url",
176
+ "image_url": {"url": base64_encode_image(frame)},
177
+ } for frame in frames
178
+ ]
179
+ ]}
180
+ ]
181
+ return client.chat.completions.create(
182
+ model="gpt-4o",
183
+ messages=messages
184
+ )
185
+
186
+
187
+ # %%
188
+ import app as gr
189
+
190
+ # %%
191
+ import uuid, shutil
192
+ import tempfile
193
+
194
+ def tiktok_analyze(video_path):
195
+ results = {
196
+ "overview": "",
197
+ "ai_overview": "",
198
+ "first_5s_analysis": "",
199
+ "video_analysis": "",
200
+ "transcript": "",
201
+ }
202
+
203
+ video_id = str(uuid.uuid4())
204
+ # copy video path to videos/video_id
205
+
206
+ path_root = f"{tempfile.gettempdir()}/videos/{video_id}"
207
+ os.makedirs(path_root, exist_ok=True)
208
+ shutil.copy(video_path, f"{path_root}.mp4")
209
+ video = VideoFileClip(f"{path_root}.mp4")
210
+ extract_video_parts(video, f"{path_root}_parts")
211
+
212
+ frames = get_frames(f"{path_root}_parts/frames")
213
+ first_5s_analysis = analyze_frames(f"{path_root}_parts/5s_frames", tag_embeddings)
214
+ whole_analysis = analyze_frames(f"{path_root}_parts/frames", tag_embeddings)
215
+
216
+ audio_features = extract_audio_features(f"{path_root}_parts/audio.mp3")
217
+
218
+ results["transcript"] = transcribe(f"{path_root}_parts/audio.mp3")
219
+
220
+ ai_summary_response = ai_summary(results["transcript"], frames, audio_features).choices[0].message.content
221
+
222
+ results["overview"] = f"""
223
+ ## Overview
224
+ **duration:** {video.duration}
225
+
226
+ **major themes:** {", ".join(list(whole_analysis["summary"]["tags"])[:5])}
227
+
228
+ **audio:** {", ".join(audio_features)}
229
+ """
230
+
231
+ results["ai_overview"] = "# AI Summary\n" + ai_summary_response
232
+ results["first_5s_analysis"] = f"Major themes: {', '.join(first_5s_analysis['summary']['tags'])}"
233
+ results["video_analysis"] = f"Major themes: {', '.join(whole_analysis['summary']['tags'])}"
234
+
235
+ return [
236
+ results["overview"],
237
+ results["first_5s_analysis"],
238
+ results["video_analysis"],
239
+ results["ai_overview"],
240
+ results["transcript"],
241
+ ]
242
+
243
+ demo = gr.Interface(
244
+ title="Tiktok Content Analyzer",
245
+ description="Start by uploading a video to analyze.",
246
+ fn=tiktok_analyze,
247
+ inputs="video",
248
+ outputs=[
249
+ gr.Markdown(label="Overview"),
250
+ gr.Text(label="First 5s Content Analysis"),
251
+ gr.Text(label="Content Analysis"),
252
+ gr.Markdown(label="AI Summary"),
253
+ gr.Text(label="Transcript")]
254
+ )
255
+
256
+ demo.launch()
257
+
258
+ # %%
259
+
260
+
261
+