refactor nodes into standalone file

This commit is contained in:
2026-03-26 19:50:03 -07:00
parent 711d7995b3
commit de0b49acc5
54 changed files with 3615 additions and 3710 deletions

View File

@@ -28,7 +28,7 @@ def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
def test_gaussian_filter():
print("=== Test: GaussianFilter ===")
from backend.nodes.filters import GaussianFilter
from backend.nodes.gaussian_filter import GaussianFilter
node = GaussianFilter()
field = make_field()
@@ -46,7 +46,7 @@ def test_gaussian_filter():
def test_median_filter():
print("=== Test: MedianFilter ===")
from backend.nodes.filters import MedianFilter
from backend.nodes.median_filter import MedianFilter
node = MedianFilter()
# Median filter should remove salt-and-pepper noise
@@ -68,7 +68,7 @@ def test_median_filter():
def test_crop_resize_field():
print("=== Test: CropResizeField ===")
from backend.nodes.modify import CropResizeField
from backend.nodes.crop_resize_field import CropResizeField
node = CropResizeField()
data = np.arange(32, dtype=np.float64).reshape(4, 8)
@@ -167,7 +167,7 @@ def test_crop_resize_field():
def test_rotate_field():
print("=== Test: RotateField ===")
from backend.nodes.modify import RotateField
from backend.nodes.rotate_field import RotateField
node = RotateField()
data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
@@ -230,7 +230,7 @@ def test_rotate_field():
def test_rotate_field_overlay_warning():
print("=== Test: RotateField overlay warning ===")
from backend.nodes.modify import RotateField
from backend.nodes.rotate_field import RotateField
node = RotateField()
warnings = []
@@ -258,7 +258,7 @@ def test_rotate_field_overlay_warning():
def test_colormap_adjust():
print("=== Test: ColormapAdjust ===")
from backend.nodes.modify import ColormapAdjust
from backend.nodes.colormap_adjust import ColormapAdjust
node = ColormapAdjust()
field = DataField(
@@ -299,7 +299,7 @@ def test_colormap_adjust():
def test_edge_detect():
print("=== Test: EdgeDetect ===")
from backend.nodes.filters import EdgeDetect
from backend.nodes.edge_detect import EdgeDetect
node = EdgeDetect()
# Create an image with a sharp vertical edge
@@ -320,7 +320,7 @@ def test_edge_detect():
def test_fft_filter_1d():
print("=== Test: FFTFilter1D ===")
from backend.nodes.filters import FFTFilter1D
from backend.nodes.fft_filter_1d import FFTFilter1D
node = FFTFilter1D()
# Signal: low-frequency sine + high-frequency sine
@@ -364,7 +364,7 @@ def test_fft_filter_1d():
def test_fft_filter_2d():
print("=== Test: FFTFilter2D ===")
from backend.nodes.filters import FFTFilter2D
from backend.nodes.fft_filter_2d import FFTFilter2D
node = FFTFilter2D()
N = 128
@@ -406,7 +406,7 @@ def test_fft_filter_2d():
def test_plane_level():
print("=== Test: PlaneLevelField ===")
from backend.nodes.level import PlaneLevelField
from backend.nodes.plane_level_field import PlaneLevelField
node = PlaneLevelField()
# Create a tilted plane + small signal
@@ -428,7 +428,7 @@ def test_plane_level():
def test_poly_level():
print("=== Test: PolyLevelField ===")
from backend.nodes.level import PolyLevelField
from backend.nodes.poly_level_field import PolyLevelField
node = PolyLevelField()
N = 64
@@ -455,7 +455,7 @@ def test_poly_level():
def test_fix_zero():
print("=== Test: FixZero ===")
from backend.nodes.level import FixZero
from backend.nodes.fix_zero import FixZero
node = FixZero()
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
@@ -477,7 +477,7 @@ def test_fix_zero():
def test_statistics():
print("=== Test: Statistics ===")
from backend.nodes.analysis import Statistics
from backend.nodes.statistics_node import Statistics
node = Statistics()
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
@@ -507,7 +507,7 @@ def test_statistics():
def test_height_histogram():
print("=== Test: Histogram ===")
from backend.nodes.analysis import Histogram
from backend.nodes.histogram import Histogram
node = Histogram()
# Uniform data should give a roughly flat histogram
@@ -556,7 +556,7 @@ def test_height_histogram():
def test_cross_section():
print("=== Test: CrossSection ===")
from backend.nodes.analysis import CrossSection
from backend.nodes.cross_section import CrossSection
node = CrossSection()
# Create a field with a known horizontal gradient
@@ -604,7 +604,8 @@ def test_cross_section():
)
assert len(profile_diag) == 50
from backend.nodes.analysis import Cursors, Stats
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
cursors = Cursors()
table, _ = cursors.process(profile, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
@@ -630,7 +631,7 @@ def test_cross_section():
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
from backend.nodes.mask import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
@@ -673,7 +674,7 @@ def test_threshold_mask():
def test_mask_morphology():
print("=== Test: MaskMorphology ===")
from backend.nodes.mask import MaskMorphology
from backend.nodes.mask_morphology import MaskMorphology
node = MaskMorphology()
# Small square blob in the centre
@@ -710,7 +711,7 @@ def test_mask_morphology():
def test_mask_invert():
print("=== Test: MaskInvert ===")
from backend.nodes.mask import MaskInvert
from backend.nodes.mask_invert import MaskInvert
node = MaskInvert()
mask = np.zeros((64, 64), dtype=np.uint8)
@@ -729,7 +730,7 @@ def test_mask_invert():
def test_mask_combine():
print("=== Test: MaskCombine ===")
from backend.nodes.mask import MaskCombine
from backend.nodes.mask_combine import MaskCombine
node = MaskCombine()
# Two overlapping squares
@@ -768,7 +769,7 @@ def test_mask_combine():
def test_draw_mask():
print("=== Test: DrawMask ===")
from backend.nodes.mask import DrawMask
from backend.nodes.draw_mask import DrawMask
node = DrawMask()
field = make_field(data=np.zeros((32, 32), dtype=np.float64))
@@ -815,7 +816,7 @@ def test_draw_mask():
def test_particle_analysis():
print("=== Test: ParticleAnalysis ===")
from backend.nodes.particless import ParticleAnalysis
from backend.nodes.particle_analysis import ParticleAnalysis
node = ParticleAnalysis()
# Create a field with two distinct particles
@@ -855,7 +856,7 @@ def test_particle_analysis():
def test_load_file():
print("=== Test: Image ===")
from backend.nodes.io import Image as ImageNode
from backend.nodes.image import Image as ImageNode
from PIL import Image as PILImage
node = ImageNode()
@@ -912,7 +913,7 @@ def test_load_file():
def test_save_image():
print("=== Test: SaveImage (Save Layers) ===")
from backend.nodes.io import SaveImage
from backend.nodes.save_image import SaveImage
import tifffile
node = SaveImage()
@@ -1012,7 +1013,7 @@ def test_save_image():
def test_color_map_node():
print("=== Test: ColorMap ===")
from backend.nodes.display import ColorMap
from backend.nodes.color_map import ColorMap
node = ColorMap()
@@ -1038,7 +1039,7 @@ def test_color_map_node():
def test_font_node():
print("=== Test: Font ===")
from backend.nodes.display import Font
from backend.nodes.font_node import Font
from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT
node = Font()
@@ -1056,7 +1057,7 @@ def test_font_node():
def test_preview_image():
print("=== Test: PreviewImage ===")
from backend.nodes.display import PreviewImage
from backend.nodes.preview_image import PreviewImage
node = PreviewImage()
# Set up a capture for the broadcast
@@ -1104,7 +1105,8 @@ def test_preview_image():
def test_annotations():
print("=== Test: Annotations ===")
from backend.nodes.display import Annotations, Font
from backend.nodes.annotations import Annotations
from backend.nodes.font_node import Font
node = Annotations()
font_node = Font()
@@ -1175,7 +1177,7 @@ def test_annotations():
def test_markup():
print("=== Test: Markup ===")
from backend.nodes.display import Markup
from backend.nodes.markup import Markup
from backend.data_types import _preview_markup_stroke_width
node = Markup()
@@ -1226,7 +1228,7 @@ def test_markup():
def test_print_table():
print("=== Test: PrintTable ===")
from backend.nodes.display import PrintTable
from backend.nodes.print_table import PrintTable
node = PrintTable()
captured = []
@@ -1244,7 +1246,7 @@ def test_print_table():
def test_value_display():
print("=== Test: ValueDisplay ===")
from backend.nodes.display import ValueDisplay
from backend.nodes.value_display import ValueDisplay
node = ValueDisplay()
captured = []
@@ -1273,7 +1275,7 @@ def test_value_display():
def test_load_file_ibw():
print("=== Test: Image IBW multi-channel ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
@@ -1309,7 +1311,7 @@ def test_load_file_ibw():
def test_load_file_npz():
print("=== Test: Image .npz ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1326,7 +1328,7 @@ def test_load_file_npz():
def test_load_file_not_found():
print("=== Test: Image not found ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
try:
@@ -1340,7 +1342,7 @@ def test_load_file_not_found():
def test_load_file_unsupported():
print("=== Test: Image unsupported format ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1358,7 +1360,7 @@ def test_load_file_unsupported():
def test_load_file_warning():
print("=== Test: Image warning for uncalibrated data ===")
from backend.nodes.io import Image as ImageNode
from backend.nodes.image import Image as ImageNode
from PIL import Image as PILImage
node = ImageNode()
@@ -1387,7 +1389,8 @@ def test_load_file_warning():
def test_list_channels():
print("=== Test: list_channels ===")
from backend.nodes.io import list_channels, list_folder_paths, Folder
from backend.nodes.helpers import list_channels, list_folder_paths
from backend.nodes.folder import Folder
from PIL import Image
# Non-existent file → default
@@ -1458,7 +1461,7 @@ def test_list_channels():
def test_load_demo():
print("=== Test: ImageDemo ===")
from backend.nodes.io import ImageDemo
from backend.nodes.image_demo import ImageDemo
node = ImageDemo()
@@ -1519,7 +1522,7 @@ def test_load_demo_multi_layer_preview_payload():
def test_coordinate():
print("=== Test: Coordinate ===")
from backend.nodes.io import Coordinate
from backend.nodes.coordinate import Coordinate
node = Coordinate()
@@ -1543,7 +1546,7 @@ def test_coordinate():
def test_number():
print("=== Test: Number ===")
from backend.nodes.io import Number
from backend.nodes.number import Number
node = Number()
@@ -1558,7 +1561,7 @@ def test_number():
def test_range_slider():
print("=== Test: RangeSlider ===")
from backend.nodes.io import RangeSlider
from backend.nodes.range_slider import RangeSlider
node = RangeSlider()
@@ -1642,7 +1645,7 @@ def test_execution_engine_numeric_socket_coercion():
def test_line_cursors():
print("=== Test: Cursors ===")
from backend.nodes.analysis import Cursors
from backend.nodes.cursors import Cursors
node = Cursors()
@@ -1718,7 +1721,7 @@ def test_line_cursors():
def test_fft2d():
print("=== Test: FFT2D ===")
from backend.nodes.analysis import FFT2D
from backend.nodes.fft_2d import FFT2D
node = FFT2D()
@@ -1777,7 +1780,7 @@ def test_fft2d():
def test_stats():
print("=== Test: Stats ===")
from backend.nodes.analysis import Stats
from backend.nodes.stats import Stats
node = Stats()
captured = []
@@ -1845,7 +1848,7 @@ def test_stats():
def test_view3d():
print("=== Test: View3D ===")
from backend.nodes.display import View3D
from backend.nodes.view_3d import View3D
node = View3D()
field = make_field()
@@ -1863,7 +1866,7 @@ def test_view3d():
assert "height" in mesh
assert "z_data" in mesh
assert "colors" in mesh
assert mesh["z_scale"] == 2.0
assert mesh["z_scale"] == 0.2
assert mesh["width"] <= 64
assert mesh["height"] <= 64
# z_min < z_max for non-constant data