import uuid from typing import Literal from fastapi import APIRouter, BackgroundTasks, Depends from nacsos_data.db.crud import upsert_orm from nacsos_data.db.schemas import AnnotationTracker, AssignmentScope, AnnotationScheme, BotAnnotationMetaData, \ AnnotationQuality from nacsos_data.models.annotation_quality import AnnotationQualityModel from nacsos_data.models.annotation_tracker import AnnotationTrackerModel, DehydratedAnnotationTracker from nacsos_data.models.bot_annotations import BotAnnotationMetaDataBaseModel from nacsos_data.util.annotations.evaluation import get_new_label_batches from nacsos_data.util.annotations.evaluation.buscar import ( calculate_h0s_for_batches, compute_recall, calculate_h0s) from nacsos_data.util.annotations.evaluation.irr import compute_irr_scores from nacsos_data.util.annotations.label_transform import annotations_to_sequence, get_annotations from nacsos_data.util.auth import UserPermissions from pydantic import BaseModel from sqlalchemy import select, String, literal, delete from server.data import db_engine from server.api.errors import DataNotFoundWarning from server.util.logging import get_logger from server.util.security import UserPermissionChecker from sqlalchemy.ext.asyncio import AsyncSession logger = get_logger('nacsos.api.route.eval') logger.debug('Setup nacsos.api.route.eval router') router = APIRouter() class LabelScope(BaseModel): scope_id: str name: str scope_type: Literal['H', 'R'] @router.get('/tracking/scopes', response_model=list[LabelScope]) async def get_project_scopes(permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> list[LabelScope]: async with db_engine.session() as session: # type: AsyncSession stmt = (select(AssignmentScope.assignment_scope_id.cast(String).label('scope_id'), AssignmentScope.name, literal('H', type_=String).label('scope_type')) .join(AnnotationScheme, AnnotationScheme.annotation_scheme_id == AssignmentScope.annotation_scheme_id) .where(AnnotationScheme.project_id == permissions.permissions.project_id) .order_by(AssignmentScope.time_created)) rslt = (await session.execute(stmt)).mappings().all() assignment_scopes = [LabelScope.model_validate(r) for r in rslt] stmt = (select(BotAnnotationMetaData.bot_annotation_metadata_id.cast(String).label('scope_id'), BotAnnotationMetaData.name, literal('R', type_=String).label('scope_type')) .where(BotAnnotationMetaData.project_id == permissions.permissions.project_id) .order_by(BotAnnotationMetaData.time_created)) rslt = (await session.execute(stmt)).mappings().all() resolution_scopes = [LabelScope.model_validate(r) for r in rslt] return assignment_scopes + resolution_scopes @router.get('/resolutions', response_model=list[BotAnnotationMetaDataBaseModel]) async def get_resolutions_for_scope(assignment_scope_id: str, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> list[BotAnnotationMetaDataBaseModel]: async with db_engine.session() as session: # type: AsyncSession stmt = (select(BotAnnotationMetaData) .where(BotAnnotationMetaData.assignment_scope_id == assignment_scope_id)) rslt = (await session.execute(stmt)).scalars().all() return [BotAnnotationMetaDataBaseModel.model_validate(r.__dict__) for r in rslt] async def read_tracker(session: AsyncSession, tracker_id: str | uuid.UUID, project_id: str | uuid.UUID | None = None) -> AnnotationTracker: stmt = (select(AnnotationTracker) .where(AnnotationTracker.annotation_tracking_id == tracker_id)) rslt = (await session.scalars(stmt)).one_or_none() if rslt is None: raise DataNotFoundWarning(f'No Tracker in project {project_id} for id {tracker_id}!') return rslt @router.get('/tracking/trackers', response_model=list[DehydratedAnnotationTracker]) async def get_project_trackers(permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> list[DehydratedAnnotationTracker]: async with db_engine.session() as session: # type: AsyncSession stmt = (select(AnnotationTracker.name, AnnotationTracker.annotation_tracking_id) .where(AnnotationTracker.project_id == permissions.permissions.project_id)) rslt = (await session.execute(stmt)).mappings().all() return [DehydratedAnnotationTracker.model_validate(r) for r in rslt] @router.get('/tracking/tracker/{tracker_id}', response_model=AnnotationTrackerModel) async def get_tracker(tracker_id: str, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> AnnotationTrackerModel: async with db_engine.session() as session: # type: AsyncSession tracker = await read_tracker(tracker_id=tracker_id, session=session, project_id=permissions.permissions.project_id) return AnnotationTrackerModel.model_validate(tracker.__dict__) @router.put('/tracking/tracker', response_model=str) async def save_tracker(tracker: AnnotationTrackerModel, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) -> str: pkey = await upsert_orm(upsert_model=tracker, Schema=AnnotationTracker, primary_key='annotation_tracking_id', db_engine=db_engine, skip_update=['labels', 'recall', 'buscar']) return str(pkey) @router.post('/tracking/refresh', response_model=AnnotationTrackerModel) async def update_tracker(tracker_id: str, background_tasks: BackgroundTasks, batch_size: int | None = None, reset: bool = False, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \ -> AnnotationTrackerModel: async with db_engine.session() as session: # type: AsyncSession tracker = await read_tracker(tracker_id=tracker_id, session=session, project_id=permissions.permissions.project_id) batched_annotations = [await get_annotations(session=session, source_ids=[sid]) for sid in tracker.source_ids] batched_sequence = [annotations_to_sequence(tracker.inclusion_rule, annotations=annotations, majority=tracker.majority) for annotations in batched_annotations if len(annotations) > 0] diff: list[list[int]] | None = None if reset: tracker.buscar = None tracker.recall = None elif tracker.labels is not None: diff = get_new_label_batches(tracker.labels, batched_sequence) # Update labels tracker.labels = batched_sequence await session.flush() # We are not handing over the existing tracker ORM, because the session is not persistent background_tasks.add_task(bg_populate_tracker, tracker_id, batch_size, diff) return AnnotationTrackerModel.model_validate(tracker.__dict__) async def bg_populate_tracker(tracker_id: str, batch_size: int | None = None, labels: list[list[int]] | None = None): async with db_engine.session() as session: # type: AsyncSession tracker = await read_tracker(tracker_id=tracker_id, session=session) if labels is None or len(labels) == 0: labels = tracker.labels if labels is not None: flat_labels = [lab for batch in labels for lab in batch] recall = compute_recall(labels_=flat_labels) if tracker.recall is None: tracker.recall = recall else: tracker.recall += recall await session.flush() # Initialise buscar scores if tracker.buscar is None: tracker.buscar = [] if batch_size is None: # Use scopes as batches it = calculate_h0s_for_batches(labels=tracker.labels, recall_target=tracker.recall_target, n_docs=tracker.n_items_total) else: # Ignore the batches derived from scopes and use fixed step sizes it = calculate_h0s(labels_=flat_labels, batch_size=batch_size, recall_target=tracker.recall_target, n_docs=tracker.n_items_total) for x, y in it: tracker.buscar = tracker.buscar + [(x, y)] # save after each step, so the user can refresh the page and get data as it becomes available await session.flush() @router.get('/quality/load/{assignment_scope_id}', response_model=list[AnnotationQualityModel]) async def get_irr(assignment_scope_id: str, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> list[AnnotationQualityModel]: async with db_engine.session() as session: # type: AsyncSession results = ( await session.execute(select(AnnotationQuality) .where(AnnotationQuality.assignment_scope_id == assignment_scope_id)) ).scalars().all() return [AnnotationQualityModel(**r.__dict__) for r in results] @router.get('/quality/compute', response_model=list[AnnotationQualityModel]) async def recompute_irr(assignment_scope_id: str, bot_annotation_metadata_id: str | None = None, permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \ -> list[AnnotationQualityModel]: async with db_engine.session() as session: # type: AsyncSession # Delete existing metrics await session.execute(delete(AnnotationQuality) .where(AnnotationQuality.assignment_scope_id == assignment_scope_id)) # Compute new metrics metrics = await compute_irr_scores(session=session, assignment_scope_id=assignment_scope_id, resolution_id=bot_annotation_metadata_id, project_id=permissions.permissions.project_id) metrics_orm = [AnnotationQuality(**metric.model_dump()) for metric in metrics] session.add_all(metrics_orm) await session.commit() return await get_irr(assignment_scope_id, permissions)