rework web server so multiple clients can be server at a time
This commit is contained in:
@@ -6,7 +6,11 @@ Routes
|
||||
GET / → serve frontend/index.html
|
||||
GET /static/{path} → serve frontend JS/CSS
|
||||
GET /nodes → JSON dict of all registered node definitions
|
||||
POST /upload → multipart file upload to input/
|
||||
GET /files → list files in the current session upload workspace
|
||||
GET /folder-files → list compatible files in a picked folder
|
||||
GET /channels → inspect channels for a picked file
|
||||
POST /upload → multipart file upload to the current session workspace
|
||||
POST /upload-folder → create a folder in the current session workspace
|
||||
POST /prompt → submit a workflow; returns {prompt_id}
|
||||
GET /ws → WebSocket upgrade
|
||||
|
||||
@@ -15,7 +19,7 @@ WebSocket message types sent to clients
|
||||
{"type": "execution_start", "data": {"prompt_id": "..."}}
|
||||
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
|
||||
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
|
||||
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
|
||||
{"type": "table", "data": {"node_id": "...", "rows": [...]} }
|
||||
{"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}}
|
||||
{"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}}
|
||||
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
|
||||
@@ -23,39 +27,43 @@ WebSocket message types sent to clients
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready
|
||||
from backend.runtime_paths import (
|
||||
ensure_runtime_dirs,
|
||||
frontend_dir,
|
||||
frontend_dist_dir,
|
||||
input_dir,
|
||||
output_dir,
|
||||
project_root,
|
||||
from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root
|
||||
from backend.session_runtime import (
|
||||
PATH_INPUT_TYPES,
|
||||
SESSION_HEADER,
|
||||
SESSION_QUERY,
|
||||
ensure_session_runtime_dirs,
|
||||
normalize_relative_upload_path,
|
||||
resolve_client_path,
|
||||
server_path_to_client_path,
|
||||
session_input_dir,
|
||||
session_upload_uri,
|
||||
validate_session_id,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FRONTEND_DIR = frontend_dir()
|
||||
DIST_DIR = frontend_dist_dir()
|
||||
INPUT_DIR = input_dir()
|
||||
OUTPUT_DIR = output_dir()
|
||||
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON helper — numpy scalars are not serialisable by default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SafeEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
import numpy as np
|
||||
|
||||
if isinstance(obj, (np.integer,)):
|
||||
return int(obj)
|
||||
if isinstance(obj, (np.floating,)):
|
||||
@@ -81,45 +89,115 @@ def save_png_bytes(target_path: str, payload: bytes) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
# Import nodes to trigger registration decorators
|
||||
def create_app(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
*,
|
||||
allow_local_filesystem: bool = False,
|
||||
) -> web.Application:
|
||||
import backend.nodes # noqa: F401
|
||||
from backend.node_registry import get_all_node_info
|
||||
from backend.execution import ExecutionEngine, new_prompt_id
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info
|
||||
|
||||
ensure_runtime_dirs()
|
||||
|
||||
engine = ExecutionEngine()
|
||||
websockets: set[web.WebSocketResponse] = set()
|
||||
session_engines: dict[str, ExecutionEngine] = {}
|
||||
session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket broadcast helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _is_link(value) -> bool:
|
||||
return (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) == 2
|
||||
and isinstance(value[0], str)
|
||||
and isinstance(value[1], int)
|
||||
)
|
||||
|
||||
def broadcast(msg: dict) -> None:
|
||||
"""Schedule a broadcast to all connected WebSocket clients."""
|
||||
def require_session_id(request: web.Request) -> str:
|
||||
raw_session = request.headers.get(SESSION_HEADER) or request.query.get(SESSION_QUERY)
|
||||
if not raw_session:
|
||||
if allow_local_filesystem:
|
||||
raw_session = "desktop-local-session"
|
||||
else:
|
||||
raise web.HTTPBadRequest(reason="Missing session id")
|
||||
|
||||
try:
|
||||
session_id = validate_session_id(raw_session)
|
||||
except ValueError as exc:
|
||||
raise web.HTTPBadRequest(reason=str(exc)) from exc
|
||||
|
||||
ensure_session_runtime_dirs(session_id)
|
||||
return session_id
|
||||
|
||||
def get_session_engine(session_id: str) -> ExecutionEngine:
|
||||
engine = session_engines.get(session_id)
|
||||
if engine is None:
|
||||
engine = ExecutionEngine()
|
||||
session_engines[session_id] = engine
|
||||
return engine
|
||||
|
||||
def resolve_request_path(session_id: str, raw_value: str) -> Path:
|
||||
try:
|
||||
return resolve_client_path(
|
||||
raw_value,
|
||||
session_id=session_id,
|
||||
allow_local_filesystem=allow_local_filesystem,
|
||||
)
|
||||
except PermissionError as exc:
|
||||
raise web.HTTPForbidden(reason=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise web.HTTPBadRequest(reason=str(exc)) from exc
|
||||
|
||||
def rewrite_prompt_paths(prompt: dict, session_id: str) -> dict:
|
||||
normalized = deepcopy(prompt)
|
||||
for node_def in normalized.values():
|
||||
class_name = node_def.get("class_type")
|
||||
cls = NODE_CLASS_MAPPINGS.get(class_name)
|
||||
if cls is None:
|
||||
continue
|
||||
|
||||
input_types = cls.INPUT_TYPES()
|
||||
specs = {}
|
||||
specs.update(input_types.get("required", {}))
|
||||
specs.update(input_types.get("optional", {}))
|
||||
|
||||
inputs = node_def.get("inputs", {})
|
||||
if not isinstance(inputs, dict):
|
||||
continue
|
||||
|
||||
for input_name, raw_value in list(inputs.items()):
|
||||
if _is_link(raw_value) or not isinstance(raw_value, str):
|
||||
continue
|
||||
if not raw_value.strip():
|
||||
continue
|
||||
|
||||
spec = specs.get(input_name)
|
||||
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
|
||||
if not isinstance(input_type, str):
|
||||
continue
|
||||
if input_type not in PATH_INPUT_TYPES:
|
||||
continue
|
||||
|
||||
inputs[input_name] = str(resolve_request_path(session_id, raw_value))
|
||||
return normalized
|
||||
|
||||
def broadcast(session_id: str, msg: dict) -> None:
|
||||
payload = _dumps(msg)
|
||||
for ws in list(websockets):
|
||||
for ws in list(session_websockets.get(session_id, ())):
|
||||
if not ws.closed:
|
||||
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
|
||||
|
||||
def on_preview(node_id: str, data_uri: str) -> None:
|
||||
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
|
||||
def on_preview(session_id: str, node_id: str, data_uri: str) -> None:
|
||||
broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
|
||||
|
||||
def on_table(node_id: str, rows: list) -> None:
|
||||
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
|
||||
def on_table(session_id: str, node_id: str, rows: list) -> None:
|
||||
broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}})
|
||||
|
||||
def on_mesh(node_id: str, mesh_data: dict) -> None:
|
||||
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
|
||||
def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None:
|
||||
broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
|
||||
|
||||
def on_overlay(node_id: str, overlay_data) -> None:
|
||||
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||
def on_overlay(session_id: str, node_id: str, overlay_data) -> None:
|
||||
broadcast(session_id, {"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||
|
||||
def on_value(node_id: str, payload) -> None:
|
||||
def on_value(session_id: str, node_id: str, payload) -> None:
|
||||
if isinstance(payload, dict):
|
||||
value = payload.get("value")
|
||||
unit = payload.get("unit", "")
|
||||
@@ -130,14 +208,10 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
data = {"node_id": node_id, "value": value}
|
||||
if isinstance(unit, str) and unit.strip():
|
||||
data["unit"] = unit.strip()
|
||||
broadcast({"type": "scalar", "data": data})
|
||||
broadcast(session_id, {"type": "scalar", "data": data})
|
||||
|
||||
def on_warning(node_id: str, message: str) -> None:
|
||||
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ------------------------------------------------------------------
|
||||
def on_warning(session_id: str, node_id: str, message: str) -> None:
|
||||
broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}})
|
||||
|
||||
async def index(request: web.Request) -> web.Response:
|
||||
if not getattr(sys, "frozen", False):
|
||||
@@ -167,88 +241,96 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
)
|
||||
|
||||
async def get_nodes(request: web.Request) -> web.Response:
|
||||
info = get_all_node_info()
|
||||
return web.Response(
|
||||
text=_dumps(info),
|
||||
text=_dumps(get_all_node_info()),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def list_files(request: web.Request) -> web.Response:
|
||||
"""List files in the input/ directory for the file picker widget."""
|
||||
session_id = require_session_id(request)
|
||||
input_path = session_input_dir(session_id)
|
||||
files = sorted(
|
||||
f.name for f in INPUT_DIR.iterdir()
|
||||
if f.is_file() and not f.name.startswith(".")
|
||||
) if INPUT_DIR.exists() else []
|
||||
server_path_to_client_path(entry, session_id)
|
||||
for entry in input_path.iterdir()
|
||||
if entry.is_file() and not entry.name.startswith(".")
|
||||
) if input_path.exists() else []
|
||||
return web.Response(text=_dumps(files), content_type="application/json")
|
||||
|
||||
async def browse_dir(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Server-side directory browser for local file picking.
|
||||
GET /browse?dir=/some/path → {parent, dirs[], files[]}
|
||||
"""
|
||||
dir_path = request.query.get("dir", str(Path.home()))
|
||||
p = Path(dir_path).expanduser().resolve()
|
||||
|
||||
if not p.is_dir():
|
||||
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
try:
|
||||
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
|
||||
if entry.name.startswith("."):
|
||||
continue
|
||||
if entry.is_dir():
|
||||
dirs.append(entry.name)
|
||||
elif entry.is_file():
|
||||
files.append(entry.name)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
async def create_upload_folder(request: web.Request) -> web.Response:
|
||||
session_id = require_session_id(request)
|
||||
body = await request.json()
|
||||
relative_path = normalize_relative_upload_path(body.get("path", ""))
|
||||
target = session_input_dir(session_id) / Path(relative_path.as_posix())
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
return web.Response(
|
||||
text=_dumps({
|
||||
"path": str(p),
|
||||
"parent": str(p.parent) if p.parent != p else None,
|
||||
"dirs": dirs,
|
||||
"files": files,
|
||||
}),
|
||||
text=_dumps({"path": session_upload_uri(relative_path)}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def get_folder_files(request: web.Request) -> web.Response:
|
||||
folder_path = request.query.get("folder", "")
|
||||
from backend.nodes.helpers import list_folder_paths
|
||||
loop = asyncio.get_running_loop()
|
||||
entries = await loop.run_in_executor(None, list_folder_paths, folder_path)
|
||||
return web.Response(text=_dumps(entries), content_type="application/json")
|
||||
|
||||
session_id = require_session_id(request)
|
||||
folder_path = request.query.get("folder", "")
|
||||
if not folder_path:
|
||||
return web.Response(text=_dumps([]), content_type="application/json")
|
||||
|
||||
resolved_path = resolve_request_path(session_id, folder_path)
|
||||
running_loop = asyncio.get_running_loop()
|
||||
entries = await running_loop.run_in_executor(None, list_folder_paths, str(resolved_path))
|
||||
|
||||
payload = []
|
||||
for entry in entries:
|
||||
mapped = dict(entry)
|
||||
if "path" in mapped:
|
||||
mapped["path"] = server_path_to_client_path(mapped["path"], session_id)
|
||||
payload.append(mapped)
|
||||
return web.Response(text=_dumps(payload), content_type="application/json")
|
||||
|
||||
async def upload_file(request: web.Request) -> web.Response:
|
||||
session_id = require_session_id(request)
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "file":
|
||||
relative_path = None
|
||||
filename = ""
|
||||
file_bytes = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
if field.name == "relative_path":
|
||||
relative_path = await field.text()
|
||||
continue
|
||||
if field.name == "file":
|
||||
filename = Path(field.filename or "upload.bin").name
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = await field.read_chunk(65536)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
file_bytes = b"".join(chunks)
|
||||
|
||||
if file_bytes is None:
|
||||
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
|
||||
|
||||
filename = Path(field.filename).name # strip any path traversal
|
||||
dest = INPUT_DIR / filename
|
||||
with open(dest, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(65536)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
relative = normalize_relative_upload_path(relative_path or filename)
|
||||
dest = session_input_dir(session_id) / Path(relative.as_posix())
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(file_bytes)
|
||||
|
||||
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
|
||||
return web.Response(
|
||||
text=_dumps({"filename": filename, "path": session_upload_uri(relative)}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def download_file(request: web.Request) -> web.Response:
|
||||
"""Accept a blob POST and return it with Content-Disposition: attachment."""
|
||||
body = await request.read()
|
||||
filename = request.query.get("filename", "workflow.png")
|
||||
return web.Response(
|
||||
body=body,
|
||||
content_type="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
},
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
async def save_workflow_png(request: web.Request) -> web.Response:
|
||||
@@ -266,34 +348,39 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
)
|
||||
|
||||
async def get_channels(request: web.Request) -> web.Response:
|
||||
"""Return available channels for a given file path."""
|
||||
from backend.nodes.helpers import list_channels
|
||||
|
||||
session_id = require_session_id(request)
|
||||
filepath = request.query.get("file", "")
|
||||
if not filepath:
|
||||
return web.Response(
|
||||
text=_dumps([{"name": "field", "type": "DATA_FIELD"}]),
|
||||
content_type="application/json",
|
||||
)
|
||||
channels = await loop.run_in_executor(None, list_channels, filepath)
|
||||
|
||||
resolved_path = resolve_request_path(session_id, filepath)
|
||||
channels = await loop.run_in_executor(None, list_channels, str(resolved_path))
|
||||
return web.Response(text=_dumps(channels), content_type="application/json")
|
||||
|
||||
async def submit_prompt(request: web.Request) -> web.Response:
|
||||
session_id = require_session_id(request)
|
||||
body = await request.json()
|
||||
prompt = body.get("prompt")
|
||||
if not isinstance(prompt, dict) or not prompt:
|
||||
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
|
||||
|
||||
normalized_prompt = rewrite_prompt_paths(prompt, session_id)
|
||||
prompt_id = new_prompt_id()
|
||||
engine = get_session_engine(session_id)
|
||||
|
||||
# Run execution in a thread pool so scipy doesn't block the event loop
|
||||
async def run():
|
||||
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
|
||||
broadcast(session_id, {"type": "execution_start", "data": {"prompt_id": prompt_id}})
|
||||
|
||||
def on_start(node_id: str) -> None:
|
||||
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
|
||||
broadcast(session_id, {"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
|
||||
|
||||
def on_done(node_id: str, elapsed_ms: float) -> None:
|
||||
broadcast({
|
||||
broadcast(session_id, {
|
||||
"type": "node_timing",
|
||||
"data": {"node_id": node_id, "elapsed_ms": elapsed_ms},
|
||||
})
|
||||
@@ -302,21 +389,21 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: engine.execute(
|
||||
prompt,
|
||||
normalized_prompt,
|
||||
on_node_start=on_start,
|
||||
on_node_done=on_done,
|
||||
on_preview=on_preview,
|
||||
on_table=on_table,
|
||||
on_mesh=on_mesh,
|
||||
on_overlay=on_overlay,
|
||||
on_value=on_value,
|
||||
on_warning=on_warning,
|
||||
on_preview=lambda node_id, payload: on_preview(session_id, node_id, payload),
|
||||
on_table=lambda node_id, rows: on_table(session_id, node_id, rows),
|
||||
on_mesh=lambda node_id, mesh_data: on_mesh(session_id, node_id, mesh_data),
|
||||
on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data),
|
||||
on_value=lambda node_id, payload: on_value(session_id, node_id, payload),
|
||||
on_warning=lambda node_id, message: on_warning(session_id, node_id, message),
|
||||
),
|
||||
)
|
||||
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||
broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||
except Exception as exc:
|
||||
log.exception("Execution error")
|
||||
broadcast({
|
||||
broadcast(session_id, {
|
||||
"type": "execution_error",
|
||||
"data": {"node_id": "", "message": str(exc)},
|
||||
})
|
||||
@@ -328,32 +415,40 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
)
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
session_id = require_session_id(request)
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
websockets.add(ws)
|
||||
log.info("WebSocket client connected (%d total)", len(websockets))
|
||||
session_websockets[session_id].add(ws)
|
||||
log.info(
|
||||
"WebSocket client connected for session %s (%d total in session)",
|
||||
session_id,
|
||||
len(session_websockets[session_id]),
|
||||
)
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
pass # clients don't need to send anything currently
|
||||
pass
|
||||
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
|
||||
break
|
||||
finally:
|
||||
websockets.discard(ws)
|
||||
log.info("WebSocket client disconnected (%d total)", len(websockets))
|
||||
session_websockets[session_id].discard(ws)
|
||||
if not session_websockets[session_id]:
|
||||
session_websockets.pop(session_id, None)
|
||||
log.info(
|
||||
"WebSocket client disconnected for session %s (%d remaining in session)",
|
||||
session_id,
|
||||
len(session_websockets.get(session_id, ())),
|
||||
)
|
||||
return ws
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# App assembly
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
app = web.Application()
|
||||
app["allow_local_filesystem"] = allow_local_filesystem
|
||||
|
||||
app.router.add_get("/", index)
|
||||
app.router.add_get("/nodes", get_nodes)
|
||||
app.router.add_get("/files", list_files)
|
||||
app.router.add_get("/browse", browse_dir)
|
||||
app.router.add_get("/folder-files", get_folder_files)
|
||||
app.router.add_post("/upload-folder", create_upload_folder)
|
||||
app.router.add_post("/upload", upload_file)
|
||||
app.router.add_post("/download", download_file)
|
||||
app.router.add_post("/save-workflow-png", save_workflow_png)
|
||||
@@ -361,26 +456,24 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
app.router.add_post("/prompt", submit_prompt)
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
|
||||
# Serve frontend static files (Vite build or raw)
|
||||
if (DIST_DIR / "assets").exists():
|
||||
app.router.add_static("/assets", DIST_DIR / "assets")
|
||||
if FRONTEND_DIR.exists():
|
||||
app.router.add_static("/static", FRONTEND_DIR)
|
||||
|
||||
# CORS — allow any origin (local dev only)
|
||||
async def _cors_middleware(app_, handler):
|
||||
async def middleware(request):
|
||||
if request.method == "OPTIONS":
|
||||
return web.Response(headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Allow-Headers": f"Content-Type, {SESSION_HEADER}",
|
||||
})
|
||||
response = await handler(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
return response
|
||||
|
||||
return middleware
|
||||
|
||||
app.middlewares.append(_cors_middleware)
|
||||
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user