tts_labeling / data /repository /annotator_workload_repo.py
vargha's picture
alligned interface and data import scripts
f7ef7d3
from typing import List, Dict, Optional, Any
from sqlalchemy import and_
from sqlalchemy.orm import Session
from data.models import TTSData, Annotation, AnnotationInterval
from data.repository.annotator_repo import AnnotatorRepo
from utils.logger import Logger
log = Logger()
class AnnotatorWorkloadRepo:
def __init__(self, db: Session) -> None:
self.db = db
self.annotator_repo = AnnotatorRepo(db)
def get_tts_data_with_annotations(
self, annotator_name: str
) -> List[Dict[str, Optional[Any]]]:
# This method is kept for compatibility if used elsewhere, but
# get_tts_data_with_annotations_for_user_id is preferred for new logic.
annotator = self.annotator_repo.get_annotator_by_name(annotator_name)
if annotator is None:
log.warning(
f"Annotator '{annotator_name}' not found in get_tts_data_with_annotations. Returning empty list."
)
return []
return self.get_tts_data_with_annotations_for_user_id(
annotator.id, annotator_name
)
def get_tts_data_with_annotations_for_user_id(
self, annotator_id: int, annotator_name_for_log: str = "Unknown"
) -> List[Dict[str, Optional[Any]]]:
"""
output: [
{"tts_data": <TTSData>, "annotation": <Annotation or None>},
...
]
"""
query = (
self.db.query(
TTSData,
Annotation,
)
.join(
AnnotationInterval,
and_(
AnnotationInterval.annotator_id == annotator_id,
TTSData.id >= AnnotationInterval.start_index,
TTSData.id <= AnnotationInterval.end_index,
),
)
.outerjoin(
Annotation,
and_(
Annotation.tts_data_id == TTSData.id,
Annotation.annotator_id == annotator_id,
),
)
.order_by(TTSData.id)
).distinct(TTSData.id) # Ensure distinct TTSData items
rows = [{"tts_data": tts, "annotation": ann} for tts, ann in query.all()]
log.info(
f"{len(rows)} TTS rows fetched for annotator ID '{annotator_id}' (Name: {annotator_name_for_log})."
)
return rows