From bb40503b7ba8d6893f3663098f9fd62c7f03ce11 Mon Sep 17 00:00:00 2001
From: Tim Repke <repke@mcc-berlin.net>
Date: Mon, 18 Dec 2023 14:00:49 +0100
Subject: [PATCH] add count nql

---
 server/api/routes/search.py | 39 +++++++++++++++++++++++--------------
 1 file changed, 24 insertions(+), 15 deletions(-)

diff --git a/server/api/routes/search.py b/server/api/routes/search.py
index f47a0f3..0c0f2f9 100644
--- a/server/api/routes/search.py
+++ b/server/api/routes/search.py
@@ -1,25 +1,22 @@
-from typing import TYPE_CHECKING
-
 import httpx
-from nacsos_data.db.schemas import Project, ItemType
 from pydantic import BaseModel
 from fastapi import APIRouter, Depends
 import sqlalchemy.sql.functions as func
 from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
 
+from nacsos_data.db.engine import ensure_session
+from nacsos_data.db.schemas import Project, ItemType
+from nacsos_data.util.nql import NQLQuery, NQLFilter
 from nacsos_data.util.academic.openalex import query_async, SearchResult
 from nacsos_data.models.items import AcademicItemModel, FullLexisNexisItemModel, GenericItemModel
 from nacsos_data.models.openalex.solr import SearchField, DefType, OpType
-from nacsos_data.util.nql import NQLQuery, NQLFilter
 
 from server.util.security import UserPermissionChecker, UserPermissions
 from server.util.logging import get_logger
 from server.util.config import settings
 from server.data import db_engine
 
-if TYPE_CHECKING:
-    from sqlalchemy.ext.asyncio import AsyncSession  # noqa F401
-
 router = APIRouter()
 
 logger = get_logger('nacsos.api.route.search')
@@ -88,22 +85,34 @@ class QueryResult(BaseModel):
     docs: list[AcademicItemModel] | list[FullLexisNexisItemModel] | list[GenericItemModel]
 
 
+@ensure_session
+async def _get_query(session: AsyncSession, query: NQLFilter, project_id: str) -> NQLQuery:
+    project_type: ItemType | None = (
+        await session.scalar(select(Project.type).where(Project.project_id == project_id)))
+
+    if project_type is None:
+        raise KeyError(f'Found no matching project for {project_id}. This should NEVER happen!')
+
+    return NQLQuery(query, project_id=str(project_id), project_type=project_type)
+
+
 @router.post('/nql/query', response_model=QueryResult)
 async def nql_query(query: NQLFilter,
                     page: int = 1,
                     limit: int = 20,
                     permissions: UserPermissions = Depends(UserPermissionChecker('dataset_read'))) -> QueryResult:
     async with db_engine.session() as session:  # type: AsyncSession
-        project_id = permissions.permissions.project_id
-        project_type: ItemType | None = (
-            await session.scalar(select(Project.type).where(Project.project_id == project_id)))
-
-        if project_type is None:
-            raise KeyError(f'Found no matching project for {project_id}. This should NEVER happen!')
-
-        nql = NQLQuery(query, project_id=str(project_id), project_type=project_type)
+        nql = await _get_query(session=session, query=query, project_id=permissions.permissions.project_id)
 
         n_docs = (await session.execute(func.count(nql.stmt.subquery().c.item_id))).scalar()
         docs = await nql.results_async(session=session, limit=limit, offset=(page - 1) * limit)
 
         return QueryResult(n_docs=n_docs, docs=docs)  # type: ignore[arg-type]
+
+
+@router.post('/nql/count', response_model=int)
+async def nql_query_count(query: NQLFilter,
+                          permissions: UserPermissions = Depends(UserPermissionChecker('dataset_read'))) -> QueryResult:
+    async with db_engine.session() as session:  # type: AsyncSession
+        nql = await _get_query(session=session, query=query, project_id=permissions.permissions.project_id)
+        return (await session.execute(func.count(nql.stmt.subquery().c.item_id))).scalar()
-- 
GitLab