|
""" |
|
Caption tab for Video Model Studio UI |
|
""" |
|
|
|
import gradio as gr |
|
import logging |
|
import asyncio |
|
import traceback |
|
from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple |
|
from pathlib import Path |
|
|
|
from .base_tab import BaseTab |
|
from ..config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, TRAINING_VIDEOS_PATH |
|
from ..utils import is_image_file, is_video_file, copy_files_to_training_dir |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class CaptionTab(BaseTab): |
|
"""Caption tab for managing asset captions""" |
|
|
|
def __init__(self, app_state): |
|
super().__init__(app_state) |
|
self.id = "caption_tab" |
|
self.title = "3️⃣ Caption" |
|
self._should_stop_captioning = False |
|
|
|
def create(self, parent=None) -> gr.TabItem: |
|
"""Create the Caption tab UI components""" |
|
with gr.TabItem(self.title, id=self.id) as tab: |
|
with gr.Row(): |
|
self.components["caption_title"] = gr.Markdown("## Captioning of 0 files (0 bytes)") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
self.components["custom_prompt_prefix"] = gr.Textbox( |
|
scale=3, |
|
label='Prefix to add to ALL captions (eg. "In the style of TOK, ")', |
|
placeholder="In the style of TOK, ", |
|
lines=2, |
|
value=DEFAULT_PROMPT_PREFIX |
|
) |
|
self.components["captioning_bot_instructions"] = gr.Textbox( |
|
scale=6, |
|
label="System instructions for the automatic captioning model", |
|
placeholder="Please generate a full description of...", |
|
lines=5, |
|
value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS |
|
) |
|
with gr.Row(): |
|
self.components["run_autocaption_btn"] = gr.Button( |
|
"Automatically fill missing captions", |
|
variant="primary" |
|
) |
|
self.components["copy_files_to_training_dir_btn"] = gr.Button( |
|
"Copy assets to training directory", |
|
variant="primary" |
|
) |
|
self.components["stop_autocaption_btn"] = gr.Button( |
|
"Stop Captioning", |
|
variant="stop", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
self.components["training_dataset"] = gr.Dataframe( |
|
headers=["name", "status"], |
|
interactive=False, |
|
wrap=True, |
|
value=self.list_training_files_to_caption(), |
|
row_count=10 |
|
) |
|
|
|
with gr.Column(): |
|
self.components["preview_video"] = gr.Video( |
|
label="Video Preview", |
|
interactive=False, |
|
visible=False |
|
) |
|
self.components["preview_image"] = gr.Image( |
|
label="Image Preview", |
|
interactive=False, |
|
visible=False |
|
) |
|
self.components["preview_caption"] = gr.Textbox( |
|
label="Caption", |
|
lines=6, |
|
interactive=True |
|
) |
|
self.components["save_caption_btn"] = gr.Button("Save Caption") |
|
self.components["preview_status"] = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
visible=True |
|
) |
|
self.components["original_file_path"] = gr.State(value=None) |
|
|
|
return tab |
|
|
|
def connect_events(self) -> None: |
|
"""Connect event handlers to UI components""" |
|
|
|
self.components["run_autocaption_btn"].click( |
|
fn=self.show_refreshing_status, |
|
outputs=[self.components["training_dataset"]] |
|
).then( |
|
fn=self.update_captioning_buttons_start, |
|
outputs=[ |
|
self.components["run_autocaption_btn"], |
|
self.components["stop_autocaption_btn"], |
|
self.components["copy_files_to_training_dir_btn"] |
|
] |
|
).then( |
|
fn=self.start_caption_generation, |
|
inputs=[ |
|
self.components["captioning_bot_instructions"], |
|
self.components["custom_prompt_prefix"] |
|
], |
|
outputs=[self.components["training_dataset"]], |
|
).then( |
|
fn=self.update_captioning_buttons_end, |
|
outputs=[ |
|
self.components["run_autocaption_btn"], |
|
self.components["stop_autocaption_btn"], |
|
self.components["copy_files_to_training_dir_btn"] |
|
] |
|
) |
|
|
|
|
|
self.components["copy_files_to_training_dir_btn"].click( |
|
fn=self.copy_files_to_training_dir, |
|
inputs=[self.components["custom_prompt_prefix"]] |
|
) |
|
|
|
|
|
self.components["stop_autocaption_btn"].click( |
|
fn=self.stop_captioning, |
|
outputs=[ |
|
self.components["training_dataset"], |
|
self.components["run_autocaption_btn"], |
|
self.components["stop_autocaption_btn"], |
|
self.components["copy_files_to_training_dir_btn"] |
|
] |
|
) |
|
|
|
|
|
self.components["training_dataset"].select( |
|
fn=self.handle_training_dataset_select, |
|
outputs=[ |
|
self.components["preview_image"], |
|
self.components["preview_video"], |
|
self.components["preview_caption"], |
|
self.components["original_file_path"], |
|
self.components["preview_status"] |
|
] |
|
) |
|
|
|
|
|
self.components["save_caption_btn"].click( |
|
fn=self.save_caption_changes, |
|
inputs=[ |
|
self.components["preview_caption"], |
|
self.components["preview_image"], |
|
self.components["preview_video"], |
|
self.components["original_file_path"], |
|
self.components["custom_prompt_prefix"] |
|
], |
|
outputs=[self.components["preview_status"]] |
|
).success( |
|
fn=self.list_training_files_to_caption, |
|
outputs=[self.components["training_dataset"]] |
|
) |
|
|
|
def refresh(self) -> Dict[str, Any]: |
|
"""Refresh the dataset list with current data""" |
|
training_dataset = self.list_training_files_to_caption() |
|
return { |
|
"training_dataset": training_dataset |
|
} |
|
|
|
def show_refreshing_status(self) -> List[List[str]]: |
|
"""Show a 'Refreshing...' status in the dataframe""" |
|
return [["Refreshing...", "please wait"]] |
|
|
|
def update_captioning_buttons_start(self): |
|
"""Return individual button values instead of a dictionary""" |
|
return ( |
|
gr.Button( |
|
interactive=False, |
|
variant="secondary", |
|
), |
|
gr.Button( |
|
interactive=True, |
|
variant="stop", |
|
), |
|
gr.Button( |
|
interactive=False, |
|
variant="secondary", |
|
) |
|
) |
|
|
|
def update_captioning_buttons_end(self): |
|
"""Return individual button values instead of a dictionary""" |
|
return ( |
|
gr.Button( |
|
interactive=True, |
|
variant="primary", |
|
), |
|
gr.Button( |
|
interactive=False, |
|
variant="secondary", |
|
), |
|
gr.Button( |
|
interactive=True, |
|
variant="primary", |
|
) |
|
) |
|
|
|
def stop_captioning(self): |
|
"""Stop ongoing captioning process and reset UI state""" |
|
try: |
|
|
|
self._should_stop_captioning = True |
|
|
|
|
|
if self.app.captioner: |
|
self.app.captioner.stop_captioning() |
|
|
|
|
|
updated_list = self.list_training_files_to_caption() |
|
|
|
|
|
return { |
|
"training_dataset": gr.update(value=updated_list), |
|
"run_autocaption_btn": gr.Button(interactive=True, variant="primary"), |
|
"stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), |
|
"copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") |
|
} |
|
except Exception as e: |
|
logger.error(f"Error stopping captioning: {str(e)}") |
|
return { |
|
"training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]), |
|
"run_autocaption_btn": gr.Button(interactive=True, variant="primary"), |
|
"stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), |
|
"copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") |
|
} |
|
|
|
def copy_files_to_training_dir(self, prompt_prefix: str): |
|
"""Run auto-captioning process""" |
|
|
|
self._should_stop_captioning = False |
|
|
|
try: |
|
copy_files_to_training_dir(prompt_prefix) |
|
except Exception as e: |
|
traceback.print_exc() |
|
raise gr.Error(f"Error copying assets to training dir: {str(e)}") |
|
|
|
async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix): |
|
"""Process the caption generator's results in the background""" |
|
try: |
|
async for _ in self.start_caption_generation( |
|
captioning_bot_instructions, |
|
prompt_prefix |
|
): |
|
|
|
pass |
|
logger.info("Background captioning completed") |
|
except Exception as e: |
|
logger.error(f"Error in background captioning: {str(e)}") |
|
|
|
async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]: |
|
"""Run auto-captioning process""" |
|
try: |
|
|
|
self._should_stop_captioning = False |
|
|
|
|
|
yield gr.update( |
|
value=[["Starting captioning service...", "initializing"]], |
|
headers=["name", "status"] |
|
) |
|
|
|
|
|
file_statuses = {} |
|
|
|
|
|
async for rows in self.app.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix): |
|
|
|
for name, status in rows: |
|
file_statuses[name] = status |
|
|
|
|
|
status_rows = [[name, status] for name, status in file_statuses.items()] |
|
|
|
|
|
status_rows.sort(key=lambda x: x[0]) |
|
|
|
|
|
yield gr.update( |
|
value=status_rows, |
|
headers=["name", "status"] |
|
) |
|
|
|
|
|
yield gr.update( |
|
value=self.list_training_files_to_caption(), |
|
headers=["name", "status"] |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in captioning: {str(e)}") |
|
yield gr.update( |
|
value=[[f"Error: {str(e)}", "error"]], |
|
headers=["name", "status"] |
|
) |
|
|
|
def list_training_files_to_caption(self) -> List[List[str]]: |
|
"""List all clips and images - both pending and captioned""" |
|
files = [] |
|
already_listed = {} |
|
|
|
|
|
for file in STAGING_PATH.glob("*.*"): |
|
if is_video_file(file) or is_image_file(file): |
|
txt_file = file.with_suffix('.txt') |
|
|
|
|
|
has_caption = txt_file.exists() and txt_file.stat().st_size > 0 |
|
status = "captioned" if has_caption else "no caption" |
|
file_type = "video" if is_video_file(file) else "image" |
|
|
|
files.append([file.name, f"{status} ({file_type})", str(file)]) |
|
already_listed[file.name] = True |
|
|
|
|
|
for file in TRAINING_VIDEOS_PATH.glob("*.*"): |
|
if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed: |
|
txt_file = file.with_suffix('.txt') |
|
|
|
|
|
if txt_file.exists() and txt_file.stat().st_size > 0: |
|
file_type = "video" if is_video_file(file) else "image" |
|
files.append([file.name, f"captioned ({file_type})", str(file)]) |
|
already_listed[file.name] = True |
|
|
|
|
|
files.sort(key=lambda x: x[0]) |
|
|
|
|
|
return [[file[0], file[1]] for file in files] |
|
|
|
def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: |
|
"""Handle selection of both video clips and images""" |
|
try: |
|
if not evt: |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
visible=False |
|
), |
|
None, |
|
"No file selected" |
|
] |
|
|
|
file_name = evt.value |
|
if not file_name: |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
visible=False |
|
), |
|
None, |
|
"No file selected" |
|
] |
|
|
|
|
|
possible_paths = [ |
|
STAGING_PATH / file_name, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
file_path = None |
|
for path in possible_paths: |
|
if path.exists(): |
|
file_path = path |
|
break |
|
|
|
if not file_path: |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
visible=False |
|
), |
|
None, |
|
f"File not found: {file_name}" |
|
] |
|
|
|
txt_path = file_path.with_suffix('.txt') |
|
caption = txt_path.read_text() if txt_path.exists() else "" |
|
|
|
|
|
if is_video_file(file_path): |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
label="Video Preview", |
|
interactive=False, |
|
visible=True, |
|
value=str(file_path) |
|
), |
|
gr.Textbox( |
|
label="Caption", |
|
lines=6, |
|
interactive=True, |
|
visible=True, |
|
value=str(caption) |
|
), |
|
str(file_path), |
|
None |
|
] |
|
|
|
elif is_image_file(file_path): |
|
return [ |
|
gr.Image( |
|
label="Image Preview", |
|
interactive=False, |
|
visible=True, |
|
value=str(file_path) |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
label="Caption", |
|
lines=6, |
|
interactive=True, |
|
visible=True, |
|
value=str(caption) |
|
), |
|
str(file_path), |
|
None |
|
] |
|
else: |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
interactive=False, |
|
visible=False |
|
), |
|
None, |
|
f"Unsupported file type: {file_path.suffix}" |
|
] |
|
except Exception as e: |
|
logger.error(f"Error handling selection: {str(e)}") |
|
return [ |
|
gr.Image( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Video( |
|
interactive=False, |
|
visible=False |
|
), |
|
gr.Textbox( |
|
interactive=False, |
|
visible=False |
|
), |
|
None, |
|
f"Error handling selection: {str(e)}" |
|
] |
|
|
|
def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str): |
|
"""Save changes to caption""" |
|
try: |
|
|
|
if original_file_path: |
|
file_path = Path(original_file_path) |
|
self.app.captioner.update_file_caption(file_path, preview_caption) |
|
|
|
return gr.update(value="Caption saved successfully!") |
|
else: |
|
return gr.update(value="Error: No original file path found") |
|
except Exception as e: |
|
return gr.update(value=f"Error saving caption: {str(e)}") |
|
|
|
def preview_file(self, selected_text: str) -> Dict: |
|
"""Generate preview based on selected file |
|
|
|
Args: |
|
selected_text: Text of the selected item containing filename |
|
|
|
Returns: |
|
Dict with preview content for each preview component |
|
""" |
|
import mimetypes |
|
from ..config import TRAINING_VIDEOS_PATH |
|
|
|
if not selected_text or "Caption:" in selected_text: |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": None |
|
} |
|
|
|
|
|
filename = selected_text.split(" (")[0].strip() |
|
file_path = TRAINING_VIDEOS_PATH / filename |
|
|
|
if not file_path.exists(): |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": f"File not found: {filename}" |
|
} |
|
|
|
|
|
mime_type, _ = mimetypes.guess_type(str(file_path)) |
|
if not mime_type: |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": f"Unknown file type: {filename}" |
|
} |
|
|
|
|
|
if mime_type.startswith('video/'): |
|
return { |
|
"video": str(file_path), |
|
"image": None, |
|
"text": None |
|
} |
|
elif mime_type.startswith('image/'): |
|
return { |
|
"video": None, |
|
"image": str(file_path), |
|
"text": None |
|
} |
|
elif mime_type.startswith('text/'): |
|
try: |
|
text_content = file_path.read_text() |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": text_content |
|
} |
|
except Exception as e: |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": f"Error reading file: {str(e)}" |
|
} |
|
else: |
|
return { |
|
"video": None, |
|
"image": None, |
|
"text": f"Unsupported file type: {mime_type}" |
|
} |