evaluation.py 10.72 KiB
import uuid
from typing import Literal
from fastapi import APIRouter, BackgroundTasks, Depends
from nacsos_data.db.crud import upsert_orm
from nacsos_data.db.schemas import AnnotationTracker, AssignmentScope, AnnotationScheme, BotAnnotationMetaData, \
AnnotationQuality
from nacsos_data.models.annotation_quality import AnnotationQualityModel
from nacsos_data.models.annotation_tracker import AnnotationTrackerModel, DehydratedAnnotationTracker
from nacsos_data.models.bot_annotations import BotAnnotationMetaDataBaseModel
from nacsos_data.util.annotations.evaluation import get_new_label_batches
from nacsos_data.util.annotations.evaluation.buscar import (
calculate_h0s_for_batches,
compute_recall,
calculate_h0s)
from nacsos_data.util.annotations.evaluation.irr import compute_irr_scores
from nacsos_data.util.annotations.label_transform import annotations_to_sequence, get_annotations
from nacsos_data.util.auth import UserPermissions
from pydantic import BaseModel
from sqlalchemy import select, String, literal, delete
from server.data import db_engine
from server.api.errors import DataNotFoundWarning
from server.util.logging import get_logger
from server.util.security import UserPermissionChecker
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger('nacsos.api.route.eval')
logger.debug('Setup nacsos.api.route.eval router')
router = APIRouter()
class LabelScope(BaseModel):
scope_id: str
name: str
scope_type: Literal['H', 'R']
@router.get('/tracking/scopes', response_model=list[LabelScope])
async def get_project_scopes(permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> list[LabelScope]:
async with db_engine.session() as session: # type: AsyncSession
stmt = (select(AssignmentScope.assignment_scope_id.cast(String).label('scope_id'),
AssignmentScope.name,
literal('H', type_=String).label('scope_type'))
.join(AnnotationScheme, AnnotationScheme.annotation_scheme_id == AssignmentScope.annotation_scheme_id)
.where(AnnotationScheme.project_id == permissions.permissions.project_id)
.order_by(AssignmentScope.time_created))
rslt = (await session.execute(stmt)).mappings().all()
assignment_scopes = [LabelScope.model_validate(r) for r in rslt]
stmt = (select(BotAnnotationMetaData.bot_annotation_metadata_id.cast(String).label('scope_id'),
BotAnnotationMetaData.name,
literal('R', type_=String).label('scope_type'))
.where(BotAnnotationMetaData.project_id == permissions.permissions.project_id)
.order_by(BotAnnotationMetaData.time_created))
rslt = (await session.execute(stmt)).mappings().all()
resolution_scopes = [LabelScope.model_validate(r) for r in rslt]
return assignment_scopes + resolution_scopes
@router.get('/resolutions', response_model=list[BotAnnotationMetaDataBaseModel])
async def get_resolutions_for_scope(assignment_scope_id: str,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> list[BotAnnotationMetaDataBaseModel]:
async with db_engine.session() as session: # type: AsyncSession
stmt = (select(BotAnnotationMetaData)
.where(BotAnnotationMetaData.assignment_scope_id == assignment_scope_id))
rslt = (await session.execute(stmt)).scalars().all()
return [BotAnnotationMetaDataBaseModel.model_validate(r.__dict__) for r in rslt]
async def read_tracker(session: AsyncSession, tracker_id: str | uuid.UUID,
project_id: str | uuid.UUID | None = None) -> AnnotationTracker:
stmt = (select(AnnotationTracker)
.where(AnnotationTracker.annotation_tracking_id == tracker_id))
rslt = (await session.scalars(stmt)).one_or_none()
if rslt is None:
raise DataNotFoundWarning(f'No Tracker in project {project_id} for id {tracker_id}!')
return rslt
@router.get('/tracking/trackers', response_model=list[DehydratedAnnotationTracker])
async def get_project_trackers(permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> list[DehydratedAnnotationTracker]:
async with db_engine.session() as session: # type: AsyncSession
stmt = (select(AnnotationTracker.name, AnnotationTracker.annotation_tracking_id)
.where(AnnotationTracker.project_id == permissions.permissions.project_id))
rslt = (await session.execute(stmt)).mappings().all()
return [DehydratedAnnotationTracker.model_validate(r) for r in rslt]
@router.get('/tracking/tracker/{tracker_id}', response_model=AnnotationTrackerModel)
async def get_tracker(tracker_id: str,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> AnnotationTrackerModel:
async with db_engine.session() as session: # type: AsyncSession
tracker = await read_tracker(tracker_id=tracker_id, session=session,
project_id=permissions.permissions.project_id)
return AnnotationTrackerModel.model_validate(tracker.__dict__)
@router.put('/tracking/tracker', response_model=str)
async def save_tracker(tracker: AnnotationTrackerModel,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) -> str:
pkey = await upsert_orm(upsert_model=tracker, Schema=AnnotationTracker,
primary_key='annotation_tracking_id', db_engine=db_engine,
skip_update=['labels', 'recall', 'buscar'])
return str(pkey)
@router.post('/tracking/refresh', response_model=AnnotationTrackerModel)
async def update_tracker(tracker_id: str,
background_tasks: BackgroundTasks,
batch_size: int | None = None,
reset: bool = False,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \
-> AnnotationTrackerModel:
async with db_engine.session() as session: # type: AsyncSession
tracker = await read_tracker(tracker_id=tracker_id, session=session,
project_id=permissions.permissions.project_id)
batched_annotations = [await get_annotations(session=session, source_ids=[sid])
for sid in tracker.source_ids]
batched_sequence = [annotations_to_sequence(tracker.inclusion_rule, annotations=annotations,
majority=tracker.majority)
for annotations in batched_annotations
if len(annotations) > 0]
diff: list[list[int]] | None = None
if reset:
tracker.buscar = None
tracker.recall = None
elif tracker.labels is not None:
diff = get_new_label_batches(tracker.labels, batched_sequence)
# Update labels
tracker.labels = batched_sequence
await session.flush()
# We are not handing over the existing tracker ORM, because the session is not persistent
background_tasks.add_task(bg_populate_tracker, tracker_id, batch_size, diff)
return AnnotationTrackerModel.model_validate(tracker.__dict__)
async def bg_populate_tracker(tracker_id: str, batch_size: int | None = None, labels: list[list[int]] | None = None):
async with db_engine.session() as session: # type: AsyncSession
tracker = await read_tracker(tracker_id=tracker_id, session=session)
if labels is None or len(labels) == 0:
labels = tracker.labels
if labels is not None:
flat_labels = [lab for batch in labels for lab in batch]
recall = compute_recall(labels_=flat_labels)
if tracker.recall is None:
tracker.recall = recall
else:
tracker.recall += recall
await session.flush()
# Initialise buscar scores
if tracker.buscar is None:
tracker.buscar = []
if batch_size is None:
# Use scopes as batches
it = calculate_h0s_for_batches(labels=tracker.labels,
recall_target=tracker.recall_target,
n_docs=tracker.n_items_total)
else:
# Ignore the batches derived from scopes and use fixed step sizes
it = calculate_h0s(labels_=flat_labels,
batch_size=batch_size,
recall_target=tracker.recall_target,
n_docs=tracker.n_items_total)
for x, y in it:
tracker.buscar = tracker.buscar + [(x, y)]
# save after each step, so the user can refresh the page and get data as it becomes available
await session.flush()
@router.get('/quality/load/{assignment_scope_id}', response_model=list[AnnotationQualityModel])
async def get_irr(assignment_scope_id: str,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> list[AnnotationQualityModel]:
async with db_engine.session() as session: # type: AsyncSession
results = (
await session.execute(select(AnnotationQuality)
.where(AnnotationQuality.assignment_scope_id == assignment_scope_id))
).scalars().all()
return [AnnotationQualityModel(**r.__dict__) for r in results]
@router.get('/quality/compute', response_model=list[AnnotationQualityModel])
async def recompute_irr(assignment_scope_id: str,
bot_annotation_metadata_id: str | None = None,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
-> list[AnnotationQualityModel]:
async with db_engine.session() as session: # type: AsyncSession
# Delete existing metrics
await session.execute(delete(AnnotationQuality)
.where(AnnotationQuality.assignment_scope_id == assignment_scope_id))
# Compute new metrics
metrics = await compute_irr_scores(session=session,
assignment_scope_id=assignment_scope_id,
resolution_id=bot_annotation_metadata_id,
project_id=permissions.permissions.project_id)
metrics_orm = [AnnotationQuality(**metric.model_dump()) for metric in metrics]
session.add_all(metrics_orm)
await session.commit()
return await get_irr(assignment_scope_id, permissions)