Skip to content
Snippets Groups Projects
Commit 23bcbf08 authored by Tim Repke's avatar Tim Repke
Browse files

fixing mypy

parent d78938a2
No related branches found
No related tags found
1 merge request!1Mypy fix
Pipeline #766 failed
......@@ -3,4 +3,5 @@ tox==3.25.1
pytest==7.1.2
pytest-cov==3.0.0
mypy==0.971
alembic==1.8.1
\ No newline at end of file
alembic==1.8.1
types-PyYAML==6.0.11
\ No newline at end of file
class DataNotFoundWarning(Warning):
pass
class ProjectNotFoundError(Exception):
status = 400
class UserNotFoundError(Exception):
pass
class AnnotationSchemeNotFoundError(Exception):
pass
class NoNextAssignmentWarning(Warning):
pass
class AssignmentScopeNotFoundError(Exception):
pass
class SaveFailedError(Exception):
pass
class UnknownEventError(Exception):
pass
class MissingInformationError(Exception):
pass
......@@ -34,6 +34,9 @@ from nacsos_data.util.annotations.validation import merge_scheme_and_annotations
from nacsos_data.util.annotations.assignments.random import random_assignments
from pydantic import BaseModel
from server.api.errors import SaveFailedError, AssignmentScopeNotFoundError, NoNextAssignmentWarning, \
ProjectNotFoundError, AnnotationSchemeNotFoundError, MissingInformationError
from server.util.security import UserPermissionChecker
from server.data import db_engine
......@@ -58,7 +61,10 @@ async def get_scheme_definition(annotation_scheme_id: str) -> AnnotationSchemeMo
:param annotation_scheme_id: database id of the annotation scheme.
:return: a single annotation scheme
"""
return await read_annotation_scheme(annotation_scheme_id=annotation_scheme_id, engine=db_engine)
scheme = await read_annotation_scheme(annotation_scheme_id=annotation_scheme_id, engine=db_engine)
if scheme is not None:
return scheme
raise AnnotationSchemeNotFoundError(f'No `AnnotationScheme` found in DB for id {annotation_scheme_id}')
@router.put('/schemes/definition/', response_model=str)
......@@ -85,16 +91,24 @@ async def get_scheme_definitions_for_project(project_id: str) -> list[Annotation
async def _construct_annotation_item(assignment: AssignmentModel, project_id: str) -> AnnotationItem:
if assignment.assignment_id is None:
raise MissingInformationError('No `assignment_id` set for `assignment`.')
scope = await read_assignment_scope(assignment_scope_id=assignment.assignment_scope_id, engine=db_engine)
scheme = await read_annotation_scheme(annotation_scheme_id=assignment.annotation_scheme_id, engine=db_engine)
if scheme is None:
raise AnnotationSchemeNotFoundError(f'No annotation scheme found in DB for id '
f'{assignment.annotation_scheme_id}')
annotations = await read_annotations_for_assignment(assignment_id=assignment.assignment_id, engine=db_engine)
scheme = merge_scheme_and_annotations(annotation_scheme=scheme, annotations=annotations)
merged_scheme = merge_scheme_and_annotations(annotation_scheme=scheme, annotations=annotations)
project = await read_project_by_id(project_id=project_id, engine=db_engine)
if project is None:
raise ProjectNotFoundError(f'No project found in DB for id {project_id}')
item = await read_any_item_by_item_id(item_id=assignment.item_id, item_type=project.type, engine=db_engine)
return AnnotationItem(scheme=scheme, assignment=assignment, scope=scope, item=item)
return AnnotationItem(scheme=merged_scheme, assignment=assignment, scope=scope, item=item)
@router.get('/annotate/next/{assignment_scope_id}/{current_assignment_id}', response_model=AnnotationItem)
......@@ -106,6 +120,8 @@ async def get_next_assignment_for_scope_for_user(assignment_scope_id: str,
assignment_scope_id=assignment_scope_id,
user_id=permissions.user.user_id,
engine=db_engine)
if assignment is None:
raise NoNextAssignmentWarning(f'Could not determine a next assignment for scope {assignment_scope_id}')
return await _construct_annotation_item(assignment=assignment, project_id=permissions.permissions.project_id)
......@@ -152,11 +168,13 @@ async def get_assignment_scopes_for_project(permissions=Depends(UserPermissionCh
@router.get('/annotate/scope/{assignment_scope_id}', response_model=AssignmentScopeModel)
async def get_assignment_scope(assignment_scope_id: str,
permissions=Depends(UserPermissionChecker(['annotations_read', 'annotations_edit'],
fulfill_all=False))) \
-> AssignmentScopeModel:
permissions=Depends(
UserPermissionChecker(['annotations_read', 'annotations_edit'], fulfill_all=False))
) -> AssignmentScopeModel:
scope = await read_assignment_scope(assignment_scope_id=assignment_scope_id, engine=db_engine)
return scope
if scope is not None:
return scope
raise AssignmentScopeNotFoundError(f'No assignment scope found in the DB for {assignment_scope_id}')
@router.put('/annotate/scope/', response_model=str)
......@@ -217,6 +235,9 @@ async def get_annotations(assignment_scope_id: str, permissions=Depends(UserPerm
async def save_annotation(annotated_item: AnnotatedItem,
permissions=Depends(UserPermissionChecker('annotations_read'))) -> AssignmentStatus:
# double-check, that the supposed assignment actually exists
if annotated_item.assignment.assignment_id is None:
raise MissingInformationError('Missing `assignment_id` in `annotation_item`!')
assignment_db = await read_assignment(assignment_id=annotated_item.assignment.assignment_id, engine=db_engine)
if permissions.user.user_id == assignment_db.user_id \
......@@ -227,7 +248,9 @@ async def save_annotation(annotated_item: AnnotatedItem,
status = await upsert_annotations(annotations=annotations,
assignment_id=annotated_item.assignment.assignment_id,
engine=db_engine)
return status
if status is not None:
return status
raise SaveFailedError('Failed to save annotation!')
else:
raise HTTPException(
status_code=http_status.HTTP_403_FORBIDDEN,
......
from typing import Type
from fastapi import APIRouter
from pydantic import BaseModel
from ..errors import UnknownEventError
from ...util.events import eventbus, events, AnyEvent, AnyEventType
from ...util.logging import get_logger
......@@ -16,10 +16,6 @@ class Event(BaseModel):
payload: AnyEvent
class UnknownEventError(Exception):
pass
@router.post('/emit')
async def emit(event: Event) -> None:
"""
......@@ -31,13 +27,17 @@ async def emit(event: Event) -> None:
"""
logger.info(f'Received external event to be emitted: {event.event}')
if hasattr(events, event.event):
EmitEvent: Type[AnyEvent] = getattr(events, event.event)
emit_event = EmitEvent.parse_obj(event.payload)
logger.debug(f'Going to emit {EmitEvent} ({emit_event})')
await eventbus.emit_async(emit_event._name, emit_event) # noqa PyProtectedMember
else:
emit_event_type: AnyEvent = getattr(events, event.event, None)
if emit_event_type is None:
raise UnknownEventError(f'Event {event.event} not in {AnyEvent}')
if not issubclass(emit_event_type, events.BaseEvent):
raise UnknownEventError(f'Event {event.event} is not a valid subclass of `BaseEvent`')
emit_event = emit_event_type.parse_obj(event.payload)
logger.debug(f'Going to emit {emit_event} ({emit_event})')
await eventbus.emit_async(emit_event._name, emit_event) # noqa PyProtectedMember
# TODO user-configurable triggers (e.g. trigger on event or cron-like)
# - create schema, model, crud in nacsos-data (probably could just be a JSONB field in `Project`
......
......@@ -9,6 +9,7 @@ from server.util.logging import get_logger
from . import permissions
from . import items
from ...errors import ProjectNotFoundError
logger = get_logger('nacsos.api.route.project')
router = APIRouter()
......@@ -18,7 +19,10 @@ logger.info('Setting up projects route')
@router.get('/{project_id}/info/', response_model=ProjectModel)
async def get_project(project_id: str, permission=Depends(UserPermissionChecker())) -> ProjectModel:
return await read_project_by_id(project_id=project_id, engine=db_engine)
project = await read_project_by_id(project_id=project_id, engine=db_engine)
if project is not None:
return project
raise ProjectNotFoundError(f'No project found in the database for id {project_id}')
# TODO create project (superuser only)
......
......@@ -4,7 +4,7 @@ from nacsos_data.models.projects import ProjectPermissionsModel
from nacsos_data.db.crud.projects import read_project_permissions_for_project, read_project_permissions_by_id
from server.data import db_engine
from server.util.security import UserPermissionChecker
from server.util.security import UserPermissionChecker, UserPermissions
from server.util.logging import get_logger
logger = get_logger('nacsos.api.route.project')
......@@ -12,7 +12,7 @@ router = APIRouter()
@router.get('/me', response_model=ProjectPermissionsModel)
async def get_project_permissions_current_user(permission=Depends(UserPermissionChecker())) \
async def get_project_permissions_current_user(permission: UserPermissions = Depends(UserPermissionChecker())) \
-> ProjectPermissionsModel:
return permission.permissions
......
......@@ -4,6 +4,7 @@ from nacsos_data.models.users import UserModel
from nacsos_data.models.projects import ProjectModel
from nacsos_data.db.crud.projects import read_all_projects, read_all_projects_for_user
from server.api.errors import MissingInformationError
from server.data import db_engine
from server.util.security import get_current_active_user
from server.util.logging import get_logger
......@@ -25,4 +26,8 @@ async def get_all_projects(current_user: UserModel = Depends(get_current_active_
"""
if current_user.is_superuser:
return await read_all_projects(engine=db_engine)
if current_user.user_id is None:
raise MissingInformationError('`current_user` has no `user_id`, which points to a serious issue in the system!')
return await read_all_projects_for_user(current_user.user_id, engine=db_engine)
from fastapi import APIRouter, Depends, Query
from server.api.errors import DataNotFoundWarning, UserNotFoundError
from server.util.logging import get_logger
from nacsos_data.models.users import UserModel, UserInDBModel
from nacsos_data.db.crud.users import \
......@@ -27,7 +29,9 @@ async def get_project_users(project_id: str,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \
-> list[UserInDBModel]:
result = await read_project_users(project_id=project_id, engine=db_engine)
return result
if result is not None:
return result
raise DataNotFoundWarning(f'Found no users for project with ID {project_id}')
# FIXME refine required permission
......@@ -36,7 +40,9 @@ async def get_user_by_id(user_id: str,
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \
-> UserInDBModel:
result = await read_user_by_id(user_id=user_id, engine=db_engine)
return result
if result is not None:
return result
raise UserNotFoundError(f'User not found in DB for ID {user_id}')
# FIXME refine required permission
......@@ -45,4 +51,6 @@ async def get_users_by_ids(user_id: list[str] = Query(),
permissions: UserPermissions = Depends(UserPermissionChecker('annotations_edit'))) \
-> list[UserInDBModel]:
result = await read_users_by_ids(user_ids=user_id, engine=db_engine)
return result
if result is not None:
return result
raise UserNotFoundError(f'Users not found in DB for IDs {user_id}')
......@@ -34,7 +34,9 @@ class ServerConfig(BaseModel):
if isinstance(v, str) and not v.startswith('['):
return [i.strip() for i in v.split(',')]
if isinstance(v, str) and v.startswith('['):
return json.loads(v)
ret = json.loads(v)
if type(ret) == list:
return ret
elif isinstance(v, (list, str)):
return v
raise ValueError(v)
......@@ -80,7 +82,7 @@ class EmailConfig(BaseModel):
and values.get('SENDER_ADDRESS')
)
TEST_USER: EmailStr = 'test@nacsos.eu'
TEST_USER: EmailStr = EmailStr('test@nacsos.eu')
class UsersConfig(BaseModel):
......@@ -103,14 +105,18 @@ class Settings(BaseSettings):
# EMAIL: EmailConfig
LOG_CONF_FILE: str = 'config/logging.conf'
LOGGING_CONF: dict | None = None
LOGGING_CONF: dict[str, Any] | None = None
@validator('LOGGING_CONF', pre=True)
def read_logging_config(cls, v: dict, values: dict[str, Any]) -> dict:
def read_logging_config(cls, v: dict[str, Any] | None, values: dict[str, str]) -> dict[str, Any]:
if isinstance(v, dict):
return v
with open(values.get('LOG_CONF_FILE'), 'r') as f:
return yaml.safe_load(f.read())
filename = values.get('LOG_CONF_FILE', cls.LOG_CONF_FILE)
with open(filename, 'r') as f:
ret = yaml.safe_load(f.read())
if type(ret) == dict:
return ret
raise ValueError('Logging config invalid!')
class Config:
case_sensitive = True
......
from typing import Union, Literal
from typing import Union, Literal, TYPE_CHECKING
from pymitter import EventEmitter
......@@ -6,9 +6,16 @@ from .hooks import imports
from . import events
eventbus = EventEmitter(delimiter='_', wildcard=True)
AnyEvent = Union[events.BaseEvent.get_subclasses()]
AnyEventType = Literal[tuple(sc.__name__ for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
AnyEventLiteral = Literal[tuple(sc._name for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
if TYPE_CHECKING:
from typing import TypeVar
AnyEvent = TypeVar('AnyEvent', bound=events.BaseEvent)
AnyEventType = str
AnyEventLiteral = str
else:
AnyEvent = Union[events.BaseEvent.get_subclasses()]
AnyEventType = Literal[tuple(sc.__name__ for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
AnyEventLiteral = Literal[tuple(sc._name for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
# Permanent/global listeners
eventbus.on(events.PipelineTaskStatusChangedEvent._name, imports.update_import_status) # noqa PyProtectedMember
......
......@@ -19,13 +19,13 @@ async def update_import_status(event: PipelineTaskStatusChangedEvent):
async with db_engine.session() as session:
stmt = select(Import).filter_by(pipeline_task_id=event.task_id)
import_details: Import = (await session.execute(stmt)).scalars().one_or_none()
import_details: Import | None = (await session.execute(stmt)).scalars().one_or_none()
if import_details is None and event.import_id is not None:
logger.debug(f'second try with {event.import_id}')
stmt = select(Import).filter_by(import_id=event.import_id)
import_details: Import = (await session.execute(stmt)).scalars().one_or_none()
logger.debug(repr(import_details))
import_details = (await session.execute(stmt)).scalars().one_or_none()
if import_details is not None:
# Seems like task was started, remember the time
......
......@@ -8,8 +8,9 @@ from uvicorn.logging import DefaultFormatter
from server.util.config import settings
def get_logger(name=None):
logging.config.dictConfig(settings.LOGGING_CONF)
def get_logger(name: str | None = None):
if settings.LOGGING_CONF is not None:
logging.config.dictConfig(settings.LOGGING_CONF)
return logging.getLogger(name)
......
import time
import json
from typing import Literal, Any
from typing import Literal, Any, TypeVar
from resource import getrusage, RUSAGE_SELF
from pydantic import BaseModel
from fastapi import HTTPException, status as http_status
......@@ -12,16 +13,6 @@ from starlette.responses import Response
from server.util.logging import get_logger
logger = get_logger('nacsos.server.middlewares')
try:
from resource import getrusage, RUSAGE_SELF
except ImportError as e:
logger.warning(e)
RUSAGE_SELF = None
def getrusage(*args, **kwargs): # noqa:E303
return 0.0, 0.0
class ErrorDetail(BaseModel):
......@@ -35,9 +26,12 @@ class ErrorDetail(BaseModel):
args: list[Any]
Error = TypeVar('Error', bound=Warning | Exception)
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
@classmethod
def _resolve_args(cls, ew: Exception | Warning) -> list[Any]:
def _resolve_args(cls, ew: Error) -> list[Any]:
if hasattr(ew, 'args') and ew.args is not None and len(ew.args) > 0:
ret = []
for arg in ew.args:
......@@ -50,9 +44,11 @@ class ErrorHandlingMiddleware(BaseHTTPMiddleware):
return [repr(ew)]
@classmethod
def _resolve_status(cls, ew: Exception | Warning) -> http_status:
def _resolve_status(cls, ew: Error) -> int:
if hasattr(ew, 'status'):
return ew.status
error_status = getattr(ew, 'status')
if type(error_status) == int:
return error_status
return http_status.HTTP_400_BAD_REQUEST
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
......
......@@ -11,6 +11,7 @@ from nacsos_data.models.projects import ProjectPermissionsModel, ProjectPermissi
from nacsos_data.db.crud.users import read_user_by_name as crud_get_user_by_name, read_user_by_id
from nacsos_data.db.crud.projects import read_project_permissions_for_user as crud_get_project_permissions_for_user
from server.api.errors import MissingInformationError
from server.data import db_engine
from server.util.config import settings
......@@ -58,7 +59,7 @@ async def authenticate_user(username: str, plain_password: str):
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
def create_access_token(data: dict[str, str | datetime], expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
......@@ -75,6 +76,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
user = None
if settings.USERS.DEFAULT_USER is None:
try:
if token is None:
......@@ -86,15 +88,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = await crud_get_user_by_name(username=token_data.username, engine=db_engine)
token_user = token_data.username
if token_user is not None:
user = await crud_get_user_by_name(username=token_user, engine=db_engine)
else:
user = await read_user_by_id(user_id=settings.USERS.DEFAULT_USER, engine=db_engine)
logger.warning('Authentication using fake user!')
logger.debug(f'Current user: user_id: {user.user_id} {user.username}')
if user is None:
raise credentials_exception
logger.debug(f'Current user: user_id: {user.user_id} {user.username}')
return user
......@@ -113,10 +116,12 @@ def get_current_active_superuser(current_user: UserModel = Depends(get_current_a
async def get_project_permissions_for_user(project_id: str, current_user: UserModel) -> ProjectPermissionsModel | None:
if current_user.user_id is None:
raise MissingInformationError('The `current_user` is missing the (here) required `user_id` field.')
if current_user.is_superuser:
# admin gets to do anything always, so return with simulated full permissions
return ProjectPermissionsModel.get_virtual_admin(project_id=project_id,
user_id=current_user.user_id)
user_id=str(current_user.user_id))
return await crud_get_project_permissions_for_user(user_id=current_user.user_id,
project_id=project_id,
......@@ -124,7 +129,9 @@ async def get_project_permissions_for_user(project_id: str, current_user: UserMo
class UserPermissionChecker:
def __init__(self, permissions: list[ProjectPermission] | ProjectPermission = None, fulfill_all: bool = True):
def __init__(self,
permissions: list[ProjectPermission] | ProjectPermission | None = None,
fulfill_all: bool = True):
self.permissions = permissions
self.fulfill_all = fulfill_all
......@@ -159,11 +166,12 @@ class UserPermissionChecker:
# check that each required permission is fulfilled
for permission in self.permissions:
if self.fulfill_all and not project_permissions[permission]:
p_permission = getattr(project_permissions, permission, False)
if self.fulfill_all and not p_permission:
raise InsufficientPermissions(
f'User does not have permission "{permission}" for project "{x_project_id}".'
)
any_permission_fulfilled = any_permission_fulfilled or project_permissions[permission]
any_permission_fulfilled = any_permission_fulfilled or p_permission
if not any_permission_fulfilled and not self.fulfill_all:
raise InsufficientPermissions(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment