From 626ac3d82e199deb9cf2fbb39efa55a644868c1b Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Mon, 18 May 2026 20:41:11 -0600 Subject: [PATCH] remove last channel name code --- src/ezmsg/learn/process/adaptive_linear_regressor.py | 4 ++-- src/ezmsg/learn/process/sklearn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/learn/process/adaptive_linear_regressor.py b/src/ezmsg/learn/process/adaptive_linear_regressor.py index c86a08e..7965027 100644 --- a/src/ezmsg/learn/process/adaptive_linear_regressor.py +++ b/src/ezmsg/learn/process/adaptive_linear_regressor.py @@ -171,7 +171,7 @@ def partial_fit(self, message: AxisArray) -> None: ]: # river path: needs numpy/pandas data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data - x = pd.DataFrame(data_np, columns=_axis_labels(message.axes["ch"].data)) + x = pd.DataFrame(data_np, columns=[f"f{i}" for i in range(data_np.shape[1])]) targets = message.attrs["trigger"].value target_np = np.asarray(targets.data) if target_np.ndim == 1: @@ -219,7 +219,7 @@ def _process(self, message: AxisArray) -> AxisArray | None: ]: # river path: needs numpy/pandas data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data - x = pd.DataFrame(data_np, columns=_axis_labels(message.axes["ch"].data)) + x = pd.DataFrame(data_np, columns=[f"f{i}" for i in range(data_np.shape[1])]) n_outputs = len(self.state.model) if isinstance(self.state.model, dict) else 1 out_labels = self._prediction_labels(n_outputs) if isinstance(self.state.model, dict): diff --git a/src/ezmsg/learn/process/sklearn.py b/src/ezmsg/learn/process/sklearn.py index 7d1f802..de9163f 100644 --- a/src/ezmsg/learn/process/sklearn.py +++ b/src/ezmsg/learn/process/sklearn.py @@ -110,7 +110,7 @@ def partial_fit(self, message: AxisArray) -> None: kwargs["classes"] = self.settings.partial_fit_classes self._state.model.partial_fit(X, y, **kwargs) elif hasattr(self._state.model, "learn_many"): - df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) + df_X = pd.DataFrame(message.data, columns=[f"f{i}" for i in range(message.data.shape[1])]) name = ( message.attrs["trigger"].value.axes["ch"].data[0] if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes