Last commit not found
import gradio as gr | |
from gradio_client import Client | |
import cv2 | |
from moviepy.editor import * | |
# 1. extract and store 1 image every 5 images from video input | |
# 2. extract audio | |
# 3. for each image from extracted_images, get caption from caption model and concatenate into list | |
# 4. for audio, ask audio questioning model to describe sound/scene | |
# 5. give all to LLM, and ask it to resume, according to image caption list combined to audio caption | |
import re | |
import torch | |
from transformers import pipeline | |
zephyr_model = "HuggingFaceH4/zephyr-7b-beta" | |
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto") | |
standard_sys = f""" | |
""" | |
def extract_frames(video_in, interval=12, output_format='.jpg'): | |
"""Extract frames from a video at a specified interval and store them in a list. | |
Args: | |
- video_in: string or path-like object pointing to the video file | |
- interval: integer specifying how many frames apart to extract images (default: 5) | |
- output_format: string indicating desired format for saved images (default: '.jpg') | |
Returns: | |
A list of strings containing paths to saved images. | |
""" | |
# Initialize variables | |
vidcap = cv2.VideoCapture(video_in) | |
frames = [] | |
count = 0 | |
# Loop through frames until there are no more | |
while True: | |
success, image = vidcap.read() | |
# Check if successful read and not past end of video | |
if success: | |
print('Read a new frame:', success) | |
# Save current frame if it meets criteria | |
if count % interval == 0: | |
filename = f'frame_{count // interval}{output_format}' | |
frames.append(filename) | |
cv2.imwrite(filename, image) | |
print(f'Saved {filename}') | |
# Increment counter | |
count += 1 | |
# Break out of loop when done reading frames | |
else: | |
break | |
# Close video capture | |
vidcap.release() | |
print('Done extracting frames!') | |
return frames | |
def process_image(image_in): | |
client = Client("vikhyatk/moondream2") | |
result = client.predict( | |
image_in, # filepath in 'image' Image component | |
"Describe precisely the image in one sentence.", # str in 'Question' Textbox component | |
api_name="/answer_question" | |
#api_name="/predict" | |
) | |
print(result) | |
return result | |
def extract_audio(video_path): | |
# Open the video clip and extract the audio stream | |
audio = VideoFileClip(video_path).audio | |
# Set the output file path and format | |
output_path = 'output_audio.wav' | |
# Write the audio stream to disk using the AAC codec | |
audio.write_audiofile(output_path, codec='aac') | |
# Confirm that the audio file was written successfully | |
if os.path.exists(output_path): | |
print(f'Successfully wrote audio to {output_path}.') | |
return output_path | |
else: | |
raise FileNotFoundError(f'Failed to write audio to {output_path}.') | |
def get_salmonn(audio_in): | |
salmonn_prompt = "Please list each event in the audio in order." | |
client = Client("fffiloni/SALMONN-7B-gradio") | |
result = client.predict( | |
audio_in, # filepath in 'Audio' Audio component | |
salmonn_prompt, # str in 'User question' Textbox component | |
4, # float (numeric value between 1 and 10) in 'beam search numbers' Slider component | |
1, # float (numeric value between 0.8 and 2.0) in 'temperature' Slider component | |
0.9, # float (numeric value between 0.1 and 1.0) in 'top p' Slider component | |
api_name="/gradio_answer" | |
) | |
print(result) | |
return result | |
def llm_process(user_prompt): | |
agent_maker_sys = standard_sys | |
instruction = f""" | |
<|system|> | |
{agent_maker_sys}</s> | |
<|user|> | |
""" | |
prompt = f"{instruction.strip()}\n{user_prompt}</s>" | |
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>' | |
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL) | |
print(f"SUGGESTED video description: {cleaned_text}") | |
return cleaned_text.lstrip("\n") | |
def infer(video_in): | |
# Extract frames from a video | |
frame_files = extract_frames(video_in) | |
# Process each extracted frame and collect results in a list | |
processed_texts = [] | |
for frame_file in frame_files: | |
text = process_image(frame_file) | |
processed_texts.append(text) | |
print(processed_texts) | |
# Convert processed_texts list to a string list with line breaks | |
string_list = '\n'.join(processed_texts) | |
# Extract audio from video | |
extracted_audio = extract_audio(video_in) | |
# Get description of audio content | |
audio_content_described = get_salmonn(extracted_audio) | |
# Assemble captions | |
formatted_captions = f""" | |
### Visual events:\n{string_list}\n ### Audio events:\n{audio_content_described} | |
""" | |
print(formatted_captions) | |
# Send formatted captions to LLM | |
#video_description_from_llm = llm_process(formatted_captions) | |
return formatted_captions | |
with gr.Blocks() as demo : | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(""" | |
<h2 style="text-align: center;">Video description</h2> | |
""") | |
video_in = gr.Video(label="Video input") | |
submit_btn = gr.Button("Submit") | |
video_description = gr.Textbox(label="Video description") | |
submit_btn.click( | |
fn = infer, | |
inputs = [video_in], | |
outputs = [video_description] | |
) | |
demo.queue().launch() |