import uuid

from fastapi import APIRouter, BackgroundTasks, Depends
from nacsos_data.db.crud import upsert_orm
from nacsos_data.db.schemas import AnnotationTracker
from nacsos_data.models.annotation_tracker import AnnotationTrackerModel
from nacsos_data.util.annotations.evaluation import get_new_label_batches
from nacsos_data.util.annotations.evaluation.buscar import (
    calculate_h0s_for_batches,
    compute_recall,
    calculate_h0s)
from nacsos_data.util.annotations.evaluation.label_transform import annotations_to_sequence, get_annotations
from nacsos_data.util.auth import UserPermissions
from sqlalchemy import select

from server.data import db_engine
from server.api.errors import DataNotFoundWarning
from server.util.logging import get_logger
from server.util.security import UserPermissionChecker

from sqlalchemy.ext.asyncio import AsyncSession

logger = get_logger('nacsos.api.route.eval')
logger.debug('Setup nacsos.api.route.eval router')

router = APIRouter()


async def read_tracker(session: AsyncSession, tracker_id: str | uuid.UUID,
                       project_id: str | uuid.UUID | None = None) -> AnnotationTracker:
    stmt = (select(AnnotationTracker)
            .where(AnnotationTracker.annotation_tracking_id == tracker_id))
    rslt = (await session.scalars(stmt)).one_or_none()
    if rslt is None:
        raise DataNotFoundWarning(f'No Tracker in project {project_id} for id {tracker_id}!')
    return rslt


@router.get('/tracking/tracker/{tracker_id}', response_model=AnnotationTrackerModel)
async def get_tracker(tracker_id: str,
                      permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) \
        -> AnnotationTrackerModel:
    async with db_engine.session() as session:  # type: AsyncSession
        return AnnotationTrackerModel.model_validate(read_tracker(tracker_id=tracker_id, session=session,
                                                                  project_id=permissions.permissions.project_id))


@router.put('/tracking/tracker', response_model=str)
async def save_tracker(tracker: AnnotationTrackerModel,
                       permissions: UserPermissions = Depends(UserPermissionChecker('annotations_read'))) -> str:
    pkey = await upsert_orm(upsert_model=tracker, Schema=AnnotationTracker,
                            primary_key='annotation_tracking_id', db_engine=db_engine,
                            skip_update=['labels', 'recall', 'buscar'])
    return str(pkey)


@router.post('/tracking/refresh', response_model=AnnotationTrackerModel)
async def update_tracker(tracker_id: str,
                         background_tasks: BackgroundTasks,
                         batch_size: int | None = None,
                         reset: bool = False,
                         permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \
        -> AnnotationTrackerModel:
    async with db_engine.session() as session:  # type: AsyncSession
        tracker = await read_tracker(tracker_id=tracker_id, session=session,
                                     project_id=permissions.permissions.project_id)

        batched_annotations = [await get_annotations(session=session, source_ids=[sid])
                               for sid in tracker.source_ids]
        batched_sequence = [annotations_to_sequence(tracker.inclusion_rule, annotations=annotations,
                                                    majority=tracker.majority)
                            for annotations in batched_annotations]

        diff: list[list[int]] | None = None
        if reset:
            tracker.buscar = None
            tracker.recall = None
        elif tracker.labels is not None:
            diff = get_new_label_batches(tracker.labels, batched_sequence)

        # Update labels
        tracker.labels = batched_sequence
        await session.commit()

        # We are not handing over the existing tracker ORM, because the session is not persistent
        background_tasks.add_task(bg_populate_tracker, tracker_id, batch_size, diff)

        return AnnotationTrackerModel.model_validate(tracker)


async def bg_populate_tracker(tracker_id: str, batch_size: int | None = None, labels: list[list[int]] | None = None):
    async with db_engine.session() as session:  # type: AsyncSession
        tracker = await read_tracker(tracker_id=tracker_id, session=session)

        if labels is None:
            labels = tracker.labels

        flat_labels = [lab for batch in labels for lab in batch]

        recall = compute_recall(labels_=flat_labels)
        if tracker.recall is None:
            tracker.recall = recall
        else:
            tracker.recall += recall

        await session.commit()

        # Initialise buscar scores
        if tracker.buscar is None:
            tracker.buscar = []

        if batch_size is None:
            # Use scopes as batches
            it = calculate_h0s_for_batches(labels_=tracker.labels,
                                           recall_target=tracker.recall_target,
                                           n_docs=tracker.n_items_total)
        else:
            # Ignore the batches derived from scopes and use fixed step sizes
            it = calculate_h0s(labels_=flat_labels,
                               batch_size=batch_size,
                               recall_target=tracker.recall_target,
                               n_docs=tracker.n_items_total)

        for x, y in it:
            tracker.buscar = tracker.buscar + [(x, y)]
            # save after each step, so the user can refresh the page and get data as it becomes available
            await session.commit()