Muhammad Taqi Raza commited on
Commit
2ae859b
·
1 Parent(s): d2d7c02

adding gradio

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. app.py +113 -73
  3. datasets/.DS_Store +0 -0
  4. inference_script.py +1 -1
  5. requirements.txt +4 -1
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,82 +1,122 @@
1
- import gradio as gr
2
  import os
3
- import zipfile
4
- from pathlib import Path
5
-
6
- # The base directory that will be explored
7
- BASE_DIR = Path(__file__).resolve().parent / "your_target_folder"
8
-
9
- # Ensure BASE_DIR exists
10
- BASE_DIR.mkdir(parents=True, exist_ok=True)
11
-
12
- def resolve_path(current_rel_path: str) -> Path:
13
- """Secure path resolution within BASE_DIR"""
14
- resolved_path = (BASE_DIR / current_rel_path).resolve()
15
- if BASE_DIR not in resolved_path.parents and resolved_path != BASE_DIR:
16
- raise ValueError("Access outside base directory is not allowed.")
17
- return resolved_path
18
-
19
- def list_dir(current_rel_path: str = ""):
20
- current_path = resolve_path(current_rel_path)
21
-
22
- # Parent folder navigation
23
- parent_rel = str(Path(current_rel_path).parent) if current_rel_path else ""
24
- entries = []
25
-
26
- if current_path != BASE_DIR:
27
- entries.append(("..", "⬆️ Parent Folder"))
28
-
29
- # List directories and files
30
- for item in sorted(current_path.iterdir()):
31
- rel_item = os.path.relpath(item, BASE_DIR)
32
- label = f"📁 {item.name}" if item.is_dir() else f"📄 {item.name}"
33
- entries.append((rel_item, label))
34
-
35
- return gr.update(choices=entries, value=None), f"Currently in: /{current_rel_path}"
36
-
37
- def download_entry(selected_rel_path: str):
38
- selected_path = resolve_path(selected_rel_path)
39
-
40
- if selected_path.is_file():
41
- return selected_path
42
- elif selected_path.is_dir():
43
- zip_path = f"/tmp/{selected_path.name}.zip"
44
- with zipfile.ZipFile(zip_path, "w") as zipf:
45
- for root, dirs, files in os.walk(selected_path):
46
- for file in files:
47
- abs_file = os.path.join(root, file)
48
- arc_file = os.path.relpath(abs_file, selected_path)
49
- zipf.write(abs_file, arc_file)
50
- return zip_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
- return None
53
 
 
 
 
54
  with gr.Blocks() as demo:
55
- current_path = gr.State("")
 
56
 
57
- gr.Markdown("# 📁 Folder Browser")
58
-
59
  with gr.Row():
60
- folder_dropdown = gr.Dropdown(label="Select File or Folder", choices=[])
61
- refresh_btn = gr.Button("🔄 Refresh")
62
-
63
- status_text = gr.Textbox(label="Current Path", interactive=False)
64
- download_btn = gr.Button("⬇️ Download Selected")
65
- file_output = gr.File(label="Download Result")
66
-
67
- # Events
68
- refresh_btn.click(fn=list_dir, inputs=current_path, outputs=[folder_dropdown, status_text])
69
-
70
- folder_dropdown.change(
71
- fn=lambda x: (x, *list_dir(x)), # update path, refresh list
72
- inputs=folder_dropdown,
73
- outputs=[current_path, folder_dropdown, status_text],
74
- )
75
 
76
- download_btn.click(fn=download_entry, inputs=folder_dropdown, outputs=file_output)
77
-
78
- # Initial trigger
79
- demo.load(fn=list_dir, inputs=current_path, outputs=[folder_dropdown, status_text])
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  demo.launch()
82
-
 
 
1
  import os
