Source code for saq.web.starlette

"""
Starlette/FastAPI/ASGI
"""

from __future__ import annotations

import os

from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import FileResponse, JSONResponse, Response
from starlette.routing import Mount, Route
from starlette.staticfiles import PathLike, StaticFiles
from starlette.types import Scope

from saq.job import Job
from saq.queue import Queue
from saq.types import QueueInfo
from saq.web.common import STATIC_PATH, job_dict, render

QUEUES: dict[str, Queue] = {}
ROOT_PATH: str = ""


class GZStaticFiles(StaticFiles):
    async def get_response(
        self,
        path: PathLike,
        scope: Scope,
    ) -> Response:
        # Check if the client accepts gzip or Brotli encoding
        headers = Headers(scope=scope)
        encoding = headers.get("accept-encoding", "")

        # Convert path to string if it's not already
        path_str = str(path)

        # Path to the file in the static directory
        full_path = os.path.join(str(self.directory), path_str)

        if "gzip" in encoding and os.path.exists(full_path + ".gz"):
            # Serve Gzip compressed file
            return FileResponse(full_path + ".gz", headers={"Content-Encoding": "gzip"})

        # Serve the original file
        return await super().get_response(path_str, scope)


async def views(request: Request) -> Response:
    return Response(content=render(root_path=ROOT_PATH), media_type="text/html")


async def health(request: Request) -> Response:
    if await _get_all_info():
        return Response(content="OK", media_type="text/plain")
    raise HTTPException(status_code=500)


async def queues_(request: Request) -> JSONResponse:
    queue_name = request.path_params.get("queue")

    response: dict[str, QueueInfo | list[QueueInfo]] = {}

    if queue_name:
        response["queue"] = await _get_queue(queue_name).info(jobs=True)
    else:
        response["queues"] = await _get_all_info()

    return JSONResponse(response)


async def jobs(request: Request) -> JSONResponse:
    queue_name = request.path_params["queue"]
    job_key = request.path_params["job"]

    job = await _get_job(queue_name, job_key)
    return JSONResponse({"job": job_dict(job)})


async def retry(request: Request) -> JSONResponse:
    queue_name = request.path_params["queue"]
    job_key = request.path_params["job"]

    job = await _get_job(queue_name, job_key)
    await job.retry("retried from ui")
    return JSONResponse({})


async def abort(request: Request) -> JSONResponse:
    queue_name = request.path_params["queue"]
    job_key = request.path_params["job"]

    job = await _get_job(queue_name, job_key)
    await job.abort("aborted from ui")
    return JSONResponse({})


async def _get_all_info() -> list[QueueInfo]:
    return [await q.info() for q in QUEUES.values()]


def _get_queue(queue_name: str) -> Queue:
    return QUEUES[queue_name]


async def _get_job(queue_name: str, job_key: str) -> Job:
    job = await _get_queue(queue_name).job(job_key)
    if not job:
        raise ValueError(f"Job {job_key} not found")
    return job


[docs] def saq_web(root_path: str, queues: list[Queue]) -> Starlette: """ Create an embeddable monitoring Web UI Example: .. code-block:: routes = [ Mount("/monitor", saq_web("/monitor", queues=all_the_queues_list)) ] Args: root_path: The absolute mount point, typically the same as where you mount it. queues: The list of known queues Returns: Starlette ASGI instance. """ global ROOT_PATH QUEUES.clear() for queue in queues: QUEUES[queue.name] = queue ROOT_PATH = root_path return Starlette( routes=[ Route("/", views), Route("/queues/{queue}", views), Route("/queues/{queue}/jobs/{job}", views), Route("/api/queues", queues_), Route("/api/queues/{queue}", queues_), Route("/api/queues/{queue}/jobs/{job}", jobs), Route("/api/queues/{queue}/jobs/{job}/retry", retry, methods=["POST"]), Route("/api/queues/{queue}/jobs/{job}/abort", abort, methods=["POST"]), Mount("/static", GZStaticFiles(directory=STATIC_PATH)), Route("/health", health), ] )