Source code for saq.worker

"""
Workers
"""

from __future__ import annotations

import asyncio
import contextvars
import logging
import os
import signal
import sys
import traceback
import threading
import typing as t
import typing_extensions as te
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone, tzinfo

from croniter import croniter

from saq.job import Status
from saq.queue import Queue
from saq.types import (
    CtxType,
    FunctionsType,
    JobTaskContext,
    LifecycleFunctionsType,
    SettingsDict,
)
from saq.utils import cancel_tasks, millis, now, uuid1

if t.TYPE_CHECKING:
    from asyncio import Task
    from collections.abc import Callable, Collection, Coroutine

    from aiohttp.web_app import Application

    from saq.job import CronJob, Job
    from saq.types import (
        Function,
        PartialTimersDict,
        TimersDict,
        WorkerInfo,
    )


logger = logging.getLogger("saq")

# Type that represents arbitrary json
JsonDict = t.Dict[str, t.Any]


[docs] class Worker(t.Generic[CtxType]): """ Worker is used to process and monitor jobs. Args: id: optional override for the worker id, if not provided, uuid will be used queue: instance of saq.queue.Queue functions: list of async functions concurrency: number of jobs to process concurrently cron_jobs: List of CronJob instances. cron_tz: timezone for cron scheduler startup: async function to call on startup shutdown: async function to call on shutdown before_process: async function to call before a job processes after_process: async function to call after a job processes timers: dict with various timer overrides in seconds schedule: how often we poll to schedule jobs worker_info: how often to update worker info, stats and metadata sweep: how often to clean up stuck jobs abort: how often to check if a job is aborted dequeue_timeout: how long it will wait to dequeue burst: whether to stop the worker once all jobs have been processed max_burst_jobs: the maximum number of jobs to process in burst mode shutdown_grace_period_s: how long to wait for jobs to finish before sending cancellation signals. cancellation_hard_deadline_s: how long to wait for a job to finish after sending a cancellation signal. metadata: arbitrary data to pass to the worker which it will register with saq poll_interval: If > 0.0, dequeue will use polling instead of listen/notify to trigger dequeues. This only affects Postgres. (default 0.0) """ SIGNALS = [signal.SIGINT, signal.SIGTERM] if os.name != "nt" else [] def __init__( self, queue: Queue, functions: FunctionsType[CtxType], *, id: t.Optional[str] = None, concurrency: int = 10, cron_jobs: Collection[CronJob[CtxType]] | None = None, cron_tz: tzinfo = timezone.utc, startup: LifecycleFunctionsType[CtxType] | None = None, shutdown: LifecycleFunctionsType[CtxType] | None = None, before_process: LifecycleFunctionsType[CtxType] | None = None, after_process: LifecycleFunctionsType[CtxType] | None = None, timers: PartialTimersDict | None = None, dequeue_timeout: float = 0.0, burst: bool = False, max_burst_jobs: int | None = None, shutdown_grace_period_s: int | None = None, cancellation_hard_deadline_s: float = 1.0, metadata: t.Optional[JsonDict] = None, poll_interval: float = 0.0, ) -> None: self.queue = queue self.concurrency = concurrency self.pool = ThreadPoolExecutor() self.startup = ensure_coroutine_function_many(startup, self.pool) if startup else None self.shutdown = shutdown self.before_process = ( ensure_coroutine_function_many(before_process, self.pool) if before_process else None ) self.after_process = ( ensure_coroutine_function_many(after_process, self.pool) if after_process else None ) self.timers: TimersDict = { "schedule": 1, "worker_info": 10, "sweep": 60, "abort": 1, } if timers is not None: self.timers.update(timers) self.event = asyncio.Event() functions = set(functions) self.functions: dict[str, Function[CtxType]] = {} self.cron_jobs: Collection[CronJob] = cron_jobs or [] self.cron_tz: tzinfo = cron_tz self.context: CtxType = t.cast(CtxType, {"worker": self}) self.tasks: set[Task[t.Any]] = set() self.job_task_contexts: dict[Job, JobTaskContext] = {} self.dequeue_timeout = dequeue_timeout self.burst = burst self.max_burst_jobs = max_burst_jobs self.burst_jobs_processed = 0 self.burst_jobs_processed_lock = threading.Lock() self.burst_condition_met = False self._metadata = metadata self._poll_interval = poll_interval self._stop_lock = asyncio.Lock() self._stopped = False self._shutdown_grace_period_s = shutdown_grace_period_s self._cancellation_hard_deadline_s = cancellation_hard_deadline_s self.id = uuid1() if id is None else id if self.burst: if self.dequeue_timeout <= 0: raise ValueError( "dequeue_timeout must be a positive value greater than 0 when the burst mode is enabled" ) if self.max_burst_jobs is not None: self.concurrency = min(self.concurrency, self.max_burst_jobs) for job in self.cron_jobs: if not croniter.is_valid(job.cron): raise ValueError(f"Cron is invalid {job.cron}") functions.add(job.function) for function in functions: if isinstance(function, tuple): name, function = function else: name = function.__qualname__ self.functions[name] = function async def _before_process(self, ctx: CtxType) -> None: if self.before_process: for bp in self.before_process: await bp(ctx) async def _after_process(self, ctx: CtxType) -> None: if self.after_process: for ap in self.after_process: await ap(ctx)
[docs] async def start(self) -> None: """Start processing jobs and upkeep tasks.""" logger.info("Worker starting: %s", repr(self.queue)) logger.debug("Registered functions:\n%s", "\n".join(f" {key}" for key in self.functions)) try: self.event = asyncio.Event() async with self._stop_lock: self._stopped = False loop = asyncio.get_running_loop() for signum in self.SIGNALS: loop.add_signal_handler(signum, self.event.set) if self.startup: for s in self.startup: await s(self.context) self.tasks.update(await self.upkeep()) for _ in range(self.concurrency): self._process() await self.event.wait() except asyncio.CancelledError: pass finally: logger.info("Working shutting down") await self.stop() for signum in self.SIGNALS: loop.remove_signal_handler(signum)
[docs] async def stop(self) -> None: """Stop the worker and cleanup.""" self.event.set() async with self._stop_lock: if self._stopped: return try: all_tasks = list(self.tasks) self.tasks.clear() try: await asyncio.wait_for( asyncio.gather(*all_tasks, return_exceptions=True), timeout=self._shutdown_grace_period_s or 0, ) except asyncio.TimeoutError: logger.warning( "Some tasks did not finish within the shutdown grace period, requesting cancellation" ) cancelled = await cancel_tasks( all_tasks, timeout=self._cancellation_hard_deadline_s ) if not cancelled: logger.warning( "Some tasks did not finish cancellation in time, they may be stuck or blocked" ) if sys.version_info[0:2] < (3, 9): self.pool.shutdown(True) else: self.pool.shutdown(True, cancel_futures=True) if not self.shutdown: return # We can't reuse our task pool here, because we shut it to close tasks with ThreadPoolExecutor() as shutdown_pool: shutdown_callbacks = ensure_coroutine_function_many( self.shutdown, shutdown_pool ) for s in shutdown_callbacks: await s(self.context) finally: self._stopped = True
async def schedule(self, lock: int = 1) -> None: for cron_job in self.cron_jobs: kwargs = cron_job.__dict__.copy() function = kwargs.pop("function").__qualname__ kwargs["key"] = f"cron:{function}" if kwargs.pop("unique") else None start_time = datetime.now(self.cron_tz) scheduled = croniter(kwargs.pop("cron"), start_time).get_next() await self.queue.enqueue( function, scheduled=int(scheduled), **{k: v for k, v in kwargs.items() if v is not None}, ) job_ids = await self.queue.schedule(lock) if job_ids: logger.info("Scheduled %s", job_ids) async def worker_info(self, ttl: int = 60) -> WorkerInfo: return await self.queue.worker_info( self.id, queue_key=self.queue.name, metadata=self._metadata, ttl=ttl )
[docs] async def upkeep(self) -> list[Task[None]]: """Start various upkeep tasks async.""" async def poll( func: Callable[[int], Coroutine], sleep: int, arg: int | None = None ) -> None: while not self.event.is_set(): try: await func(arg or sleep) except (Exception, asyncio.CancelledError): if self.event.is_set(): return logger.exception("Upkeep task failed unexpectedly") await asyncio.sleep(sleep) return [ asyncio.create_task(poll(self.abort, self.timers["abort"])), asyncio.create_task(poll(self.schedule, self.timers["schedule"])), asyncio.create_task(poll(self.queue.sweep, self.timers["sweep"])), asyncio.create_task( poll( self.worker_info, self.timers["worker_info"], self.timers["worker_info"] + 1, ) ), ]
async def abort(self, abort_threshold: float) -> None: def get_duration(job: Job) -> float: return job.duration("running") or 0 jobs = [ job for job in self.job_task_contexts if get_duration(job) >= millis(abort_threshold) ] if not jobs: return for job in await self.queue.jobs(job.key for job in jobs): if not job or job.status not in (Status.ABORTING, Status.ABORTED): continue task_data = self.job_task_contexts.get(job, None) if not task_data: logger.warning("No task data found for job %s", job.id) continue task = task_data["task"] logger.info("Aborting %s", job.id) if not task.done(): task_data["aborted"] = "abort" if job.error is None else job.error # abort should be a blocking operation _ = await cancel_tasks([task], None) await self.queue.finish_abort(job) async def process(self) -> bool: context: CtxType | None = None job: Job | None = None task_ctx: JobTaskContext | None = None try: job = await self.queue.dequeue( timeout=self.dequeue_timeout, poll_interval=self._poll_interval, ) if job is None: return False job.started = now() job.attempts += 1 job.worker_id = self.id await job.update(status=Status.ACTIVE) context = t.cast(CtxType, {**self.context, "job": job}) await self._before_process(context) logger.info("Processing %s", job.info(logger.isEnabledFor(logging.DEBUG))) function = ensure_coroutine_function(self.functions[job.function], self.pool) task = asyncio.create_task(function(context, **(job.kwargs or {}))) task_ctx = JobTaskContext(task=task, aborted=None) self.job_task_contexts[job] = task_ctx try: result = await asyncio.wait_for( asyncio.shield(task), job.timeout if job.timeout else None ) except asyncio.TimeoutError: # Since we have a shield around the task passed to wait_for, # we need to explicitly cancel it on timeout. task.cancel() raise if task_ctx["aborted"] is None: await job.finish(Status.COMPLETE, result=result) except asyncio.CancelledError: if not job or task_ctx is None: return False task = task_ctx["task"] aborted = task_ctx["aborted"] if aborted is not None: await job.finish(Status.ABORTED, error=aborted) return False if not task.done(): cancelled = await cancel_tasks([task], self._cancellation_hard_deadline_s) if not cancelled: logger.warning( "Function: %s did not finish cancellation in time, it may be stuck or blocked", job.function, extra={"job_id": job.id}, ) await job.retry("cancelled") except Exception as ex: if context is not None: context["exception"] = ex if job: logger.exception("Error processing job %s", job) # Ensure that the task is done or cancelled if task_ctx is not None: task = task_ctx["task"] if not task.done(): cancelled = await cancel_tasks([task], self._cancellation_hard_deadline_s) if not cancelled: logger.warning( "Function '%s' did not finish cancellation in time, it may be stuck or blocked", job.function, extra={"job_id": job.id}, ) error = traceback.format_exc() if job.retryable: await job.retry(error) else: await job.finish(Status.FAILED, error=error) finally: if context: # Only clear our own slot: a later same-key attempt (e.g. a # sweep re-enqueue) may have replaced it, and we must not # pop its context. if ( job is not None and task_ctx is not None and self.job_task_contexts.get(job) is task_ctx ): del self.job_task_contexts[job] try: await self._after_process(context) except (Exception, asyncio.CancelledError): logger.exception("Failed to run after process hook") return True def _process(self, previous_task: Task | None = None) -> None: if previous_task: self.tasks.discard(previous_task) if self.burst and self._check_burst(previous_task): if not any(t.get_name() == "process" for t in self.tasks): # Stop the worker if all process tasks are done self.event.set() return if not self.event.is_set(): new_task = asyncio.create_task(self.process(), name="process") self.tasks.add(new_task) new_task.add_done_callback(self._process) def _check_burst(self, previous_task: Task) -> bool: if self.burst_condition_met: return self.burst_condition_met job_dequeued = previous_task.result() if not job_dequeued: self.burst_condition_met = True elif self.max_burst_jobs is not None: with self.burst_jobs_processed_lock: self.burst_jobs_processed += 1 if self.burst_jobs_processed >= self.max_burst_jobs: self.burst_condition_met = True return self.burst_condition_met
P = te.ParamSpec("P") R = te.TypeVar("R") OneOrManyCallable = t.Union[t.Callable[P, R], t.Collection[t.Callable[P, R]]] def ensure_coroutine_function_many( func: OneOrManyCallable[P, R] | OneOrManyCallable[P, Coroutine[t.Any, t.Any, R]], pool: ThreadPoolExecutor, ) -> t.List[Callable[P, Coroutine[t.Any, t.Any, R]]]: if callable(func): return [ensure_coroutine_function(func, pool)] return [ensure_coroutine_function(f, pool) for f in func] def ensure_coroutine_function( func: Callable[P, R] | Callable[P, Coroutine[t.Any, t.Any, R]], pool: ThreadPoolExecutor, ) -> Callable[P, Coroutine[t.Any, t.Any, R]]: if asyncio.iscoroutinefunction(func): return func async def wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: future = None try: ctx = contextvars.copy_context() future = pool.submit(lambda: ctx.run(func, *args, **kwargs)) return await asyncio.wrap_future(future) except asyncio.CancelledError: try: # job has already been cancelled, swallow all errors if future is not None: await asyncio.wrap_future(future) except Exception: pass raise return wrapped def import_settings(settings: str) -> SettingsDict: import importlib # given a.b.c, parses out a.b as the module path and c as the variable module_path, name = settings.strip().rsplit(".", 1) module = importlib.import_module(module_path) settings_obj = getattr(module, name) if callable(settings_obj): settings_obj = settings_obj() if not isinstance(settings_obj, dict): raise TypeError( f"Settings {settings} must be a dictionary or a callable that returns a dictionary, got final type '{type(settings_obj)}'" ) return t.cast(SettingsDict, settings_obj) def start( settings: str, web: bool = False, extra_web_settings: list[str] | None = None, port: int = 8080, ) -> None: settings_obj = import_settings(settings) if "queue" not in settings_obj: settings_obj["queue"] = Queue.from_url("redis://localhost") loop = asyncio.new_event_loop() worker = Worker(**settings_obj) async def worker_start() -> None: try: await worker.queue.connect() await worker.start() finally: await worker.queue.disconnect() if web: import aiohttp.web from saq.web.aiohttp import create_app extra_web_settings = extra_web_settings or [] web_settings = [settings_obj] + [import_settings(s) for s in extra_web_settings] queues = [s["queue"] for s in web_settings if s.get("queue")] async def shutdown(_app: Application) -> None: await worker.stop() app = create_app(queues) app.on_shutdown.append(shutdown) loop.create_task(worker_start()).add_done_callback( lambda _: signal.raise_signal(signal.SIGTERM) ) aiohttp.web.run_app(app, port=port, loop=loop) else: loop.run_until_complete(worker_start()) async def async_check_health(queue: Queue) -> int: await queue.connect() info = await queue.info() name = info.get("name") if name != queue.name: logger.warning( "Health check failed. Unknown queue name %s. Expected %s", name, queue.name, ) status = 1 elif not info.get("workers"): logger.warning("No active workers found for queue %s", name) status = 1 else: workers = len(info["workers"].values()) logger.info("Found %d active workers for queue %s", workers, name) status = 0 await queue.disconnect() return status def check_health(settings: str) -> int: settings_dict = import_settings(settings) loop = asyncio.new_event_loop() queue = settings_dict.get("queue") or Queue.from_url("redis://localhost") return loop.run_until_complete(async_check_health(queue))