diff --git a/backend/nodes/save.py b/backend/nodes/save.py index f0e8f63..bb6aa0b 100644 --- a/backend/nodes/save.py +++ b/backend/nodes/save.py @@ -46,7 +46,7 @@ class Save: "DATA_FIELD": ["TIFF", "PNG", "NPZ"], "IMAGE": ["PNG", "TIFF", "NPZ"], "ANNOTATION_SOURCE": ["PNG", "TIFF", "NPZ"], - "LINE": ["CSV", "NPZ", "JSON"], + "LINE": ["PNG", "TIFF", "CSV", "NPZ", "JSON"], "RECORD_TABLE": ["CSV", "JSON"], "DATA_TABLE": ["CSV", "JSON"], "FLOAT": ["TXT", "JSON"], @@ -57,6 +57,11 @@ class Save: }, "optional": { "directory": ("DIRECTORY", {"label": "directory"}), + "plot_title": ("STRING", { + "default": "", + "placeholder": "plot title (optional)", + "label": "title", + }), }, } @@ -79,6 +84,7 @@ class Save: format: str, value, directory: str | None = None, + plot_title: str = "", ): path = self._resolve_save_path(filename, format, directory, directory_path) @@ -88,11 +94,11 @@ class Save: self._save_datafield(path, value, format) elif isinstance(value, np.ndarray): if value.ndim == 1: - self._save_line(path, LineData(data=value), format) + self._save_line(path, LineData(data=value), format, title=plot_title) else: self._save_image_or_array(path, value, format) elif isinstance(value, LineData): - self._save_line(path, value, format) + self._save_line(path, value, format, title=plot_title) elif isinstance(value, list): self._save_table(path, value, format) elif isinstance(value, (int, float, np.floating, np.integer)): @@ -176,9 +182,12 @@ class Save: return raise ValueError(f"Format {format_name} is not supported for IMAGE.") - def _save_line(self, path: Path, line: LineData, format_name: str): + def _save_line(self, path: Path, line: LineData, format_name: str, title: str = ""): y = np.asarray(line.data, dtype=np.float64).ravel() x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] if line.x_axis is not None else np.arange(len(y), dtype=np.float64) + if format_name in ("PNG", "TIFF"): + self._save_line_plot(path, x, y, line.x_unit, line.y_unit, title, format_name) + return if format_name == "CSV": with path.open("w", newline="", encoding="utf-8") as fh: writer = csv.writer(fh) @@ -199,6 +208,33 @@ class Save: return raise ValueError(f"Format {format_name} is not supported for LINE.") + def _save_line_plot( + self, + path: Path, + x: np.ndarray, + y: np.ndarray, + x_unit: str, + y_unit: str, + title: str, + format_name: str, + ): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(8, 5), dpi=150) + ax.plot(x, y, linewidth=1.2, color="#4f8ef7") + ax.set_xlabel(x_unit if x_unit else "x") + ax.set_ylabel(y_unit if y_unit else "y") + if title and title.strip(): + ax.set_title(title.strip()) + ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) + fig.tight_layout() + + ext = ".png" if format_name == "PNG" else ".tiff" + fig.savefig(str(path.with_suffix(ext)), format=format_name.lower(), dpi=150) + plt.close(fig) + def _save_table(self, path: Path, rows: list, format_name: str): if format_name == "JSON": path.write_text(json.dumps(rows, indent=2), encoding="utf-8")