Skip to content
Snippets Groups Projects
Commit 56692a3f authored by Tim Repke's avatar Tim Repke
Browse files

add nql route

parent 2f2b29b0
No related branches found
No related tags found
1 merge request!55Master
Pipeline #1902 passed
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
import sqlalchemy.sql.functions as func
from typing import TYPE_CHECKING
from nacsos_data.util.academic.openalex import query_async, SearchResult, SearchField, DefType, OpType from nacsos_data.util.academic.openalex import query_async, SearchResult, SearchField, DefType, OpType
from nacsos_data.db.crud.items import Query
from nacsos_data.models.items import AcademicItemModel
from server.util.security import UserPermissionChecker, UserPermissions from server.util.security import UserPermissionChecker, UserPermissions
from server.util.logging import get_logger from server.util.logging import get_logger
from server.util.config import settings from server.util.config import settings
from server.data import db_engine
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession # noqa F401
logger = get_logger('nacsos.api.route.search')
router = APIRouter() router = APIRouter()
logger = get_logger('nacsos.api.route.search')
logger.info('Setting up academic search route') logger.info('Setting up academic search route')
...@@ -69,3 +77,24 @@ async def term_expansion(term_prefix: str, ...@@ -69,3 +77,24 @@ async def term_expansion(term_prefix: str,
ttf=terms[i + 1]['ttf']) ttf=terms[i + 1]['ttf'])
for i in range(0, len(terms), 2) for i in range(0, len(terms), 2)
] ]
class QueryResult(BaseModel):
n_docs: int
docs: list[AcademicItemModel]
@router.get('/nql/query', response_model=QueryResult)
async def nql_query(query: str,
limit: int = 20,
permissions: UserPermissions = Depends(UserPermissionChecker('dataset_read'))) -> QueryResult:
q = Query(query, project_id=permissions.permissions.project_id)
async with db_engine.session() as session: # type: AsyncSession
stmt = q.stmt.subquery()
cnt_stmt = func.count(stmt.c.item_id)
return QueryResult(
n_docs=(await session.execute(cnt_stmt)).scalar(), # type: ignore[arg-type]
docs=[AcademicItemModel.model_validate(item.__dict__)
for item in (await session.execute(q.stmt.limit(limit))).scalars().all()]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment