add folder, file nodes and major usability improvements

This commit is contained in:
2026-03-25 22:18:25 -07:00
parent 61b68c142b
commit 7f3dfa8fdf
22 changed files with 3881 additions and 299 deletions

View File

@@ -8,10 +8,11 @@ import json
import sys
import os
import tempfile
from pathlib import Path
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8
from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8, render_datafield_preview
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
@@ -79,6 +80,7 @@ def test_crop_resize_field():
yoff=20.0,
si_unit_xy="nm",
si_unit_z="nm",
overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}],
)
overlays = []
@@ -103,6 +105,7 @@ def test_crop_resize_field():
assert cropped.yoff == 21.0
assert cropped.si_unit_xy == field.si_unit_xy
assert cropped.si_unit_z == field.si_unit_z
assert cropped.overlays == []
assert len(overlays) == 1
assert overlays[0]["kind"] == "crop_box"
assert overlays[0]["image"].startswith("data:image/png;base64,")
@@ -192,6 +195,7 @@ def test_rotate_field():
assert rotated_90.yoff == 19.0
assert rotated_90.si_unit_xy == field.si_unit_xy
assert rotated_90.si_unit_z == field.si_unit_z
assert rotated_90.overlays == []
rotated_180, = node.process(
field,
@@ -224,6 +228,34 @@ def test_rotate_field():
print(" PASS\n")
def test_rotate_field_overlay_warning():
print("=== Test: RotateField overlay warning ===")
from backend.nodes.modify import RotateField
node = RotateField()
warnings = []
RotateField._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
RotateField._current_node_id = "test"
field = DataField(
data=np.arange(16, dtype=np.float64).reshape(4, 4),
overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}],
)
rotated, = node.process(
field,
angle=30.0,
interpolation="bilinear",
expand_canvas=True,
)
assert rotated.overlays == []
assert len(warnings) == 1
assert "clears annotation/markup overlays" in warnings[0]
RotateField._broadcast_warning_fn = None
print(" PASS\n")
def test_colormap_adjust():
print("=== Test: ColormapAdjust ===")
from backend.nodes.modify import ColormapAdjust
@@ -833,16 +865,36 @@ def test_load_file():
result_npy = node.load(filename=path_npy)
assert np.allclose(result_npy[0].data, data_npy)
custom_colormap = {
"mode": "custom",
"stops": [
{"position": 0.0, "color": "#000000"},
{"position": 0.5, "color": "#ff0000"},
{"position": 1.0, "color": "#ffffff"},
],
}
result_custom = node.load(filename=path, colormap_map=custom_colormap)
assert isinstance(result_custom[0].colormap, dict)
assert result_custom[0].colormap["mode"] == "custom"
assert len(result_custom[0].colormap["stops"]) == 3
result_from_path = node.load(filename="", path=path)
assert len(result_from_path) == 1
assert result_from_path[0].data.shape == (48, 64)
print(" PASS\n")
def test_save_image():
print("=== Test: SaveImage (Save Layers) ===")
from backend.nodes.io import SaveImage
import tifffile
node = SaveImage()
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
annotated = np.zeros((24, 24, 3), dtype=np.uint8)
annotated[..., 0] = 255
with tempfile.TemporaryDirectory() as tmpdir:
# Save single layer as TIFF
@@ -861,20 +913,57 @@ def test_save_image():
im2 = Image.open(tiff_path2)
assert im2.n_frames == 2
# Save as NPZ
# Save annotated image as TIFF with layer name
annotated_tiff = os.path.join(tmpdir, "annotated.tiff")
node.save(
filename=annotated_tiff,
format="TIFF",
field_0=annotated,
layer_name_0="annotated overview",
)
with tifffile.TiffFile(annotated_tiff) as tif:
assert len(tif.pages) == 1
assert tif.pages[0].description == "annotated overview"
assert tif.pages[0].asarray().shape == annotated.shape
# Save as NPZ with layer names
npz_path = os.path.join(tmpdir, "out.npz")
node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b)
node.save(
filename=npz_path,
format="NPZ",
field_0=field_a,
field_1=annotated,
layer_name_0="height map",
layer_name_1="annotated-overview",
)
assert os.path.exists(npz_path)
npz = np.load(npz_path)
assert len(npz.files) == 2
assert np.allclose(npz["layer_0"], field_a.data)
assert np.allclose(npz["layer_1"], field_b.data)
assert np.allclose(npz["height_map"], field_a.data)
assert np.array_equal(npz["annotated_overview"], annotated)
# Extension is forced to match format
wrong_ext = os.path.join(tmpdir, "output.png")
node.save(filename=wrong_ext, format="TIFF", field_0=field_a)
assert os.path.exists(os.path.join(tmpdir, "output.tiff"))
# Directory input can drive the destination folder while filename supplies the basename
driven_dir = os.path.join(tmpdir, "nested-output")
node.save(filename="driven_name", directory=driven_dir, format="NPZ", field_0=field_a)
assert os.path.exists(os.path.join(driven_dir, "driven_name.npz"))
# Directory input rejects file paths
try:
node.save(
filename="bad",
directory=os.path.join(tmpdir, "looks_like_file.txt"),
format="TIFF",
field_0=field_a,
)
assert False, "Should have raised ValueError for file-like directory path"
except ValueError:
pass
# No fields connected → error
try:
node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF")
@@ -896,6 +985,50 @@ def test_save_image():
# Display (limited testing — these are output nodes with WS callbacks)
# =========================================================================
def test_color_map_node():
print("=== Test: ColorMap ===")
from backend.nodes.display import ColorMap
node = ColorMap()
preset, = node.build(mode="preset", preset="magma", stops_json="[]")
assert preset["mode"] == "preset"
assert preset["preset"] == "magma"
custom, = node.build(
mode="custom",
preset="viridis",
stops_json=json.dumps([
{"position": 0.0, "color": "#000000"},
{"position": 0.4, "color": "#00ff00"},
{"position": 1.0, "color": "#ffffff"},
]),
)
assert custom["mode"] == "custom"
assert custom["stops"][0]["position"] == 0.0
assert custom["stops"][-1]["position"] == 1.0
assert len(custom["stops"]) == 3
print(" PASS\n")
def test_font_node():
print("=== Test: Font ===")
from backend.nodes.display import Font
from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT
node = Font()
system_default, = node.build(SYSTEM_DEFAULT_FONT)
assert system_default is None
named, = node.build("Arial")
assert named == {"family": "Arial", "path": ""}
custom, = node.build(CUSTOM_FILE_FONT, "/tmp/example-font.ttf")
assert custom == {"family": "", "path": "/tmp/example-font.ttf"}
print(" PASS\n")
def test_preview_image():
print("=== Test: PreviewImage ===")
from backend.nodes.display import PreviewImage
@@ -912,6 +1045,27 @@ def test_preview_image():
assert len(captured) == 1
assert captured[0].startswith("data:image/png;base64,")
# Preview with field overlay metadata
captured.clear()
field_with_overlay = field.replace(overlays=[{"kind": "annotation", "show_scale_bar": True, "show_color_map": False, "text_size": 14.0}])
node.preview(colormap="viridis", field=field_with_overlay)
assert len(captured) == 1
assert captured[0].startswith("data:image/png;base64,")
# Preview with a custom colormap input
captured.clear()
custom_colormap = {
"mode": "custom",
"stops": [
{"position": 0.0, "color": "#000000"},
{"position": 0.5, "color": "#ff0000"},
{"position": 1.0, "color": "#ffffff"},
],
}
node.preview(colormap="auto", field=field, colormap_map=custom_colormap)
assert len(captured) == 1
assert captured[0].startswith("data:image/png;base64,")
# Preview with an IMAGE array
captured.clear()
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
@@ -923,6 +1077,128 @@ def test_preview_image():
print(" PASS\n")
def test_annotations():
print("=== Test: Annotations ===")
from backend.nodes.display import Annotations, Font
node = Annotations()
font_node = Font()
field = DataField(
data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64),
xreal=1e-6,
yreal=1e-6,
si_unit_xy="m",
si_unit_z="V",
colormap="viridis",
)
base = datafield_to_uint8(field, "viridis")
plain_preview = render_datafield_preview(field, "viridis")
assert np.array_equal(plain_preview, base)
plain_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=False)
assert isinstance(plain_field, DataField)
assert np.array_equal(plain_field.data, field.data)
assert plain_field.colormap == "viridis"
assert plain_field.overlays[-1]["kind"] == "annotation"
plain = render_datafield_preview(plain_field, plain_field.colormap)
assert plain.shape == base.shape
assert np.array_equal(plain, base)
with_scale_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=False)
with_scale = render_datafield_preview(with_scale_field, with_scale_field.colormap)
assert with_scale.shape == base.shape
assert not np.array_equal(with_scale, base)
with_legend_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=True)
with_legend = render_datafield_preview(with_legend_field, with_legend_field.colormap)
assert with_legend.shape[0] == base.shape[0]
assert with_legend.shape[1] > base.shape[1]
assert with_legend.shape[2] == 3
larger_legend_field, = node.render(
field,
colormap="auto",
show_scale_bar=False,
show_color_map=True,
text_size=28.0,
)
larger_legend_text = render_datafield_preview(larger_legend_field, larger_legend_field.colormap)
assert larger_legend_text.shape == with_legend.shape
assert not np.array_equal(larger_legend_text, with_legend)
annotation_font, = font_node.build("Arial")
with_font_field, = node.render(
field,
colormap="auto",
show_scale_bar=False,
show_color_map=True,
text_size=28.0,
font=annotation_font,
)
assert with_font_field.overlays[-1]["font"] == {"family": "Arial", "path": ""}
with_font = render_datafield_preview(with_font_field, with_font_field.colormap)
assert with_font.shape == with_legend.shape
with_both_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=True)
with_both = render_datafield_preview(with_both_field, with_both_field.colormap)
assert with_both.shape == with_legend.shape
assert not np.array_equal(with_both[:, :base.shape[1]], base)
print(" PASS\n")
def test_markup():
print("=== Test: Markup ===")
from backend.nodes.display import Markup
from backend.data_types import _preview_markup_stroke_width
node = Markup()
field = make_field(data=np.linspace(0.0, 1.0, 48 * 48, dtype=np.float64).reshape(48, 48))
base = render_datafield_preview(field, field.colormap)
assert _preview_markup_stroke_width(5, 128, 128) == 5
assert _preview_markup_stroke_width(5, 2048, 2048) > 5
overlays = []
Markup._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
Markup._current_node_id = "test"
plain_field, = node.process(
field=field,
shape="line",
stroke_color="#ffd54f",
stroke_width=3,
markup_shapes="[]",
)
assert isinstance(plain_field, DataField)
assert plain_field.overlays[-1]["kind"] == "markup"
plain = render_datafield_preview(plain_field, plain_field.colormap)
assert np.array_equal(plain, base)
assert overlays[-1]["kind"] == "markup"
assert overlays[-1]["image"].startswith("data:image/png;base64,")
shapes = json.dumps([
{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 3, "color": "#ff0000"},
{"kind": "rectangle", "x1": 0.2, "y1": 0.2, "x2": 0.8, "y2": 0.5, "width": 2, "color": "#00ff00"},
{"kind": "circle", "x1": 0.25, "y1": 0.55, "x2": 0.55, "y2": 0.85, "width": 2, "color": "#4fc3f7"},
{"kind": "arrow", "x1": 0.15, "y1": 0.85, "x2": 0.85, "y2": 0.2, "width": 4, "color": "#ffffff"},
])
marked_field, = node.process(
field=field,
shape="arrow",
stroke_color="#ffffff",
stroke_width=4,
markup_shapes=shapes,
)
marked = render_datafield_preview(marked_field, marked_field.colormap)
assert marked.shape == base.shape
assert not np.array_equal(marked, base)
Markup._broadcast_overlay_fn = None
print(" PASS\n")
def test_print_table():
print("=== Test: PrintTable ===")
from backend.nodes.display import PrintTable
@@ -1086,7 +1362,8 @@ def test_load_file_warning():
def test_list_channels():
print("=== Test: list_channels ===")
from backend.nodes.io import list_channels
from backend.nodes.io import list_channels, list_folder_paths, Folder
from PIL import Image
# Non-existent file → default
ch = list_channels("/nonexistent/file.ibw")
@@ -1105,7 +1382,6 @@ def test_list_channels():
# Plain image → single default channel
with tempfile.TemporaryDirectory() as tmpdir:
from PIL import Image
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
path = os.path.join(tmpdir, "test.png")
img.save(path)
@@ -1122,6 +1398,32 @@ def test_list_channels():
ch = list_channels(path)
assert len(ch) == 1
with tempfile.TemporaryDirectory() as tmpdir:
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
png_path = os.path.join(tmpdir, "a.png")
npy_path = os.path.join(tmpdir, "b.npy")
gwy_path = os.path.join(tmpdir, "c.gwy")
sxm_path = os.path.join(tmpdir, "d.sxm")
ibw_path = os.path.join(tmpdir, "e.ibw")
txt_path = os.path.join(tmpdir, "notes.txt")
img.save(png_path)
np.save(npy_path, np.zeros((4, 4)))
Path(gwy_path).write_bytes(b"gwy")
Path(sxm_path).write_bytes(b"sxm")
Path(ibw_path).write_bytes(b"ibw")
with open(txt_path, "w", encoding="utf-8") as fh:
fh.write("ignore me")
paths = list_folder_paths(tmpdir)
assert [entry["name"] for entry in paths] == ["directory", "a.png", "b.npy", "c.gwy", "d.sxm", "e.ibw"]
assert Path(paths[0]["path"]).resolve() == Path(tmpdir).resolve()
assert paths[0]["type"] == "DIRECTORY"
assert all(entry["type"] == "FILE_PATH" for entry in paths[1:])
folder_node = Folder()
folder_result = folder_node.list_files(tmpdir)
assert folder_result == tuple(entry["path"] for entry in paths)
print(" PASS\n")
@@ -1157,6 +1459,35 @@ def test_load_demo():
print(" PASS\n")
def test_load_demo_multi_layer_preview_payload():
print("=== Test: LoadDemo multi-layer preview payload ===")
from backend.execution import ExecutionEngine
import backend.nodes # noqa: F401
previews = []
prompt = {
"1": {
"class_type": "LoadDemo",
"inputs": {
"name": "whiskers.ibw",
"colormap": "viridis",
},
},
}
ExecutionEngine().execute(prompt, on_preview=lambda node_id, payload: previews.append((node_id, payload)))
assert len(previews) == 1
node_id, payload = previews[0]
assert node_id == "1"
assert payload["kind"] == "layer_gallery"
assert len(payload["layers"]) == 4
assert all(isinstance(layer["name"], str) and layer["name"] for layer in payload["layers"])
assert all(layer["image"].startswith("data:image/png;base64,") for layer in payload["layers"])
print(" PASS\n")
# =========================================================================
# I/O — Coordinate
# =========================================================================
@@ -1181,6 +1512,25 @@ def test_coordinate():
print(" PASS\n")
# =========================================================================
# I/O — Number
# =========================================================================
def test_number():
print("=== Test: Number ===")
from backend.nodes.io import Number
node = Number()
result = node.process(value=1.25)
assert result == (1.25,)
result_neg = node.process(value=-3.5)
assert result_neg == (-3.5,)
print(" PASS\n")
def test_range_slider():
print("=== Test: RangeSlider ===")
from backend.nodes.io import RangeSlider
@@ -1205,6 +1555,62 @@ def test_range_slider():
print(" PASS\n")
def test_execution_engine_numeric_socket_coercion():
print("=== Test: ExecutionEngine numeric socket coercion ===")
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
@register_node(display_name="Test Echo Int")
class TestEchoInt:
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("INT",)}}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
return (value,)
@register_node(display_name="Test Echo Float")
class TestEchoFloat:
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
return (value,)
engine = ExecutionEngine()
prompt = {
"1": {
"class_type": "Number",
"inputs": {"value": 3.6},
},
"2": {
"class_type": "TestEchoInt",
"inputs": {"value": ["1", 0]},
},
"3": {
"class_type": "TestEchoFloat",
"inputs": {"value": ["1", 0]},
},
}
outputs = engine.execute(prompt)
assert outputs["2"] == (4,)
assert outputs["3"] == (3.6,)
print(" PASS\n")
# =========================================================================
# Analysis — LineCursors
# =========================================================================