From 9607bcaf817d80e440b34baca83a7d2a2904a9ab Mon Sep 17 00:00:00 2001
From: Tim Repke <repke@mcc-berlin.net>
Date: Fri, 5 Jul 2024 18:42:31 +0200
Subject: [PATCH] fix import

---
 requirements.txt                   |  2 +-
 server/pipelines/actor.py          | 44 ++++++++++++++++++------------
 server/pipelines/tasks/__init__.py | 14 ----------
 server/pipelines/tasks/imports.py  | 25 +++++++++--------
 4 files changed, 42 insertions(+), 43 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 6a8bf56..7a567fc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,4 @@ aiofiles==24.1.0
 dramatiq[redis,watch]==1.17.0
 dramatiq-abort==1.1.0
 dramatiq-dashboard==0.4.0
-nacsos_data[utils,scripts] @ git+ssh://git@gitlab.pik-potsdam.de/mcc-apsis/nacsos/nacsos-data.git@v0.15.3
+nacsos_data[utils,scripts] @ git+ssh://git@gitlab.pik-potsdam.de/mcc-apsis/nacsos/nacsos-data.git@v0.15.4
diff --git a/server/pipelines/actor.py b/server/pipelines/actor.py
index f1e2979..a142a2b 100644
--- a/server/pipelines/actor.py
+++ b/server/pipelines/actor.py
@@ -16,9 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession  # noqa F401
 from nacsos_data.models.pipeline import compute_fingerprint, TaskStatus
 from nacsos_data.db.schemas import Task
 
-from server.util.config import settings
+from server.util.config import settings, DatabaseConfig
 from server.util.logging import get_file_logger, LogRedirector
 
+logger = logging.getLogger('nacsos.pipelines.actor')
+
 R = TypeVar("R")
 P = ParamSpec("P")
 
@@ -83,7 +85,9 @@ class NacsosActor(Actor[P, R]):
     @classmethod
     @asynccontextmanager
     async def exec_context(cls) \
-            -> AsyncIterator[tuple[AsyncSession, logging.Logger, Path, str, str | None, str | None]]:
+            -> AsyncIterator[tuple[DatabaseConfig, logging.Logger, Path, str, str | None, str | None]]:
+        logger.info('Opening execution context')
+
         from nacsos_data.db import get_engine_async
         db_engine = get_engine_async(settings=settings.DB)  # type: ignore[arg-type]
 
@@ -95,6 +99,7 @@ class NacsosActor(Actor[P, R]):
             message_id = message.message_id
             actor_name = message.options.get('nacsos_actor_name')  # type: ignore[assignment]
             task_id = message.options.get('nacsos_task_id')
+            logger.info(f'message_id: {message_id}, task_id: {task_id}, actor_name: {actor_name}')
 
         target_dir = settings.PIPES.target_dir / str(task_id)
         target_dir.mkdir(parents=True, exist_ok=True)
@@ -114,22 +119,27 @@ class NacsosActor(Actor[P, R]):
             else:
                 task_logger.warning(f'Task {task_id} not found in database.')
 
-            status: TaskStatus | None = None
-            with TemporaryDirectory(dir=settings.PIPES.WORKING_DIR) as work_dir, \
-                    LogRedirector(task_logger, level='INFO', stream='stdout'), \
-                    LogRedirector(task_logger, level='ERROR', stream='stderr'):
-                try:
-                    yield session, task_logger, target_dir, work_dir, task_id, message_id
-
-                except (Exception, Warning) as e:
-                    # Oh no, something failed. Do some post-mortem logging
-                    tb = traceback.format_exc()
-                    task_logger.fatal(tb)
-                    task_logger.fatal(f'{type(e).__name__}: {e}')
-                    status = TaskStatus.FAILED
-                finally:
+        status: TaskStatus | None = None
+        with TemporaryDirectory(dir=settings.PIPES.WORKING_DIR) as work_dir, \
+                LogRedirector(task_logger, level='INFO', stream='stdout'), \
+                LogRedirector(task_logger, level='ERROR', stream='stderr'):
+            try:
+                # Yielding this info implicitly executes everything in the `with:` context.
+                yield settings.DB, task_logger, target_dir, work_dir, task_id, message_id
+            except (Exception, Warning) as e:
+                # Oh no, something failed. Do some post-mortem logging
+                logger.error('Big drama from an actor!')
+                logger.exception(e)
+                tb = traceback.format_exc()
+                task_logger.fatal(tb)
+                task_logger.fatal(f'{type(e).__name__}: {e}')
+                status = TaskStatus.FAILED
+            finally:
+                async with db_engine.session() as session:  # type: AsyncSession
                     task = await session.get(Task, task_id)
