fix folder and file save
This commit is contained in:
@@ -70,6 +70,7 @@ class ExecutionEngine:
|
||||
on_overlay: Callable[[str, str], None] | None = None,
|
||||
on_value: Callable[[str, Any], None] | None = None,
|
||||
on_warning: Callable[[str, str], None] | None = None,
|
||||
on_file_download: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, tuple]:
|
||||
"""
|
||||
Execute the workflow described by `prompt`.
|
||||
@@ -100,6 +101,7 @@ class ExecutionEngine:
|
||||
overlay=on_overlay,
|
||||
value=on_value,
|
||||
warning=on_warning,
|
||||
file_download=on_file_download,
|
||||
):
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
|
||||
@@ -20,6 +20,7 @@ _LEGACY_CALLBACK_ATTRS = {
|
||||
"overlay": "_broadcast_overlay_fn",
|
||||
"value": "_broadcast_value_fn",
|
||||
"warning": "_broadcast_warning_fn",
|
||||
"file_download": "_broadcast_file_download_fn",
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +33,7 @@ def execution_callbacks(
|
||||
overlay: Callback | None = None,
|
||||
value: Callback | None = None,
|
||||
warning: Callback | None = None,
|
||||
file_download: Callback | None = None,
|
||||
):
|
||||
token = _callbacks_var.set({
|
||||
"preview": preview,
|
||||
@@ -40,6 +42,7 @@ def execution_callbacks(
|
||||
"overlay": overlay,
|
||||
"value": value,
|
||||
"warning": warning,
|
||||
"file_download": file_download,
|
||||
})
|
||||
try:
|
||||
yield
|
||||
@@ -120,3 +123,7 @@ def emit_value(payload: Any) -> None:
|
||||
|
||||
def emit_warning(message: str) -> None:
|
||||
_emit("warning", message)
|
||||
|
||||
|
||||
def emit_file_download(path: str) -> None:
|
||||
_emit("file_download", path)
|
||||
|
||||
@@ -6,10 +6,14 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tempfile
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.execution_context import emit_warning
|
||||
from backend.execution_context import emit_warning, emit_file_download
|
||||
from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8
|
||||
|
||||
DOWNLOAD_DIR = Path(tempfile.gettempdir()) / "tono-downloads"
|
||||
|
||||
@register_node(display_name="Save")
|
||||
class Save:
|
||||
@classmethod
|
||||
@@ -21,13 +25,6 @@ class Save:
|
||||
"placeholder": "filename",
|
||||
"placement": "top",
|
||||
}),
|
||||
"directory_path": ("FOLDER_PICKER", {
|
||||
"default": "",
|
||||
"label": "directory",
|
||||
"placement": "top",
|
||||
"hide_when_input_connected": "directory",
|
||||
"top_socket_input": "directory",
|
||||
}),
|
||||
"value": ("DATA_FIELD", {
|
||||
"label": "value",
|
||||
"accepted_types": [
|
||||
@@ -56,11 +53,11 @@ class Save:
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"directory": ("DIRECTORY", {"label": "directory"}),
|
||||
"plot_title": ("STRING", {
|
||||
"default": "",
|
||||
"placeholder": "plot title (optional)",
|
||||
"label": "title",
|
||||
"show_when_source_type": {"value": ["LINE"]},
|
||||
}),
|
||||
},
|
||||
}
|
||||
@@ -80,13 +77,11 @@ class Save:
|
||||
def save(
|
||||
self,
|
||||
filename: str,
|
||||
directory_path: str,
|
||||
format: str,
|
||||
value,
|
||||
directory: str | None = None,
|
||||
plot_title: str = "",
|
||||
):
|
||||
path = self._resolve_save_path(filename, format, directory, directory_path)
|
||||
path = self._resolve_save_path(filename, format)
|
||||
|
||||
if isinstance(value, MeshModel):
|
||||
self._save_mesh(path, value, format)
|
||||
@@ -107,15 +102,10 @@ class Save:
|
||||
raise ValueError(f"Save does not support input type: {type(value).__name__}")
|
||||
|
||||
self._send_warning(f"Saved to {path.name}")
|
||||
emit_file_download(str(path))
|
||||
return ()
|
||||
|
||||
def _resolve_save_path(
|
||||
self,
|
||||
filename: str,
|
||||
format_name: str,
|
||||
directory: str | None,
|
||||
directory_path: str = "",
|
||||
) -> Path:
|
||||
def _resolve_save_path(self, filename: str, format_name: str) -> Path:
|
||||
ext_map = {
|
||||
"PNG": ".png",
|
||||
"TIFF": ".tiff",
|
||||
@@ -129,25 +119,16 @@ class Save:
|
||||
ext = ext_map[format_name]
|
||||
|
||||
raw_filename = str(filename).strip() if filename is not None else ""
|
||||
raw_directory = str(directory).strip() if directory is not None else ""
|
||||
if not raw_directory:
|
||||
raw_directory = str(directory_path).strip() if directory_path is not None else ""
|
||||
|
||||
if not raw_filename:
|
||||
raise ValueError("No output filename selected — enter a file name.")
|
||||
|
||||
if raw_directory:
|
||||
dir_path = Path(raw_directory).expanduser()
|
||||
if dir_path.exists() and not dir_path.is_dir():
|
||||
raise ValueError("Directory input expects a folder path, not a file path.")
|
||||
if not dir_path.exists():
|
||||
if dir_path.suffix:
|
||||
raise ValueError("Directory input expects a folder path, not a file path.")
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
path = dir_path / Path(raw_filename).name
|
||||
candidate = Path(raw_filename).expanduser()
|
||||
if candidate.is_absolute():
|
||||
candidate.parent.mkdir(parents=True, exist_ok=True)
|
||||
path = candidate
|
||||
else:
|
||||
path = Path(raw_filename).expanduser()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = DOWNLOAD_DIR / candidate.name
|
||||
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
@@ -156,7 +137,7 @@ class Save:
|
||||
def _save_datafield(self, path: Path, field: DataField, format_name: str):
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), np.asarray(field.data, dtype=np.float32))
|
||||
tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), field=np.asarray(field.data))
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.execution_context import emit_warning
|
||||
from backend.execution_context import emit_warning, emit_file_download
|
||||
from backend.data_types import DataField, image_to_uint8
|
||||
from backend.nodes.helpers import _MAX_SAVE_FIELDS
|
||||
|
||||
@@ -35,9 +35,10 @@ class SaveImage:
|
||||
"placeholder": "filename",
|
||||
"placement": "top",
|
||||
}),
|
||||
"directory_path": ("FOLDER_PICKER", {
|
||||
"directory_path": ("STRING", {
|
||||
"default": "",
|
||||
"label": "directory",
|
||||
"placeholder": "directory (optional, desktop only)",
|
||||
"placement": "top",
|
||||
"hide_when_input_connected": "directory",
|
||||
"top_socket_input": "directory",
|
||||
@@ -92,6 +93,7 @@ class SaveImage:
|
||||
self._save_npz(path, layers, layer_names)
|
||||
|
||||
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
|
||||
emit_file_download(str(path))
|
||||
return ()
|
||||
|
||||
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
|
||||
@@ -140,9 +142,15 @@ class SaveImage:
|
||||
path = dir_path / filename_part
|
||||
else:
|
||||
if not raw_filename:
|
||||
raise ValueError("No output path selected — use Browse to pick a location.")
|
||||
path = Path(raw_filename).expanduser()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
raise ValueError("No output filename selected — enter a file name.")
|
||||
candidate = Path(raw_filename).expanduser()
|
||||
if candidate.is_absolute():
|
||||
candidate.parent.mkdir(parents=True, exist_ok=True)
|
||||
path = candidate
|
||||
else:
|
||||
from backend.nodes.save import DOWNLOAD_DIR
|
||||
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = DOWNLOAD_DIR / candidate.name
|
||||
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
|
||||
@@ -32,6 +32,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import secrets
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
@@ -139,6 +140,7 @@ def create_app(
|
||||
|
||||
session_engines: dict[str, ExecutionEngine] = {}
|
||||
session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set)
|
||||
pending_downloads: dict[str, Path] = {}
|
||||
|
||||
def _is_link(value) -> bool:
|
||||
return (
|
||||
@@ -254,6 +256,12 @@ def create_app(
|
||||
def on_warning(session_id: str, node_id: str, message: str) -> None:
|
||||
broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}})
|
||||
|
||||
def on_file_download(session_id: str, node_id: str, file_path: str) -> None:
|
||||
token = secrets.token_urlsafe(16)
|
||||
path = Path(file_path)
|
||||
pending_downloads[token] = path
|
||||
broadcast(session_id, {"type": "file_download", "data": {"node_id": node_id, "token": token, "filename": path.name}})
|
||||
|
||||
async def index(request: web.Request) -> web.Response:
|
||||
if not getattr(sys, "frozen", False):
|
||||
try:
|
||||
@@ -470,6 +478,16 @@ def create_app(
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
async def download_saved_file(request: web.Request) -> web.Response:
|
||||
token = request.match_info["token"]
|
||||
path = pending_downloads.pop(token, None)
|
||||
if path is None or not path.is_file():
|
||||
raise web.HTTPNotFound(reason="File not found")
|
||||
return web.FileResponse(
|
||||
path,
|
||||
headers={"Content-Disposition": f'attachment; filename="{path.name}"'},
|
||||
)
|
||||
|
||||
async def save_workflow_png(request: web.Request) -> web.Response:
|
||||
body = await request.read()
|
||||
target_path = request.query.get("path", "")
|
||||
@@ -535,6 +553,7 @@ def create_app(
|
||||
on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data),
|
||||
on_value=lambda node_id, payload: on_value(session_id, node_id, payload),
|
||||
on_warning=lambda node_id, message: on_warning(session_id, node_id, message),
|
||||
on_file_download=lambda node_id, file_path: on_file_download(session_id, node_id, file_path),
|
||||
),
|
||||
)
|
||||
broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||
@@ -627,6 +646,7 @@ def create_app(
|
||||
app.router.add_get("/help-docs", get_help_docs)
|
||||
app.router.add_get("/help-docs/{filename}", get_help_doc_file)
|
||||
app.router.add_post("/prompt", submit_prompt)
|
||||
app.router.add_get("/download-save/{token}", download_saved_file)
|
||||
app.router.add_get("/check-update", check_update)
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user