Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import hashlib | |
import os | |
import shutil | |
import tempfile | |
from pathlib import Path | |
from typing import Iterable, List, Optional, Tuple, Union | |
import av | |
import strawberry | |
from app_conf import ( | |
DATA_PATH, | |
DEFAULT_VIDEO_PATH, | |
MAX_UPLOAD_VIDEO_DURATION, | |
UPLOADS_PATH, | |
UPLOADS_PREFIX, | |
) | |
from data.data_types import ( | |
AddPointsInput, | |
CancelPropagateInVideo, | |
CancelPropagateInVideoInput, | |
ClearPointsInFrameInput, | |
ClearPointsInVideo, | |
ClearPointsInVideoInput, | |
CloseSession, | |
CloseSessionInput, | |
RemoveObjectInput, | |
RLEMask, | |
RLEMaskForObject, | |
RLEMaskListOnFrame, | |
StartSession, | |
StartSessionInput, | |
Video, | |
) | |
from data.loader import get_video | |
from data.store import get_videos | |
from data.transcoder import get_video_metadata, transcode, VideoMetadata | |
from inference.data_types import ( | |
AddPointsRequest, | |
CancelPropagateInVideoRequest, | |
CancelPropagateInVideoRequest, | |
ClearPointsInFrameRequest, | |
ClearPointsInVideoRequest, | |
CloseSessionRequest, | |
RemoveObjectRequest, | |
StartSessionRequest, | |
) | |
from inference.predictor import InferenceAPI | |
from strawberry import relay | |
from strawberry.file_uploads import Upload | |
class Query: | |
def default_video(self) -> Video: | |
""" | |
Return the default video. | |
The default video can be set with the DEFAULT_VIDEO_PATH environment | |
variable. It will return the video that matches this path. If no video | |
is found, it will return the first video. | |
""" | |
all_videos = get_videos() | |
# Find the video that matches the default path and return that as | |
# default video. | |
for _, v in all_videos.items(): | |
if v.path == DEFAULT_VIDEO_PATH: | |
return v | |
# Fallback is returning the first video | |
return next(iter(all_videos.values())) | |
def videos( | |
self, | |
) -> Iterable[Video]: | |
""" | |
Return all available videos. | |
""" | |
all_videos = get_videos() | |
return all_videos.values() | |
class Mutation: | |
def upload_video( | |
self, | |
file: Upload, | |
start_time_sec: Optional[float] = None, | |
duration_time_sec: Optional[float] = None, | |
) -> Video: | |
""" | |
Receive a video file and store it in the configured S3 bucket. | |
""" | |
max_time = MAX_UPLOAD_VIDEO_DURATION | |
filepath, file_key, vm = process_video( | |
file, | |
max_time=max_time, | |
start_time_sec=start_time_sec, | |
duration_time_sec=duration_time_sec, | |
) | |
video = get_video( | |
filepath, | |
UPLOADS_PATH, | |
file_key=file_key, | |
width=vm.width, | |
height=vm.height, | |
generate_poster=False, | |
) | |
return video | |
def start_session( | |
self, input: StartSessionInput, info: strawberry.Info | |
) -> StartSession: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = StartSessionRequest( | |
type="start_session", | |
path=f"{DATA_PATH}/{input.path}", | |
) | |
response = inference_api.start_session(request=request) | |
return StartSession(session_id=response.session_id) | |
def close_session( | |
self, input: CloseSessionInput, info: strawberry.Info | |
) -> CloseSession: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = CloseSessionRequest( | |
type="close_session", | |
session_id=input.session_id, | |
) | |
response = inference_api.close_session(request) | |
return CloseSession(success=response.success) | |
def add_points( | |
self, input: AddPointsInput, info: strawberry.Info | |
) -> RLEMaskListOnFrame: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = AddPointsRequest( | |
type="add_points", | |
session_id=input.session_id, | |
frame_index=input.frame_index, | |
object_id=input.object_id, | |
points=input.points, | |
labels=input.labels, | |
clear_old_points=input.clear_old_points, | |
) | |
reponse = inference_api.add_points(request) | |
return RLEMaskListOnFrame( | |
frame_index=reponse.frame_index, | |
rle_mask_list=[ | |
RLEMaskForObject( | |
object_id=r.object_id, | |
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), | |
) | |
for r in reponse.results | |
], | |
) | |
def remove_object( | |
self, input: RemoveObjectInput, info: strawberry.Info | |
) -> List[RLEMaskListOnFrame]: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = RemoveObjectRequest( | |
type="remove_object", session_id=input.session_id, object_id=input.object_id | |
) | |
response = inference_api.remove_object(request) | |
return [ | |
RLEMaskListOnFrame( | |
frame_index=res.frame_index, | |
rle_mask_list=[ | |
RLEMaskForObject( | |
object_id=r.object_id, | |
rle_mask=RLEMask( | |
counts=r.mask.counts, size=r.mask.size, order="F" | |
), | |
) | |
for r in res.results | |
], | |
) | |
for res in response.results | |
] | |
def clear_points_in_frame( | |
self, input: ClearPointsInFrameInput, info: strawberry.Info | |
) -> RLEMaskListOnFrame: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = ClearPointsInFrameRequest( | |
type="clear_points_in_frame", | |
session_id=input.session_id, | |
frame_index=input.frame_index, | |
object_id=input.object_id, | |
) | |
response = inference_api.clear_points_in_frame(request) | |
return RLEMaskListOnFrame( | |
frame_index=response.frame_index, | |
rle_mask_list=[ | |
RLEMaskForObject( | |
object_id=r.object_id, | |
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), | |
) | |
for r in response.results | |
], | |
) | |
def clear_points_in_video( | |
self, input: ClearPointsInVideoInput, info: strawberry.Info | |
) -> ClearPointsInVideo: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = ClearPointsInVideoRequest( | |
type="clear_points_in_video", | |
session_id=input.session_id, | |
) | |
response = inference_api.clear_points_in_video(request) | |
return ClearPointsInVideo(success=response.success) | |
def cancel_propagate_in_video( | |
self, input: CancelPropagateInVideoInput, info: strawberry.Info | |
) -> CancelPropagateInVideo: | |
inference_api: InferenceAPI = info.context["inference_api"] | |
request = CancelPropagateInVideoRequest( | |
type="cancel_propagate_in_video", | |
session_id=input.session_id, | |
) | |
response = inference_api.cancel_propagate_in_video(request) | |
return CancelPropagateInVideo(success=response.success) | |
def get_file_hash(video_path_or_file) -> str: | |
if isinstance(video_path_or_file, str): | |
with open(video_path_or_file, "rb") as in_f: | |
result = hashlib.sha256(in_f.read()).hexdigest() | |
else: | |
video_path_or_file.seek(0) | |
result = hashlib.sha256(video_path_or_file.read()).hexdigest() | |
return result | |
def _get_start_sec_duration_sec( | |
start_time_sec: Union[float, None], | |
duration_time_sec: Union[float, None], | |
max_time: float, | |
) -> Tuple[float, float]: | |
default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0")) | |
if start_time_sec is None: | |
start_time_sec = default_seek_t | |
if duration_time_sec is not None: | |
duration_time_sec = min(duration_time_sec, max_time) | |
else: | |
duration_time_sec = max_time | |
return start_time_sec, duration_time_sec | |
def process_video( | |
file: Upload, | |
max_time: float, | |
start_time_sec: Optional[float] = None, | |
duration_time_sec: Optional[float] = None, | |
) -> Tuple[Optional[str], str, str, VideoMetadata]: | |
""" | |
Process file upload including video trimming and content moderation checks. | |
Returns the filepath, s3_file_key, hash & video metaedata as a tuple. | |
""" | |
with tempfile.TemporaryDirectory() as tempdir: | |
in_path = f"{tempdir}/in.mp4" | |
out_path = f"{tempdir}/out.mp4" | |
with open(in_path, "wb") as in_f: | |
in_f.write(file.read()) | |
try: | |
video_metadata = get_video_metadata(in_path) | |
except av.InvalidDataError: | |
raise Exception("not valid video file") | |
if video_metadata.num_video_streams == 0: | |
raise Exception("video container does not contain a video stream") | |
if video_metadata.width is None or video_metadata.height is None: | |
raise Exception("video container does not contain width or height metadata") | |
if video_metadata.duration_sec in (None, 0): | |
raise Exception("video container does time duration metadata") | |
start_time_sec, duration_time_sec = _get_start_sec_duration_sec( | |
max_time=max_time, | |
start_time_sec=start_time_sec, | |
duration_time_sec=duration_time_sec, | |
) | |
# Transcode video to make sure videos returned to the app are all in | |
# the same format, duration, resolution, fps. | |
transcode( | |
in_path, | |
out_path, | |
video_metadata, | |
seek_t=start_time_sec, | |
duration_time_sec=duration_time_sec, | |
) | |
os.remove(in_path) # don't need original video now | |
out_video_metadata = get_video_metadata(out_path) | |
if out_video_metadata.num_video_frames == 0: | |
raise Exception( | |
"transcode produced empty video; check seek time or your input video" | |
) | |
filepath = None | |
file_key = None | |
with open(out_path, "rb") as file_data: | |
file_hash = get_file_hash(file_data) | |
file_data.seek(0) | |
file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4" | |
filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4") | |
assert filepath is not None and file_key is not None | |
shutil.move(out_path, filepath) | |
return filepath, file_key, out_video_metadata | |
schema = strawberry.Schema( | |
query=Query, | |
mutation=Mutation, | |
) | |