|
import json |
|
import os |
|
import random |
|
import re |
|
import sys |
|
import time |
|
from datetime import datetime |
|
from glob import glob |
|
from pathlib import Path |
|
from typing import List, Optional |
|
from uuid import uuid4 |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import requests |
|
from datasets import load_dataset |
|
from huggingface_hub import ( |
|
CommitScheduler, |
|
HfApi, |
|
InferenceClient, |
|
login, |
|
snapshot_download, |
|
) |
|
from PIL import Image |
|
|
|
cached_latest_posts_df = None |
|
cached_top_posts = None |
|
last_fetched = None |
|
last_fetched_top = None |
|
|
|
import os |
|
import tempfile |
|
from zipfile import ZipFile |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from decord import VideoReader |
|
from decord import cpu |
|
|
|
|
|
def get_reddit_id(url): |
|
|
|
pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)" |
|
|
|
|
|
match = re.match(pattern, url) |
|
|
|
if match: |
|
|
|
post_id = match.group(1) or match.group(2) |
|
print(f"Valid GamePhysics post ID: {post_id}") |
|
else: |
|
post_id = url |
|
|
|
return post_id |
|
|
|
|
|
def download_samples(url, video_url, num_frames): |
|
frames = extract_frames_decord(video_url, num_frames) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
for i, frame in enumerate(frames): |
|
frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") |
|
frame.save( |
|
frame_path, format="JPEG", quality=85 |
|
) |
|
|
|
|
|
post_id = get_reddit_id(url) |
|
print(f"Creating zip file for post {post_id}") |
|
zip_path = f"frames-{post_id}.zip" |
|
with ZipFile(zip_path, "w") as zipf: |
|
for i in range(num_frames): |
|
frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") |
|
zipf.write(frame_path, os.path.basename(frame_path)) |
|
|
|
|
|
return zip_path |
|
|
|
|
|
def extract_frames_decord(video_path, num_frames=10): |
|
try: |
|
start_time = time.time() |
|
|
|
print(f"Extracting {num_frames} frames from {video_path}") |
|
|
|
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
|
|
|
|
total_frames = len(vr) |
|
frame_indices = np.linspace( |
|
0, total_frames - 1, num_frames, dtype=int, endpoint=False |
|
) |
|
|
|
|
|
batch_frames = vr.get_batch(frame_indices).asnumpy() |
|
|
|
|
|
frame_images = [ |
|
Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) |
|
] |
|
|
|
end_time = time.time() |
|
print(f"Decord extraction took {end_time - start_time} seconds") |
|
|
|
return frame_images |
|
except Exception as e: |
|
raise Exception(f"Error extracting frames from video: {e}") |
|
|
|
|
|
def get_top_posts(): |
|
global cached_top_posts |
|
global last_fetched_top |
|
|
|
|
|
now_time = datetime.now() |
|
if last_fetched_top is not None and (now_time - last_fetched_top).seconds < 600: |
|
print("Using cached data") |
|
return cached_top_posts |
|
|
|
last_fetched_top = now_time |
|
url = "https://www.reddit.com/r/GamePhysics/top/.json?t=month" |
|
headers = {"User-Agent": "Mozilla/5.0"} |
|
|
|
response = requests.get(url, headers=headers) |
|
if response.status_code != 200: |
|
return [] |
|
|
|
data = response.json() |
|
|
|
|
|
posts = data["data"]["children"] |
|
|
|
for post in posts: |
|
title = post["data"]["title"] |
|
post_id = post["data"]["id"] |
|
|
|
|
|
|
|
examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] |
|
|
|
examples = pd.DataFrame(examples, columns=["post_id", "title"]) |
|
cached_top_posts = examples |
|
return examples |
|
|
|
|
|
def get_latest_posts(): |
|
global cached_latest_posts_df |
|
global last_fetched |
|
|
|
|
|
now_time = datetime.now() |
|
if last_fetched is not None and (now_time - last_fetched).seconds < 600: |
|
print("Using cached data") |
|
return cached_latest_posts_df |
|
|
|
last_fetched = now_time |
|
url = "https://www.reddit.com/r/GamePhysics/.json" |
|
headers = {"User-Agent": "Mozilla/5.0"} |
|
|
|
response = requests.get(url, headers=headers) |
|
if response.status_code != 200: |
|
return [] |
|
|
|
data = response.json() |
|
|
|
|
|
posts = data["data"]["children"] |
|
|
|
for post in posts: |
|
title = post["data"]["title"] |
|
post_id = post["data"]["id"] |
|
|
|
|
|
|
|
examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] |
|
|
|
examples = pd.DataFrame(examples, columns=["post_id", "title"]) |
|
cached_latest_posts_df = examples |
|
return examples |
|
|
|
|
|
def row_selected(evt: gr.SelectData): |
|
global cached_latest_posts_df |
|
global cached_top_posts |
|
|
|
|
|
string_value = evt.value |
|
row = evt.index[0] |
|
target_df = None |
|
|
|
if cached_latest_posts_df.isin([string_value]).any().any(): |
|
target_df = cached_latest_posts_df |
|
elif cached_top_posts.isin([string_value]).any().any(): |
|
target_df = cached_top_posts |
|
else: |
|
raise gr.Error("Could not find selected post in any dataframe") |
|
|
|
post_id = target_df.iloc[row]["post_id"] |
|
return post_id |
|
|
|
|
|
def load_video(url): |
|
post_id = get_reddit_id(url) |
|
video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true" |
|
|
|
|
|
r = requests.head(video_url) |
|
if r.status_code != 200 and r.status_code != 302: |
|
raise gr.Error( |
|
f"Video is not in the repo, please try another post. - {r.status_code }" |
|
) |
|
|
|
return video_url |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Preview GamePhysics") |
|
dummt_title = gr.Textbox(visible=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
reddit_id = gr.Textbox( |
|
lines=1, placeholder="Post url or id here", label="URL or Post ID" |
|
) |
|
load_btn = gr.Button("Load") |
|
video_player = gr.Video(interactive=False) |
|
|
|
with gr.Column(): |
|
gr.Markdown("## Latest Posts") |
|
latest_post_dataframe = gr.Dataframe() |
|
latest_posts_btn = gr.Button("Refresh Latest Posts") |
|
top_posts_btn = gr.Button("Refresh Top Posts") |
|
|
|
with gr.Column(): |
|
gr.Markdown("## Sampled Frames from Video") |
|
with gr.Row(): |
|
num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10) |
|
sample_decord_btn = gr.Button("Sample decord") |
|
|
|
sampled_frames = gr.Gallery() |
|
|
|
download_samples_btn = gr.Button("Download Samples") |
|
output_files = gr.File() |
|
|
|
download_samples_btn.click( |
|
download_samples, |
|
inputs=[reddit_id, video_player, num_frames], |
|
outputs=[output_files], |
|
) |
|
|
|
sample_decord_btn.click( |
|
extract_frames_decord, |
|
inputs=[video_player, num_frames], |
|
outputs=[sampled_frames], |
|
) |
|
|
|
load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player]) |
|
|
|
latest_posts_btn.click(get_latest_posts, outputs=[latest_post_dataframe]) |
|
top_posts_btn.click(get_top_posts, outputs=[latest_post_dataframe]) |
|
|
|
demo.load(get_latest_posts, outputs=[latest_post_dataframe]) |
|
|
|
latest_post_dataframe.select(fn=row_selected, outputs=[reddit_id]).then( |
|
load_video, inputs=[reddit_id], outputs=[video_player] |
|
) |
|
|
|
demo.launch() |
|
|