Spaces:
Running
Running
from typing import List, Optional | |
from sqlalchemy.orm import Session | |
from data.models import AnnotationInterval | |
from utils.logger import Logger | |
log = Logger() | |
class AnnotationIntervalRepo: | |
""" | |
Data-access layer for `annotation_intervals` table. | |
""" | |
def __init__(self, db: Session) -> None: | |
self.db = db | |
# ------------------------------------------------------------------ # | |
# CREATE | |
# ------------------------------------------------------------------ # | |
def assign_interval_to_annotator( | |
self, | |
annotator_id: int, | |
start_idx: int, | |
end_idx: int, | |
allow_overlap: bool = False, | |
) -> AnnotationInterval: | |
""" | |
Create a new interval [start_idx, end_idx] for the given annotator. | |
Raises: | |
ValueError: • start >= end | |
• overlap detected (when allow_overlap=False) | |
""" | |
if start_idx >= end_idx: | |
raise ValueError("start_idx must be < end_idx") | |
if not allow_overlap: | |
overlap = ( | |
self.db.query(AnnotationInterval) | |
.filter(AnnotationInterval.annotator_id == annotator_id) | |
.filter( | |
AnnotationInterval.start_index <= end_idx, | |
AnnotationInterval.end_index >= start_idx, | |
) | |
.first() | |
) | |
if overlap: | |
raise ValueError( | |
f"Overlap with existing interval " | |
f"[{overlap.start_index}, {overlap.end_index}]" | |
) | |
interval = AnnotationInterval( | |
annotator_id=annotator_id, | |
start_index=start_idx, | |
end_index=end_idx, | |
) | |
self.db.add(interval) | |
self.db.flush() | |
self.db.refresh(interval) | |
log.info( | |
f"Interval [{start_idx}, {end_idx}] assigned to annotator_id={annotator_id}" | |
) | |
return interval | |
# ------------------------------------------------------------------ # | |
# READ | |
# ------------------------------------------------------------------ # | |
def get_intervals_by_annotator(self, annotator_id: int) -> List[AnnotationInterval]: | |
return ( | |
self.db.query(AnnotationInterval) | |
.filter(AnnotationInterval.annotator_id == annotator_id) | |
.all() | |
) | |
def get_interval( | |
self, annotator_id: int, start_idx: int, end_idx: int | |
) -> Optional[AnnotationInterval]: | |
return ( | |
self.db.query(AnnotationInterval) | |
.filter( | |
AnnotationInterval.annotator_id == annotator_id, | |
AnnotationInterval.start_index == start_idx, | |
AnnotationInterval.end_index == end_idx, | |
) | |
.first() | |
) | |