From bb40503b7ba8d6893f3663098f9fd62c7f03ce11 Mon Sep 17 00:00:00 2001 From: Tim Repke <repke@mcc-berlin.net> Date: Mon, 18 Dec 2023 14:00:49 +0100 Subject: [PATCH] add count nql --- server/api/routes/search.py | 39 +++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/server/api/routes/search.py b/server/api/routes/search.py index f47a0f3..0c0f2f9 100644 --- a/server/api/routes/search.py +++ b/server/api/routes/search.py @@ -1,25 +1,22 @@ -from typing import TYPE_CHECKING - import httpx -from nacsos_data.db.schemas import Project, ItemType from pydantic import BaseModel from fastapi import APIRouter, Depends import sqlalchemy.sql.functions as func from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from nacsos_data.db.engine import ensure_session +from nacsos_data.db.schemas import Project, ItemType +from nacsos_data.util.nql import NQLQuery, NQLFilter from nacsos_data.util.academic.openalex import query_async, SearchResult from nacsos_data.models.items import AcademicItemModel, FullLexisNexisItemModel, GenericItemModel from nacsos_data.models.openalex.solr import SearchField, DefType, OpType -from nacsos_data.util.nql import NQLQuery, NQLFilter from server.util.security import UserPermissionChecker, UserPermissions from server.util.logging import get_logger from server.util.config import settings from server.data import db_engine -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession # noqa F401 - router = APIRouter() logger = get_logger('nacsos.api.route.search') @@ -88,22 +85,34 @@ class QueryResult(BaseModel): docs: list[AcademicItemModel] | list[FullLexisNexisItemModel] | list[GenericItemModel] +@ensure_session +async def _get_query(session: AsyncSession, query: NQLFilter, project_id: str) -> NQLQuery: + project_type: ItemType | None = ( + await session.scalar(select(Project.type).where(Project.project_id == project_id))) + + if project_type is None: + raise KeyError(f'Found no matching project for {project_id}. This should NEVER happen!') + + return NQLQuery(query, project_id=str(project_id), project_type=project_type) + + @router.post('/nql/query', response_model=QueryResult) async def nql_query(query: NQLFilter, page: int = 1, limit: int = 20, permissions: UserPermissions = Depends(UserPermissionChecker('dataset_read'))) -> QueryResult: async with db_engine.session() as session: # type: AsyncSession - project_id = permissions.permissions.project_id - project_type: ItemType | None = ( - await session.scalar(select(Project.type).where(Project.project_id == project_id))) - - if project_type is None: - raise KeyError(f'Found no matching project for {project_id}. This should NEVER happen!') - - nql = NQLQuery(query, project_id=str(project_id), project_type=project_type) + nql = await _get_query(session=session, query=query, project_id=permissions.permissions.project_id) n_docs = (await session.execute(func.count(nql.stmt.subquery().c.item_id))).scalar() docs = await nql.results_async(session=session, limit=limit, offset=(page - 1) * limit) return QueryResult(n_docs=n_docs, docs=docs) # type: ignore[arg-type] + + +@router.post('/nql/count', response_model=int) +async def nql_query_count(query: NQLFilter, + permissions: UserPermissions = Depends(UserPermissionChecker('dataset_read'))) -> QueryResult: + async with db_engine.session() as session: # type: AsyncSession + nql = await _get_query(session=session, query=query, project_id=permissions.permissions.project_id) + return (await session.execute(func.count(nql.stmt.subquery().c.item_id))).scalar() -- GitLab