from typing import List, Optional from sqlalchemy.orm import Session from data.models import AnnotationInterval from utils.logger import Logger log = Logger() class AnnotationIntervalRepo: """ Data-access layer for `annotation_intervals` table. """ def __init__(self, db: Session) -> None: self.db = db # ------------------------------------------------------------------ # # CREATE # ------------------------------------------------------------------ # def assign_interval_to_annotator( self, annotator_id: int, start_idx: int, end_idx: int, allow_overlap: bool = False, ) -> AnnotationInterval: """ Create a new interval [start_idx, end_idx] for the given annotator. Raises: ValueError: • start >= end • overlap detected (when allow_overlap=False) """ if start_idx >= end_idx: raise ValueError("start_idx must be < end_idx") if not allow_overlap: overlap = ( self.db.query(AnnotationInterval) .filter(AnnotationInterval.annotator_id == annotator_id) .filter( AnnotationInterval.start_index <= end_idx, AnnotationInterval.end_index >= start_idx, ) .first() ) if overlap: raise ValueError( f"Overlap with existing interval " f"[{overlap.start_index}, {overlap.end_index}]" ) interval = AnnotationInterval( annotator_id=annotator_id, start_index=start_idx, end_index=end_idx, ) self.db.add(interval) self.db.flush() self.db.refresh(interval) log.info( f"Interval [{start_idx}, {end_idx}] assigned to annotator_id={annotator_id}" ) return interval # ------------------------------------------------------------------ # # READ # ------------------------------------------------------------------ # def get_intervals_by_annotator(self, annotator_id: int) -> List[AnnotationInterval]: return ( self.db.query(AnnotationInterval) .filter(AnnotationInterval.annotator_id == annotator_id) .all() ) def get_interval( self, annotator_id: int, start_idx: int, end_idx: int ) -> Optional[AnnotationInterval]: return ( self.db.query(AnnotationInterval) .filter( AnnotationInterval.annotator_id == annotator_id, AnnotationInterval.start_index == start_idx, AnnotationInterval.end_index == end_idx, ) .first() )