-                    status = status or TaskStatus.COMPLETED
+                    logger.debug(f'Pre-set actor status: {status}')
+                    if status is None:
+                        status = TaskStatus.COMPLETED
                     if task:
                         task.status = status
                         task.time_finished = datetime.datetime.now()
diff --git a/server/pipelines/tasks/__init__.py b/server/pipelines/tasks/__init__.py
index 0f7254b..579d735 100644
--- a/server/pipelines/tasks/__init__.py
+++ b/server/pipelines/tasks/__init__.py
@@ -1,25 +1,11 @@
-import asyncio
-import datetime
 import logging
-import traceback
-import uuid
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import TypeVar, NamedTuple, ParamSpec, Protocol, Callable, Awaitable, Any, TYPE_CHECKING, Generic
 
 import dramatiq
-from dramatiq import Actor, Broker
 from dramatiq.middleware import CurrentMessage, AsyncIO
 from dramatiq.brokers.redis import RedisBroker
 from dramatiq_abort import Abortable, backends
 
-from sqlalchemy.ext.asyncio import AsyncSession
-
-from nacsos_data.models.pipeline import TaskModel, compute_fingerprint, TaskStatus
-from nacsos_data.db.schemas import Task
-
 from server.util.config import settings
-from server.util.logging import get_file_logger, LogRedirector
 
 logger = logging.getLogger('nacsos.pipelines.task')
 broker = RedisBroker(url=settings.PIPES.REDIS_URL)
diff --git a/server/pipelines/tasks/imports.py b/server/pipelines/tasks/imports.py
index b2e394f..1aaf2da 100644
--- a/server/pipelines/tasks/imports.py
+++ b/server/pipelines/tasks/imports.py
@@ -2,6 +2,8 @@ from pathlib import Path
 from typing import cast
 
 import dramatiq
+
+from nacsos_data.db import get_engine_async
 from nacsos_data.db.schemas import Import
 from nacsos_data.models.imports import ImportConfig, ImportModel
 from nacsos_data.util import ensure_values
@@ -25,20 +27,21 @@ def prefix_sources(sources: list[Path]):
 
 @dramatiq.actor(actor_class=NacsosActor, max_retries=0)  # type: ignore[arg-type]
 async def import_task(import_id: str | None = None) -> None:
-    async with NacsosActor.exec_context() as (session, logger, target_dir, work_dir, task_id, message_id):
+    async with NacsosActor.exec_context() as (db_settings, logger, target_dir, work_dir, task_id, message_id):
         logger.info('Preparing import task!')
+        db_engine = get_engine_async(settings=db_settings)
+        async with db_engine.session() as session:
+            if import_id is None:
+                raise ValueError('import_id is required here.')
 
-        if import_id is None:
-            raise ValueError('import_id is required here.')
-
-        stmt = select(Import).where(Import.import_id == import_id)
-        result = (await session.execute(stmt)).scalars().one_or_none()
-        if result is None:
-            raise NotFoundError(f'No import info for id={import_id}')
+            stmt = select(Import).where(Import.import_id == import_id)
+            result = (await session.execute(stmt)).scalars().one_or_none()
+            if result is None:
+                raise NotFoundError(f'No import info for id={import_id}')
 
-        import_details = ImportModel.model_validate(result.__dict__)
-        result.pipeline_task_id = task_id
-        await session.commit()
+            import_details = ImportModel.model_validate(result.__dict__)
+            result.pipeline_task_id = task_id
+            await session.commit()
 
         user_id, project_id, config = cast(tuple[str, str, ImportConfig],
                                            ensure_values(import_details, 'user_id', 'project_id', 'config'))
-- 
GitLab