2
+ import gradio as gr
3
+ import subprocess
4
+ import uuid
5
+ import shutil
6
+ from huggingface_hub import snapshot_download
7
+
8
+ # ----------------------------------------
9
+ # Step 1: Download Model Weights
10
+ # ----------------------------------------
11
+ MODEL_REPO = "roll-ai/DOVE"
12
+ MODEL_PATH = "pretrained_models/"
13
+
14
+ if not os.path.exists(MODEL_PATH) or len(os.listdir(MODEL_PATH)) == 0:
15
+ print("🔽 Downloading model weights from Hugging Face Hub...")
16
+ snapshot_download(
17
+ repo_id=MODEL_REPO,
18
+ repo_type="dataset",
19
+ local_dir=MODEL_PATH,
20
+ local_dir_use_symlinks=False
21
+ )
22
+ print("✅ Download complete.")
23
+
24
+ # ----------------------------------------
25
+ # Step 2: Setup Directories
26
+ # ----------------------------------------
27
+ INFERENCE_SCRIPT = "inference_script.py"
28
+ OUTPUT_DIR = "results/DOVE/demo"
29
+ UPLOAD_DIR = "input_videos"
30
+
31
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
32
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
33
+
34
+ # ----------------------------------------
35
+ # Step 3: Inference Function
36
+ # ----------------------------------------
37
+ def run_inference(video_path, save_format):
38
+ input_name = f"{uuid.uuid4()}.mp4"
39
+ input_path = os.path.join(UPLOAD_DIR, input_name)
40
+ shutil.copy(video_path, input_path)
41
+
42
+ # --- Run inference script ---
43
+ cmd = [
44
+ "python", INFERENCE_SCRIPT,
45
+ "--input_dir", UPLOAD_DIR,
46
+ "--model_path", MODEL_PATH,
47
+ "--output_path", OUTPUT_DIR,
48
+ "--is_vae_st",
49
+ "--save_format", save_format
50
+ ]
51
+
52
+ try:
53
+ inference_result = subprocess.run(
54
+ cmd,
55
+ capture_output=True,
56
+ text=True,
57
+ check=True
58
+ )
59
+ print("📄 Inference stdout:\n", inference_result.stdout)
60
+ print("⚠️ Inference stderr:\n", inference_result.stderr)
61
+ except subprocess.CalledProcessError as e:
62
+ print("❌ Inference failed.")
63
+ print("⚠️ STDOUT:\n", e.stdout)
64
+ print("⚠️ STDERR:\n", e.stderr)
65
+ return f"Inference failed:\n{e.stderr}", None
66
+
67
+ # --- Convert .mkv to .mp4 ---
68
+ mkv_path = os.path.join(OUTPUT_DIR, input_name).replace(".mp4", ".mkv")
69
+ mp4_path = os.path.join(OUTPUT_DIR, input_name)
70
+
71
+ if os.path.exists(mkv_path):
72
+ convert_cmd = [
73
+ "ffmpeg", "-y", "-i", mkv_path, "-c:v", "copy", "-c:a", "aac", mp4_path
74
+ ]
75
+ try:
76
+ convert_result = subprocess.run(
77
+ convert_cmd,
78
+ capture_output=True,
79
+ text=True,
80
+ check=True
81
+ )
82
+ print("🔄 FFmpeg stdout:\n", convert_result.stdout)
83
+ print("⚠️ FFmpeg stderr:\n", convert_result.stderr)
84
+ except subprocess.CalledProcessError as e:
85
+ print("❌ FFmpeg conversion failed.")
86
+ print("⚠️ STDOUT:\n", e.stdout)
87
+ print("⚠️ STDERR:\n", e.stderr)
88
+ return f"Inference OK, but conversion failed:\n{e.stderr}", None
89
+
90
+ if os.path.exists(mp4_path):
91
+ return "Inference successful!", mp4_path
92
  else:
93
+ return "Output video not found.", None
94
 
95
+ # ----------------------------------------
96
+ # Step 4: Gradio Interface
97
+ # ----------------------------------------
98
  with gr.Blocks() as demo:
99
+ gr.Markdown("# 🎥 DOVE Video SR + Restoration Inference Demo")
100
+ gr.Markdown("⚙️ **Note:** Default `save_format` is `yuv444p`. If playback fails, try `yuv420p` for compatibility.")
101
 
 
 
102
  with gr.Row():
103
+ input_video = gr.Video(label="Upload input video", type="filepath")
104
+ output_video = gr.Video(label="Output video")
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ with gr.Row():
107
+ save_format = gr.Dropdown(
108
+ choices=["yuv444p", "yuv420p"],
109
+ value="yuv444p",
110
+ label="Save format (for video playback compatibility)"
111
+ )
112
+
113
+ run_button = gr.Button("Run Inference")
114
+ status = gr.Textbox(label="Status")
115
+
116
+ run_button.click(
117
+ fn=run_inference,
118
+ inputs=[input_video, save_format],
119
+ outputs=[status, output_video],
120
+ )
121
 
122
  demo.launch()
 
datasets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
inference_script.py CHANGED
@@ -751,4 +751,4 @@ if __name__ == "__main__":
751
  with open(out_path, 'w') as f:
752
  json.dump(output, f, indent=2)
753
 
754
- print("All videos processed.")
 
751
  with open(out_path, 'w') as f:
752
  json.dump(output, f, indent=2)
753
 
754
+ print("All videos processed.")
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- gradio
2
  accelerate>=1.1.1
3
  transformers>=4.46.2
4
  numpy==1.26.0
@@ -19,3 +18,7 @@ opencv-python
19
  decord
20
  av
21
  torchdiffeq
 
 
 
 
 
 
1
  accelerate>=1.1.1
2
  transformers>=4.46.2
3
  numpy==1.26.0
 
18
  decord
19
  av
20
  torchdiffeq
21
+ diffusers["torch"]
22
+ transformers
23
+ pyiqa
24
+ huggingface_hub