diff --git a/src/frontEnd/TerminalUi.py b/src/frontEnd/TerminalUi.py index e097c654f..2384991e3 100644 --- a/src/frontEnd/TerminalUi.py +++ b/src/frontEnd/TerminalUi.py @@ -74,7 +74,7 @@ def cancelSimulation(self): self.cancelSimulationButton.setEnabled(False) self.redoSimulationButton.setEnabled(True) - if (self.qProcess.state() == QtCore.QProcess.NotRunning): + if (self.qProcess.state() == QtCore.QProcess.ProcessState.NotRunning): return self.simulationCancelled = True @@ -98,7 +98,7 @@ def redoSimulation(self): self.cancelSimulationButton.setEnabled(True) self.redoSimulationButton.setEnabled(False) - if (self.qProcess.state() != QtCore.QProcess.NotRunning): + if (self.qProcess.state() != QtCore.QProcess.ProcessState.NotRunning): return # To make the progressbar running @@ -122,8 +122,7 @@ def redoSimulation(self): else: self.Flag = False - # Emit a custom signal with name plotFlag2 depending upon the Flag - self.qProcess.setProperty("plotFlag2", self.Flag) + self.qProcess.setProperty("redoPlotFlag", self.Flag) self.qProcess.start('ngspice', self.args) diff --git a/src/ngspiceSimulation/NgspiceWidget.py b/src/ngspiceSimulation/NgspiceWidget.py index 5e0f2e8d1..413e9766d 100644 --- a/src/ngspiceSimulation/NgspiceWidget.py +++ b/src/ngspiceSimulation/NgspiceWidget.py @@ -6,6 +6,7 @@ """ import os +import shlex import logging from typing import List, Optional from PyQt6 import QtWidgets, QtCore @@ -56,7 +57,6 @@ def __init__(self, netlist: str, sim_end_signal: pyqtSignal, plotFlag: Optional[ """ super().__init__() - # **CRITICAL FIX**: Set expanding size policy self.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Expanding) @@ -68,7 +68,6 @@ def __init__(self, netlist: str, sim_end_signal: pyqtSignal, plotFlag: Optional[ self.netlist_path = netlist self.sim_end_signal = sim_end_signal - # **IMPORTANT**: Store plotFlag and command for dual plot functionality self.plotFlag = plotFlag self.command = netlist logger.info(f"Value of plotFlag: {self.plotFlag}") @@ -109,10 +108,11 @@ def _prepare_ngspice_arguments(self, netlist: str) -> List[str]: def _configure_process(self) -> None: """Configure the NGSpice process with working directory and signals.""" self.process.setWorkingDirectory(self.project_dir) - self.process.setProcessChannelMode(QtCore.QProcess.ProcessChannelMode.MergedChannels) - + self.process.setProcessChannelMode(QtCore.QProcess.ProcessChannelMode.SeparateChannels) + # Connect process signals - self.process.readyRead.connect(self.ready_read_all) + self.process.readyReadStandardOutput.connect(self._handle_stdout) + self.process.readyReadStandardError.connect(self._handle_stderr) self.process.finished.connect( lambda exit_code, exit_status: self.finish_simulation( exit_code, exit_status, self.sim_end_signal, False @@ -122,19 +122,18 @@ def _configure_process(self) -> None: lambda: self.finish_simulation(None, None, self.sim_end_signal, True) ) + def _register_process(self, process: QtCore.QProcess) -> None: + """Register a process with the application config tracker.""" + self.obj_appconfig.process_obj.append(process) + current_project_name = self.obj_appconfig.current_project['ProjectName'] + if current_project_name in self.obj_appconfig.proc_dict: + self.obj_appconfig.proc_dict[current_project_name].append(process.pid()) + def _start_process(self) -> None: """Start the NGSpice process and register it with the application.""" self.process.start('ngspice', self.ngspice_args) - - # Register process with application config - self.obj_appconfig.process_obj.append(self.process) logger.debug(f"Process dictionary: {self.obj_appconfig.proc_dict}") - - current_project_name = self.obj_appconfig.current_project['ProjectName'] - if current_project_name in self.obj_appconfig.proc_dict: - self.obj_appconfig.proc_dict[current_project_name].append( - self.process.pid() - ) + self._register_process(self.process) def _is_linux(self) -> bool: """Check if the current operating system is Linux.""" @@ -150,46 +149,40 @@ def _start_gaw_process(self, netlist: str) -> None: try: self.gaw_process = QtCore.QProcess(self) raw_file = netlist.replace(".cir.out", ".raw") - self.gaw_command = f"gaw {raw_file}" + self.gaw_command = f"gaw {shlex.quote(raw_file)}" self.gaw_process.start('sh', ['-c', self.gaw_command]) logger.info(f"Started GAW with command: {self.gaw_command}") except Exception as e: logger.error(f"Failed to start GAW process: {e}") @pyqtSlot() - def ready_read_all(self) -> None: - """ - Handle process output and display it in the terminal console. - - Reads both standard output and standard error from the NGSpice process - and displays them in the TerminalUi console. Filters out specific - NGSpice warnings that are not relevant in batch mode. - """ + def _handle_stdout(self) -> None: + """Read stdout from the NGSpice process and display in terminal.""" + try: + data = self.process.readAllStandardOutput().data() + if data: + self.terminal_ui.simulationConsole.insertPlainText( + data.decode('utf-8', errors='replace') + ) + except Exception as e: + logger.error(f"Error reading stdout: {e}") + + @pyqtSlot() + def _handle_stderr(self) -> None: + """Read stderr, filter batch-mode noise, display the rest.""" try: - # Read and display standard output - std_output = self.process.readAllStandardOutput().data() - if std_output: - output_text = str(std_output, encoding='utf-8') - self.terminal_ui.simulationConsole.insertPlainText(output_text) - - # Read and filter standard error - std_error = self.process.readAllStandardError().data() - if std_error: - error_text = str(std_error, encoding='utf-8') - - # Filter out irrelevant NGSpice warnings in batch mode - filtered_lines = [] - for line in error_text.split('\n'): - if ('PrinterOnly' not in line and - 'viewport for graphics' not in line): - filtered_lines.append(line) - - filtered_error = '\n'.join(filtered_lines) - if filtered_error.strip(): - self.terminal_ui.simulationConsole.insertPlainText(filtered_error) - + data = self.process.readAllStandardError().data() + if not data: + return + text = data.decode('utf-8', errors='replace') + filtered = '\n'.join( + line for line in text.split('\n') + if 'PrinterOnly' not in line and 'viewport for graphics' not in line + ) + if filtered.strip(): + self.terminal_ui.simulationConsole.insertPlainText(filtered) except Exception as e: - logger.error(f"Error reading process output: {e}") + logger.error(f"Error reading stderr: {e}") def finish_simulation(self, exit_code: Optional[int], exit_status: Optional[QtCore.QProcess.ExitStatus], @@ -220,7 +213,9 @@ def finish_simulation(self, exit_code: Optional[int], exit_code = self.process.exitCode() error_type = self.process.error() - if error_type <= self.ERROR_TIMED_OUT: # FailedToStart, Crashed, TimedOut + if error_type in (QtCore.QProcess.ProcessError.FailedToStart, + QtCore.QProcess.ProcessError.Crashed, + QtCore.QProcess.ProcessError.Timedout): exit_status = QtCore.QProcess.ExitStatus.CrashExit elif exit_status is None: exit_status = self.process.exitStatus() @@ -230,18 +225,13 @@ def finish_simulation(self, exit_code: Optional[int], self._show_cancellation_message() elif self._is_simulation_successful(exit_status, exit_code, error_type): self._show_success_message() - - # **CRITICAL ADDITION**: Check and update plotFlag from process properties - # This handles the re-simulation case from TerminalUi - new_plot_flag = self.process.property("plotFlag") - if new_plot_flag is not None: - self.plotFlag = new_plot_flag - - new_plot_flag2 = self.process.property("plotFlag2") - if new_plot_flag2 is not None: - self.plotFlag = new_plot_flag2 - - # **CRITICAL ADDITION**: Open NGSpice plot windows if requested + + # On redo-simulation, TerminalUi sets "redoPlotFlag" on the process + # to pass the user's plot choice back here + redo_flag = self.process.property("redoPlotFlag") + if redo_flag is not None: + self.plotFlag = redo_flag + if self.plotFlag: self.open_ngspice_plots() else: @@ -266,55 +256,36 @@ def open_ngspice_plots(self) -> None: config_path = os.path.join('library', 'config', '.nghdl', 'config.ini') parser_nghdl.read(config_path) msys_home = parser_nghdl.get('COMPILER', 'MSYS_HOME') - - temp_dir = os.getcwd() - os.chdir(self.project_dir) - - # Create command for Windows using mintty - mintty_command = ( - f'cmd /c "start /min {msys_home}/usr/bin/mintty.exe ' - f'ngspice -p {self.command}"' - ) - - # Create a new QProcess for mintty + mintty_exe = os.path.join(msys_home, 'usr', 'bin', 'mintty.exe') + self.mintty_process = QtCore.QProcess(self) - self.mintty_process.start(mintty_command) - - os.chdir(temp_dir) - logger.info(f"Started mintty with command: {mintty_command}") - + self.mintty_process.setWorkingDirectory(self.project_dir) + # Pass program + args directly — Qt handles quoting internally + self.mintty_process.start(mintty_exe, ['ngspice', '-p', self.command]) + logger.info(f"Started mintty: {mintty_exe} ngspice -p {self.command}") + except Exception as e: logger.error(f"Failed to start Windows NGSpice plots: {e}") - + else: # Linux/Unix try: - # Create xterm command for interactive NGSpice + raw_file = self.command.replace('.cir.out', '.raw') + # Quote all paths so spaces in project names don't break the shell command xterm_command = ( - f"cd {self.project_dir}; " - f"ngspice -r {self.command.replace('.cir.out', '.raw')} " - f"{self.command}" + f"cd {shlex.quote(self.project_dir)} && " + f"ngspice -r {shlex.quote(raw_file)} {shlex.quote(self.command)}" ) - xterm_args = ['-hold', '-e', xterm_command] - - # Create new QProcess for xterm self.xterm_process = QtCore.QProcess(self) - self.xterm_process.start('xterm', xterm_args) - - # Register the process - self.obj_appconfig.process_obj.append(self.xterm_process) - current_project = self.obj_appconfig.current_project['ProjectName'] - if current_project in self.obj_appconfig.proc_dict: - self.obj_appconfig.proc_dict[current_project].append( - self.xterm_process.pid() - ) - - # Also restart GAW for the new plot window + self.xterm_process.start('xterm', ['-hold', '-e', xterm_command]) + + self._register_process(self.xterm_process) + if hasattr(self, 'gaw_process') and hasattr(self, 'gaw_command'): self.gaw_process.start('sh', ['-c', self.gaw_command]) logger.info(f"Restarted GAW: {self.gaw_command}") - - logger.info(f"Started xterm with args: {xterm_args}") - + + logger.info(f"Started xterm: {xterm_command}") + except Exception as e: logger.error(f"Failed to start Linux NGSpice plots: {e}") @@ -339,8 +310,8 @@ def _is_simulation_successful(self, exit_status: QtCore.QProcess.ExitStatus, Returns: True if simulation was successful, False otherwise """ - return (exit_status == QtCore.QProcess.ExitStatus.NormalExit and - exit_code == 0 and + return (exit_status == QtCore.QProcess.ExitStatus.NormalExit and + exit_code == 0 and error_type == QtCore.QProcess.ProcessError.UnknownError) def _show_cancellation_message(self) -> None: diff --git a/src/ngspiceSimulation/data_extraction.py b/src/ngspiceSimulation/data_extraction.py index 5d8d758cb..cc491b540 100644 --- a/src/ngspiceSimulation/data_extraction.py +++ b/src/ngspiceSimulation/data_extraction.py @@ -2,321 +2,294 @@ """ Data extraction module for NGSpice simulation results. -This module handles the extraction and processing of simulation data from NGSpice -output files, supporting AC, DC, and Transient analysis types. +Parses plot_data_v.txt and plot_data_i.txt produced by ngspice. + +Transient / DC format: + * /path/to/circuit.cir <- marks start of each column group + Transient Analysis date <- analysis type line + ----... <- separator + Index time node1 node2 <- header (parts[1]=x-axis, parts[2:]=node names) + ----... + 0\tt0\tv1\tv2\t <- data rows (tab-separated, trailing \t) + ... + 54\tt54\tv1\tv2\t + <- blank line + Index time node1 node2 <- page-break header (every ~55 rows, same group) + ----... + 55\tt55\tv1\tv2\t + ... + * /path/to/circuit.cir <- new column group (circuit with many nodes) + Transient Analysis date + Index time node3 node4 + 0\tt0\tv3\tv4\t <- same time axis, new node values + +AC format (differs from Transient/DC): + Each node value is split into TWO tab columns per row: + real_part,\timaginary_part + The real_part has a trailing comma (ngspice artifact). + Example: 0\t1.0e+03\t9.96e+00,\t-4.50e+00\t + Only the real part is stored; the imaginary part is discarded, + matching the original implementation behaviour. """ import os import logging -from decimal import Decimal -from typing import List, Tuple, Dict, Any, Optional +import numpy as np +from typing import List, Tuple, Optional from PyQt6 import QtWidgets from configuration.Appconfig import Appconfig -# Set up logging logger = logging.getLogger(__name__) class DataExtraction: - """ - Extracts and processes simulation data from NGSpice output files. - - This class handles reading and parsing voltage and current data from - NGSpice simulation output files for different analysis types. - """ - - # Analysis type constants + """Extracts simulation data from NGSpice output files.""" + AC_ANALYSIS = 0 TRANSIENT_ANALYSIS = 1 DC_ANALYSIS = 2 def __init__(self) -> None: - """Initialize the DataExtraction instance.""" self.obj_appconfig = Appconfig() - self.data: List[str] = [] - # consists of all the columns of data belonging to nodes and branches - self.y: List[List[Decimal]] = [] # stores y-axis data - self.x: List[Decimal] = [] # stores x-axis data - # Add the missing instance variables + self.x: np.ndarray = np.array([], dtype=np.float64) + self.y: List[np.ndarray] = [] self.NBList: List[str] = [] self.NBIList: List[str] = [] self.volts_length: int = 0 + self.data: List[str] = [] # kept for backward compat + self.analysisType: int = self.TRANSIENT_ANALYSIS + self.dec: int = 0 + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ - def numberFinder(self, file_path: str) -> List[int]: + def _parse_plot_file( + self, filepath: str, is_ac: bool = False + ) -> Tuple[np.ndarray, str, List[str], List[np.ndarray]]: """ - Analyze simulation files to determine data structure parameters. - - Args: - file_path: Path to the directory containing simulation files - - Returns: - List containing [lines_per_node, voltage_nodes, analysis_type, - dec_flag, current_branches] + Parse one ngspice plot file. + + Returns (x_array, x_name, names, arrays) where names and arrays are + parallel lists — one entry per output column in file order. + + Duplicate column names (ngspice truncates long node names so two + distinct nodes can share the same string) are preserved as separate + entries; each gets its own data list keyed by position, not by name. + + is_ac=True: each node occupies 2 tab columns ("real, imag"); only + the real part is kept (comma stripped), imaginary discarded. + + Line dispatch: + - starts with digit -> data row (fast path) + - stripped starts with * -> new column group incoming + - stripped starts with - -> separator, skip + - stripped starts 'Index'-> column header (new group or page-break) + - everything else -> analysis-type banner, skip + """ + x_list: List[float] = [] + all_names: List[str] = [] # output channel names (duplicates kept) + all_data: List[List[float]] = [] # parallel data lists, one per channel + + x_name: str = 'time' + # Indices into all_data for the columns of the current group. + # On a page-break (same group, same header) we reuse the same indices. + current_indices: Optional[List[int]] = None + new_group_incoming: bool = False + collecting_x: bool = True + cols_per_node: int = 2 if is_ac else 1 + + try: + with open(filepath, 'r') as f: + for line in f: + # ---- Fast path: data rows always start with a digit ---- + if line and line[0].isdigit(): + if current_indices is None: + continue + parts = line.split('\t') + if len(parts) < 2 + cols_per_node * len(current_indices): + continue + try: + x_val = float(parts[1]) + if collecting_x: + x_list.append(x_val) + for i, idx in enumerate(current_indices): + if is_ac: + raw = parts[2 + 2 * i].rstrip(',') + else: + raw = parts[2 + i] + all_data[idx].append(float(raw)) + except (ValueError, IndexError): + continue + continue + + # ---- Non-data lines ---- + stripped = line.strip() + if not stripped: + continue + + if stripped[0] == '*': + new_group_incoming = True + if current_indices is not None: + collecting_x = False + continue + + if stripped[0] == '-': + continue + + if stripped.startswith('Index'): + parts = stripped.split() + if new_group_incoming: + x_name = parts[1] + col_names = parts[2:] + current_indices = [] + for name in col_names: + all_names.append(name) + all_data.append([]) + current_indices.append(len(all_data) - 1) + new_group_incoming = False + # else: page-break — same group, same columns, same indices + continue + + except OSError as e: + logger.error(f"Cannot open {filepath}: {e}") + raise + + x_arr = np.array(x_list, dtype=np.float64) + arrays = [np.array(d, dtype=np.float64) for d in all_data] + + logger.debug( + f"Parsed {filepath}: {len(x_arr)} x-pts, " + f"{len(arrays)} channels, x_name='{x_name}'" + ) + return x_arr, x_name, all_names, arrays + + def _detect_analysis_type(self, file_path: str) -> Tuple[int, int]: + """ + Read the 'analysis' file and return (analysis_type, dec_flag). + + dec_flag=1 means AC analysis uses decade (log) frequency sweep. """ - # Opening Analysis file - with open(os.path.join(file_path, "analysis")) as analysis_file: - self.analysisInfo = analysis_file.read() - self.analysisInfo = self.analysisInfo.split(" ") - - # Reading data file for voltage - with open(os.path.join(file_path, "plot_data_v.txt")) as voltage_file: - self.voltData = voltage_file.read() - - self.voltData = self.voltData.split("\n") - - # Initializing variable - # 'lines_per_node' gives no. of lines of data for each node/branch - # 'partitions_per_voltage_node' gives the no of partitions for a single voltage node - # 'voltage_node_count' gives total number of voltage - # 'current_branch_count' gives total number of current - - lines_per_node = partitions_per_voltage_node = voltage_node_count = current_branch_count = 0 - - # Finding total number of voltage node - for line in self.voltData[3:]: - # it has possible names of voltage nodes in NgSpice - if "Index" in line: # "V(" in line or "x1" in line or "u3" in line: - voltage_node_count += 1 - - # Reading Current Source Data - with open(os.path.join(file_path, "plot_data_i.txt")) as current_file: - self.currentData = current_file.read() - self.currentData = self.currentData.split("\n") - - # Finding Number of Branch - for line in self.currentData[3:]: - if "#branch" in line: - current_branch_count += 1 - - self.dec = 0 - - # For AC - if self.analysisInfo[0][-3:] == ".ac": - self.analysisType = self.AC_ANALYSIS - if "dec" in self.analysisInfo: - self.dec = 1 - - for line in self.voltData[3:]: - lines_per_node += 1 # 'lines_per_node' gives no. of lines of data for each node/branch - if "Index" in line: - partitions_per_voltage_node += 1 - # 'partitions_per_voltage_node' gives the no of partitions for a single voltage node - logger.debug(f"partitions_per_voltage_node: {partitions_per_voltage_node}") - if "AC" in line: # DC for dc files and AC for ac ones - break - - elif ".tran" in self.analysisInfo: - self.analysisType = self.TRANSIENT_ANALYSIS - for line in self.voltData[3:]: - lines_per_node += 1 - if "Index" in line: - partitions_per_voltage_node += 1 - # 'partitions_per_voltage_node' gives the no of partitions for a single voltage node - logger.debug(f"partitions_per_voltage_node: {partitions_per_voltage_node}") - if "Transient" in line: # DC for dc files and AC for ac ones - break - - # For DC: - else: - self.analysisType = self.DC_ANALYSIS - for line in self.voltData[3:]: - lines_per_node += 1 - if "Index" in line: - partitions_per_voltage_node += 1 - # 'partitions_per_voltage_node' gives the no of partitions for a single voltage node - logger.debug(f"partitions_per_voltage_node: {partitions_per_voltage_node}") - if "DC" in line: # DC for dc files and AC for ac ones - break - - voltage_node_count = voltage_node_count // partitions_per_voltage_node # voltage_node_count gives the no of voltage nodes - current_branch_count = current_branch_count // partitions_per_voltage_node # current_branch_count gives the no of branches - - analysis_params = [lines_per_node, voltage_node_count, self.analysisType, self.dec, current_branch_count] - - return analysis_params + analysis_file = os.path.join(file_path, "analysis") + with open(analysis_file, 'r') as f: + content = f.read().strip() + + tokens = content.split() + dec = 0 + + if not tokens: + logger.warning("analysis file is empty, defaulting to Transient") + return self.TRANSIENT_ANALYSIS, 0 + + first = tokens[0].lower() + + if first.endswith('.ac'): + if 'dec' in tokens: + dec = 1 + return self.AC_ANALYSIS, dec + + if '.tran' in tokens: + return self.TRANSIENT_ANALYSIS, dec + + return self.DC_ANALYSIS, dec + + # ------------------------------------------------------------------ + # Public interface (matches what plot_window.py expects) + # ------------------------------------------------------------------ def openFile(self, file_path: str) -> List[int]: """ - Open and process simulation data files. - - Args: - file_path: Path to the directory containing simulation files - + Open and process both simulation data files. + Returns: - List containing [analysis_type, dec_flag] - - Raises: - Exception: If files cannot be read or processed + [analysis_type, dec_flag] + where analysis_type is AC_ANALYSIS=0, TRANSIENT_ANALYSIS=1, DC_ANALYSIS=2 + and dec_flag=1 for log-scale AC sweep, 0 otherwise. + + Populates: + self.x - 1-D numpy array of x-axis values (time/freq/sweep) + self.y - list of 1-D numpy arrays, one per node/branch + self.NBList - list of all node+branch names (voltage first, then current) + self.NBIList - list of current branch names only + self.volts_length - number of voltage nodes """ try: - with open(os.path.join(file_path, "plot_data_i.txt")) as current_file: - all_current_data = current_file.read() + v_path = os.path.join(file_path, "plot_data_v.txt") + i_path = os.path.join(file_path, "plot_data_i.txt") - all_current_data = all_current_data.split("\n") - self.NBIList = [] + # ---- Detect analysis type ---- + analysis_type, dec = self._detect_analysis_type(file_path) + self.analysisType = analysis_type + self.dec = dec + is_ac = (analysis_type == self.AC_ANALYSIS) - with open(os.path.join(file_path, "plot_data_v.txt")) as voltage_file: - all_voltage_data = voltage_file.read() + # ---- Parse voltage file ---- + x_arr, x_name, v_names, v_arrays = self._parse_plot_file(v_path, is_ac=is_ac) - except Exception as e: - logger.error(f"Exception reading files: {e}") - self.obj_appconfig.print_error(f'Exception Message: {e}') - self.msg = QtWidgets.QErrorMessage() - self.msg.setModal(True) - self.msg.setWindowTitle("Error Message") - self.msg.showMessage('Unable to open plot data files.') - self.msg.exec() - - try: + # ---- Parse current file (graceful if missing or empty) ---- + i_names: List[str] = [] + i_arrays: List[np.ndarray] = [] try: - for token in all_current_data[3].split(" "): - if len(token) > 0: - self.NBIList.append(token) - self.NBIList = self.NBIList[2:] - current_list_length = len(self.NBIList) - except (IndexError, AttributeError) as e: - logger.warning(f"Error parsing current data: {e}") - self.NBIList = [] - current_list_length = 0 + _, _, i_names, i_arrays = self._parse_plot_file(i_path, is_ac=is_ac) + except OSError: + logger.warning(f"Current file not found or unreadable: {i_path}") + except Exception as e: + logger.warning(f"Could not parse current file: {e}") + + # ---- Populate public attributes ---- + self.x = x_arr + self.volts_length = len(v_names) + self.NBIList = i_names + self.NBList = v_names + i_names + self.y = v_arrays + i_arrays + + # self.data kept as non-empty so numVals() won't crash on old + # callers that still reach for data[0] - we give it one dummy row + # with the right column count so numVals()[0] is correct. + dummy_cols = '\t'.join(['0.0'] * len(self.NBList)) + self.data = [dummy_cols] + + _analysis_name = { + self.AC_ANALYSIS: 'AC', + self.TRANSIENT_ANALYSIS: 'Tran', + self.DC_ANALYSIS: 'DC', + }.get(analysis_type, '?') + logger.info( + f"openFile done | analysis={_analysis_name} " + f"| {len(v_names)} V-nodes, {len(i_names)} I-branches " + f"| {len(x_arr)} data points" + ) + logger.info(f"NBList: {self.NBList}") + + return [analysis_type, dec] + except Exception as e: - logger.error(f"Exception parsing current data: {e}") - self.obj_appconfig.print_error(f'Exception Message: {e}') - self.msg = QtWidgets.QErrorMessage() - self.msg.setModal(True) - self.msg.setWindowTitle("Error Message") - self.msg.showMessage('Unable to read Analysis File.') - self.msg.exec() - - data_params = self.numberFinder(file_path) - lines_per_partition = int(data_params[0] + 1) - voltage_node_count = int(data_params[1]) - analysis_type = data_params[2] - current_branch_count = data_params[4] - - analysis_info = [analysis_type, data_params[3]] - self.NBList = [] - all_voltage_data = all_voltage_data.split("\n") - for token in all_voltage_data[3].split(" "): - if len(token) > 0: - self.NBList.append(token) - self.NBList = self.NBList[2:] - voltage_list_length = len(self.NBList) - logger.info(f"NBLIST: {self.NBList}") - - processed_current_data = [] - voltage_column_count = len(all_voltage_data[5].split("\t")) - current_column_count = len(all_current_data[5].split("\t")) - - full_data = [] - - # Creating list of data: - if analysis_type < 3: - for voltage_node_index in range(1, voltage_node_count): - for token in all_voltage_data[3 + voltage_node_index * lines_per_partition].split(" "): - if len(token) > 0: - self.NBList.append(token) - self.NBList.pop(voltage_list_length) - self.NBList.pop(voltage_list_length) - voltage_list_length = len(self.NBList) - - for current_branch_index in range(1, current_branch_count): - for token in all_current_data[3 + current_branch_index * lines_per_partition].split(" "): - if len(token) > 0: - self.NBIList.append(token) - self.NBIList.pop(current_list_length) - self.NBIList.pop(current_list_length) - current_list_length = len(self.NBIList) - - partition_row_index = 0 - data_row_index = 0 - combined_row_index = 0 - - for line in all_current_data[5:lines_per_partition - 1]: - if len(line.split("\t")) == current_column_count: - current_row = line.split("\t") - current_row.pop(0) - current_row.pop(0) - current_row.pop() - if analysis_type == 0: # not in trans - current_row.pop() - - for current_partition_index in range(1, current_branch_count): - additional_current_line = all_current_data[5 + current_partition_index * lines_per_partition + data_row_index].split("\t") - additional_current_line.pop(0) - additional_current_line.pop(0) - if analysis_type == 0: - additional_current_line.pop() # not required for dc - additional_current_line.pop() - current_row = current_row + additional_current_line - - full_data.append(current_row) - - data_row_index += 1 - - for line in all_voltage_data[5:lines_per_partition - 1]: - if len(line.split("\t")) == voltage_column_count: - voltage_row = line.split("\t") - voltage_row.pop() - if analysis_type == 0: - voltage_row.pop() - for voltage_partition_index in range(1, voltage_node_count): - additional_voltage_line = all_voltage_data[5 + voltage_partition_index * lines_per_partition + partition_row_index].split("\t") - additional_voltage_line.pop(0) - additional_voltage_line.pop(0) - if analysis_type == 0: - additional_voltage_line.pop() # not required for dc - if self.NBList[len(self.NBList) - 1] == 'v-sweep': - self.NBList.pop() - additional_voltage_line.pop() - - additional_voltage_line.pop() - voltage_row = voltage_row + additional_voltage_line - voltage_row = voltage_row + full_data[combined_row_index] - combined_row_index += 1 - - combined_row_str = "\t".join(voltage_row[1:]) - combined_row_str = combined_row_str.replace(",", "") - self.data.append(combined_row_str) - - partition_row_index += 1 - - self.volts_length = len(self.NBList) - self.NBList = self.NBList + self.NBIList - - logger.info(f"Analysis info: {analysis_info}") - return analysis_info + logger.error(f"openFile failed: {e}", exc_info=True) + self.obj_appconfig.print_error(f'DataExtraction error: {e}') + try: + msg = QtWidgets.QErrorMessage() + msg.setModal(True) + msg.setWindowTitle("Error Message") + msg.showMessage(f'Unable to open plot data files:\n{e}') + msg.exec() + except Exception: + pass + return [self.TRANSIENT_ANALYSIS, 0] - def numVals(self) -> List[int]: + def computeAxes(self) -> None: """ - Get the number of data columns and voltage nodes. - - Returns: - List containing [total_columns, voltage_node_count] + No-op: x and y are already numpy arrays populated by openFile(). + Kept for backward compatibility with plot_window.py call sequence. """ - total_columns = len(self.data[0].split("\t")) - voltage_node_count = self.volts_length - return [total_columns, voltage_node_count] + # plot_window.py calls: openFile() -> computeAxes() -> numVals() + # In the old implementation computeAxes() built self.x and self.y + # from self.data. Now openFile() does it all directly. + pass - def computeAxes(self) -> None: + def numVals(self) -> List[int]: """ - Compute x and y axis data from the processed simulation data. - - This method extracts the time/frequency data (x-axis) and - voltage/current data (y-axis) from the processed data. + Return [total_node_count, voltage_node_count]. + + plot_window.py only uses index [1] (volts_length). """ - if not self.data: - logger.warning("No data available for axis computation") - return - - num_columns = len(self.data[0].split("\t")) - self.y = [] - first_row_values = self.data[0].split("\t") - for column_index in range(1, num_columns): - self.y.append([Decimal(first_row_values[column_index])]) - for row in self.data[1:]: - row_values = row.split("\t") - for column_index in range(1, num_columns): - self.y[column_index - 1].append(Decimal(row_values[column_index])) - for row in self.data: - row_values = row.split("\t") - self.x.append(Decimal(row_values[0])) + return [len(self.y), self.volts_length] diff --git a/src/ngspiceSimulation/plot_window.py b/src/ngspiceSimulation/plot_window.py index f5e657432..9b65b8e11 100644 --- a/src/ngspiceSimulation/plot_window.py +++ b/src/ngspiceSimulation/plot_window.py @@ -14,18 +14,18 @@ import logging from pathlib import Path from decimal import Decimal, getcontext -from typing import Dict, List, Optional, Tuple, Any, Union +from typing import Dict, List, Optional, Tuple, Any from PyQt6 import QtGui, QtCore, QtWidgets from PyQt6.QtCore import Qt, QSettings, pyqtSignal -from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, +from PyQt6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem, QPushButton, - QCheckBox, QRadioButton, QButtonGroup, QGroupBox, + QCheckBox, QGroupBox, QLabel, QLineEdit, QSlider, QDoubleSpinBox, QMenu, QFileDialog, QColorDialog, QInputDialog, - QMessageBox, QErrorMessage, QStatusBar, QStyle, + QMessageBox, QStatusBar, QSplitter, QToolButton, QWidgetAction, QGridLayout, - QSpacerItem, QSizePolicy,QScrollArea) + QSizePolicy, QScrollArea) from PyQt6.QtGui import (QColor, QBrush, QPalette, QKeySequence, QPainter, QPixmap, QFont, QAction) @@ -35,7 +35,6 @@ from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar from matplotlib.backend_bases import NavigationToolbar2 from matplotlib.figure import Figure -from matplotlib.widgets import Cursor from matplotlib.lines import Line2D from matplotlib.text import Text @@ -131,12 +130,92 @@ class CustomListWidget(QListWidget): def __init__(self, parent: Optional[QWidget] = None) -> None: super().__init__(parent) - self.setSelectionMode(QListWidget.MultiSelection) + self.setSelectionMode(QListWidget.SelectionMode.MultiSelection) def paintEvent(self, event: QtGui.QPaintEvent) -> None: super().paintEvent(event) +def _safe_eval(expr: str, data_map: dict) -> "np.ndarray": + """Evaluate a math expression over trace arrays without using eval() on raw user input. + + Allowed: trace names, numeric literals, + - * / ** unary minus, numpy via 'np'. + Raises ValueError for anything else (attribute access, calls, etc.). + """ + import ast, operator as op + + ALLOWED_OPS = { + ast.Add: op.add, ast.Sub: op.sub, + ast.Mult: op.mul, ast.Div: op.truediv, + ast.Pow: op.pow, ast.USub: op.neg, + } + + def _eval(node): + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return node.value + if isinstance(node, ast.Name): + if node.id in data_map: + return data_map[node.id] + raise ValueError(f"Unknown trace: '{node.id}'") + if isinstance(node, ast.BinOp): + op_fn = ALLOWED_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"Unsupported operator: {type(node.op).__name__}") + return op_fn(_eval(node.left), _eval(node.right)) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return op.neg(_eval(node.operand)) + raise ValueError(f"Unsupported expression: {ast.dump(node)}") + + tree = ast.parse(expr, mode='eval') + return np.array(_eval(tree.body), dtype=float) + + +def _format_measurement(value: float, unit: str) -> str: + """Format a voltage or current with an appropriate SI prefix.""" + abs_val = abs(value) + if unit == "A": + if abs_val >= 1: return f"{value:.3g} A" + if abs_val >= 1e-3: return f"{value * 1e3:.3g} mA" + if abs_val >= 1e-6: return f"{value * 1e6:.3g} µA" + if abs_val >= 1e-9: return f"{value * 1e9:.3g} nA" + return "0 A" + else: + if abs_val >= 1: return f"{value:.3g} V" + if abs_val >= 1e-3: return f"{value * 1e3:.3g} mV" + if abs_val >= 1e-6: return f"{value * 1e6:.3g} µV" + return "0 V" + + +def _format_frequency(freq_hz: float) -> str: + """Format a frequency in Hz with an appropriate SI prefix.""" + if freq_hz >= 1e9: return f"{freq_hz / 1e9:.3g} GHz" + if freq_hz >= 1e6: return f"{freq_hz / 1e6:.3g} MHz" + if freq_hz >= 1e3: return f"{freq_hz / 1e3:.3g} kHz" + return f"{freq_hz:.3g} Hz" + + +def _detect_frequency(time_data: "np.ndarray", + logic_normalized: "np.ndarray") -> "Optional[float]": + """Return signal frequency in Hz if periodic, else None. + + Uses rising-edge timing. Requires ≥2 complete cycles and a coefficient of + variation below 10% — rejects glitchy or non-periodic signals. + """ + transitions = np.diff(logic_normalized.astype(np.int8)) + rising_idx = np.where(transitions == 1)[0] + if len(rising_idx) < 3: # need 2+ periods to verify consistency + return None + periods = np.diff(time_data[rising_idx]) + if len(periods) == 0: + return None + mean_p = float(np.mean(periods)) + if mean_p <= 0: + return None + if len(periods) > 1 and float(np.std(periods)) / mean_p > 0.10: + return None + return 1.0 / mean_p + + class plotWindow(QWidget): """Main plotting widget for NGSpice simulation results.""" @@ -160,20 +239,17 @@ def __init__(self, file_path: str, project_name: str, parent=None) -> None: self.apply_theme() def _initialize_data_structures(self) -> None: - self.active_traces: Dict[int, Line2D] = {} - self.trace_visibility: Dict[int, bool] = {} - self.trace_colors: Dict[int, str] = {} - self.trace_thickness: Dict[int, float] = {} - self.trace_style: Dict[int, str] = {} - self.trace_names: Dict[int, str] = {} + self.traces: Dict[int, Trace] = {} self.cursor_lines: List[Optional[Line2D]] = [] self.cursor_positions: List[Optional[float]] = [] self.timing_annotations: Dict[int, Any] = {} self.color_palette = VIBRANT_COLOR_PALETTE.copy() - self.color: List[str] = [] - self.color_index = 0 - self.logic_threshold: Optional[float] = None + self.logic_thresholds: Dict[int, float] = {} self.vertical_spacing = DEFAULT_VERTICAL_SPACING + self._func_line: Optional[Line2D] = None + self._drag_cursor_idx: Optional[int] = None + self._meters: List[Any] = [] + self._last_was_timing: bool = False def _initialize_configuration(self) -> None: self.config_dir = Path.home() / '.pythonPlotting' @@ -197,9 +273,9 @@ def load_config(self) -> Dict[str, Any]: def save_config(self) -> None: try: self.config_dir.mkdir(exist_ok=True) - self.config['trace_colours'] = {self.trace_names.get(idx, self.obj_dataext.NBList[idx]): color for idx, color in self.trace_colors.items()} - self.config['trace_thickness'] = {self.trace_names.get(idx, self.obj_dataext.NBList[idx]): thickness for idx, thickness in self.trace_thickness.items()} - self.config['trace_style'] = {self.trace_names.get(idx, self.obj_dataext.NBList[idx]): style for idx, style in self.trace_style.items()} + self.config['trace_colours'] = {t.name: t.color for t in self.traces.values()} + self.config['trace_thickness'] = {t.name: t.thickness for t in self.traces.values()} + self.config['trace_style'] = {t.name: t.style for t in self.traces.values()} temp_file = self.config_file.with_suffix('.tmp') with open(temp_file, 'w', encoding='utf-8') as config_file: json.dump(self.config, config_file, indent=2) @@ -209,6 +285,9 @@ def save_config(self) -> None: def closeEvent(self, event: QtGui.QCloseEvent) -> None: self.save_config() + for meter in self._meters: + meter.close() + self._meters.clear() if hasattr(self, 'canvas'): self.canvas.close() if hasattr(self, 'fig'): @@ -263,7 +342,7 @@ def create_main_frame(self) -> None: scroll_area = QScrollArea() scroll_area.setWidget(right_widget) scroll_area.setWidgetResizable(True) - scroll_area.setFrameShape(QtWidgets.QFrame.NoFrame) + scroll_area.setFrameShape(QtWidgets.QFrame.Shape.NoFrame) scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) scrollbar_style = "QScrollBar:vertical{background-color:#F5F5F5;width:8px;border:none;border-radius:4px;}QScrollBar::handle:vertical{background-color:#BDBDBD;border-radius:4px;min-height:20px;margin:2px;}QScrollBar::handle:vertical:hover{background-color:#9E9E9E;}QScrollBar::add-line:vertical,QScrollBar::sub-line:vertical{height:0px;}QScrollBar::add-page:vertical,QScrollBar::sub-page:vertical{background:transparent;}" @@ -322,11 +401,11 @@ def create_plot_area(self) -> QWidget: center_layout.addWidget(self.nav_toolbar) center_layout.addWidget(self.canvas) self.canvas.mpl_connect('button_press_event', self.on_canvas_click) + self.canvas.mpl_connect('button_release_event', self.on_canvas_release) self.canvas.mpl_connect('motion_notify_event', self.on_mouse_move) self.canvas.mpl_connect('key_press_event', self.on_key_press) self.canvas.mpl_connect('scroll_event', self.on_scroll) - self.canvas.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) - self.canvas.customContextMenuRequested.connect(self.show_canvas_context_menu) + self.canvas.setContextMenuPolicy(Qt.ContextMenuPolicy.NoContextMenu) return center_widget def create_control_panel(self) -> QWidget: @@ -347,6 +426,7 @@ def create_control_panel(self) -> QWidget: display_layout.addWidget(self.legend_check) self.autoscale_check = QCheckBox("Autoscale") self.autoscale_check.setChecked(True) + self.autoscale_check.stateChanged.connect(self.refresh_plot) display_layout.addWidget(self.autoscale_check) self.timing_check = QCheckBox("Digital Timing View") self.timing_check.stateChanged.connect(self.on_timing_view_changed) @@ -364,8 +444,9 @@ def create_control_panel(self) -> QWidget: self.threshold_spinbox.setRange(-100, 100) self.threshold_spinbox.setDecimals(3) self.threshold_spinbox.setSingleStep(0.1) - self.threshold_spinbox.setSuffix(" V") + self.threshold_spinbox.setSuffix("") self.threshold_spinbox.setSpecialValueText("Auto") + self.threshold_spinbox.setValue(self.threshold_spinbox.minimum()) self.threshold_spinbox.valueChanged.connect(self.on_threshold_changed) threshold_layout.addWidget(self.threshold_spinbox) timing_layout.addLayout(threshold_layout) @@ -387,6 +468,9 @@ def create_control_panel(self) -> QWidget: cursor_box = CollapsibleBox("Cursor Measurements") cursor_group = QWidget() cursor_layout = QVBoxLayout(cursor_group) + cursor_hint = QLabel("Left click: C1 · Right click: C2 · Drag to move") + cursor_hint.setStyleSheet("color: #757575; font-size: 10px;") + cursor_layout.addWidget(cursor_hint) self.cursor1_label = QLabel("Cursor 1: Not set") self.cursor2_label = QLabel("Cursor 2: Not set") self.delta_label = QLabel("Delta: --") @@ -445,9 +529,6 @@ def load_simulation_data(self) -> None: self.plot_type = self.obj_dataext.openFile(self.file_path) self.obj_dataext.computeAxes() self.data_info = self.obj_dataext.numVals() - for i in range(0, self.data_info[0] - 1): - color_idx = i % len(self.color_palette) - self.color.append(self.color_palette[color_idx]) self.volts_length = self.data_info[1] if self.plot_type[0] == DataExtraction.AC_ANALYSIS: self.analysis_label.setText("AC Analysis") @@ -455,15 +536,20 @@ def load_simulation_data(self) -> None: self.analysis_label.setText("Transient Analysis") else: self.analysis_label.setText("DC Analysis") - for i, name in enumerate(self.obj_dataext.NBList): - self.trace_names[i] = name self.populate_waveform_list() + is_transient = self.plot_type[0] == DataExtraction.TRANSIENT_ANALYSIS + self.timing_check.setEnabled(is_transient) + if not is_transient: + self.timing_check.setChecked(False) + self.timing_check.setToolTip("Digital timing view is only available for transient analysis") + else: + self.timing_check.setToolTip("") def create_colored_icon(self, color: QColor, is_selected: bool) -> QtGui.QIcon: pixmap = QPixmap(18, 18) pixmap.fill(Qt.GlobalColor.transparent) painter = QPainter(pixmap) - painter.setRenderHint(QPainter.Antialiasing) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) if is_selected: painter.setBrush(QBrush(color)) painter.setPen(Qt.PenStyle.NoPen) @@ -479,29 +565,18 @@ def create_colored_icon(self, color: QColor, is_selected: bool) -> QtGui.QIcon: def populate_waveform_list(self) -> None: self.waveform_list.clear() + self.traces.clear() saved_colors = self.config.get('trace_colours', {}) saved_thickness = self.config.get('trace_thickness', {}) saved_style = self.config.get('trace_style', {}) for i, node_name in enumerate(self.obj_dataext.NBList): + color = saved_colors.get(node_name, self.color_palette[i % len(self.color_palette)]) + thickness = saved_thickness.get(node_name, DEFAULT_LINE_THICKNESS) + style = saved_style.get(node_name, '-') + self.traces[i] = Trace(index=i, name=node_name, color=color, thickness=thickness, style=style) item = QListWidgetItem() item.setData(Qt.ItemDataRole.UserRole, i) - if node_name in saved_colors: - self.trace_colors[i] = saved_colors[node_name] - elif i < len(self.color): - self.trace_colors[i] = self.color[i] - else: - color_idx = i % len(self.color_palette) - self.trace_colors[i] = self.color_palette[color_idx] - if node_name in saved_thickness: - self.trace_thickness[i] = saved_thickness[node_name] - else: - self.trace_thickness[i] = DEFAULT_LINE_THICKNESS - if node_name in saved_style: - self.trace_style[i] = saved_style[node_name] - else: - self.trace_style[i] = '-' item.setToolTip("Voltage signal" if i < self.obj_dataext.volts_length else "Current signal") - self.trace_visibility[i] = False self.waveform_list.addItem(item) self.update_list_item_appearance(item, i) @@ -513,63 +588,46 @@ def filter_waveforms(self, text: str) -> None: def on_waveform_toggle(self, item: QListWidgetItem) -> None: index = item.data(Qt.ItemDataRole.UserRole) - self.trace_visibility[index] = item.isSelected() - if item.isSelected() and index not in self.trace_colors: - self.assign_trace_color(index) + # item.isSelected() is unreliable when setItemWidget is used — clicks land + # on the child widget and never update Qt's selection model. Toggle instead. + self.traces[index].visible = not self.traces[index].visible self.update_list_item_appearance(item, index) self.refresh_plot() - def assign_trace_color(self, index: int) -> None: - used_colors = set(self.trace_colors.values()) - available_colors = [color for color in self.color_palette if color not in used_colors] - if available_colors: - self.trace_colors[index] = available_colors[0] - else: - hue = (0.618033988749895 * len(self.trace_colors)) % 1.0 - color = QtGui.QColor.fromHsvF(hue, 0.7, 0.8) - self.trace_colors[index] = color.name() - self.save_config() - def update_list_item_appearance(self, item: QListWidgetItem, index: int) -> None: - node_name = self.trace_names.get(index, self.obj_dataext.NBList[index]) - is_selected = self.trace_visibility.get(index, False) + t = self.traces[index] widget = QWidget() layout = QHBoxLayout(widget) layout.setContentsMargins(6, 4, 6, 4) layout.setSpacing(10) icon_label = QLabel() - color = QColor(self.trace_colors[index]) if is_selected and index in self.trace_colors else QColor("#9E9E9E") - icon = self.create_colored_icon(color, is_selected) + color = QColor(t.color) if t.visible else QColor("#9E9E9E") + icon = self.create_colored_icon(color, t.visible) icon_label.setPixmap(icon.pixmap(18, 18)) - text_label = QLabel(node_name) - text_label.setStyleSheet("color: #212121; font-weight: 500;" if is_selected and index in self.trace_colors else "color: #757575; font-weight: normal;") + text_label = QLabel(t.name) + text_label.setStyleSheet("color: #212121; font-weight: 500;" if t.visible else "color: #757575; font-weight: normal;") layout.addWidget(icon_label) layout.addWidget(text_label) layout.addStretch() self.waveform_list.setItemWidget(item, widget) - item.setText(node_name) + item.setText(t.name) def select_all_waveforms(self) -> None: for i in range(self.waveform_list.count()): item = self.waveform_list.item(i) if item and not item.isHidden(): - item.setSelected(True) index = item.data(Qt.ItemDataRole.UserRole) - self.trace_visibility[index] = True - if index not in self.trace_colors: - self.assign_trace_color(index) + self.traces[index].visible = True self.update_list_item_appearance(item, index) self.refresh_plot() def deselect_all_waveforms(self) -> None: - self.waveform_list.clearSelection() - for index in self.trace_visibility: - self.trace_visibility[index] = False + for t in self.traces.values(): + t.visible = False for i in range(self.waveform_list.count()): item = self.waveform_list.item(i) if item: - index = item.data(Qt.ItemDataRole.UserRole) - self.update_list_item_appearance(item, index) + self.update_list_item_appearance(item, item.data(Qt.ItemDataRole.UserRole)) self.refresh_plot() def show_list_context_menu(self, position: QtCore.QPoint) -> None: @@ -600,14 +658,9 @@ def show_list_context_menu(self, position: QtCore.QPoint) -> None: rename_action.triggered.connect(lambda: self.rename_trace(item)) index = item.data(Qt.ItemDataRole.UserRole) - visible = False - if index in self.active_traces and self.active_traces[index]: - visible = self.active_traces[index].get_visible() - - hide_show_action = menu.addAction("Show" if not visible else "Hide") - if visible: - hide_show_action.setCheckable(True) - hide_show_action.setChecked(True) + t = self.traces[index] + + hide_show_action = menu.addAction("Hide" if t.visible else "Show") hide_show_action.triggered.connect(lambda: self.toggle_trace_visibility([item])) menu.addSeparator() @@ -615,16 +668,7 @@ def show_list_context_menu(self, position: QtCore.QPoint) -> None: properties_action = menu.addAction("Figure Options...") properties_action.triggered.connect(self.open_figure_options) - menu.exec_(self.waveform_list.mapToGlobal(position)) - - def show_canvas_context_menu(self, position: QtCore.QPoint) -> None: - menu = QMenu() - export_action = menu.addAction("Export Image...") - export_action.triggered.connect(self.export_image) - menu.addSeparator() - clear_action = menu.addAction("Clear Plot") - clear_action.triggered.connect(self.clear_plot) - menu.exec_(self.canvas.mapToGlobal(position)) + menu.exec(self.waveform_list.mapToGlobal(position)) def populate_color_menu(self, menu: QMenu, selected_items: List[QListWidgetItem]) -> None: color_widget = QWidget() @@ -655,27 +699,24 @@ def change_color_and_close(self, items: List[QListWidgetItem], color: str, menu: def change_color(self, items: List[QListWidgetItem], color: str) -> None: for item in items: index = item.data(Qt.ItemDataRole.UserRole) - self.trace_colors[index] = color + self.traces[index].update_line(color=color) self.update_list_item_appearance(item, index) - if index in self.active_traces and self.active_traces[index]: - self.active_traces[index].set_color(color) if self.timing_check.isChecked() and hasattr(self, 'axes'): self.update_timing_tick_colors() - if hasattr(self, 'timing_annotations') and index in self.timing_annotations: - self.timing_annotations[index].set_color(color) + for ann_text in self.timing_annotations.get(index, []): + ann_text.set_color(color) self.save_config() self.canvas.draw() def update_timing_tick_colors(self) -> None: if not hasattr(self, 'axes'): return - visible_indices = [i for i, v in self.trace_visibility.items() if v] + visible_indices = [i for i, t in self.traces.items() if t.visible] ytick_labels = self.axes.get_yticklabels() for i, label in enumerate(ytick_labels): if i < len(visible_indices): idx = visible_indices[::-1][i] - if idx in self.trace_colors: - label.set_color(self.trace_colors[idx]) + label.set_color(self.traces[idx].color) def change_color_dialog(self, items: List[QListWidgetItem]) -> None: color = QColorDialog.getColor() @@ -684,44 +725,45 @@ def change_color_dialog(self, items: List[QListWidgetItem]) -> None: def change_thickness(self, items: List[QListWidgetItem], thickness: float) -> None: for item in items: - index = item.data(Qt.ItemDataRole.UserRole) - self.trace_thickness[index] = thickness - if index in self.active_traces and self.active_traces[index]: - self.active_traces[index].set_linewidth(thickness) + self.traces[item.data(Qt.ItemDataRole.UserRole)].update_line(thickness=thickness) self.save_config() self.canvas.draw() def change_style(self, items: List[QListWidgetItem], style: str) -> None: + needs_replot = style == 'steps-post' for item in items: index = item.data(Qt.ItemDataRole.UserRole) - self.trace_style[index] = style - if index in self.active_traces and self.active_traces[index]: - if style == 'steps-post': - self.refresh_plot() - return - else: - self.active_traces[index].set_linestyle(style) + if needs_replot: + self.traces[index].style = style + else: + self.traces[index].update_line(style=style) self.save_config() - self.canvas.draw() + if needs_replot: + self.refresh_plot() + else: + self.canvas.draw() def rename_trace(self, item: QListWidgetItem) -> None: index = item.data(Qt.ItemDataRole.UserRole) - current_name = self.trace_names.get(index, self.obj_dataext.NBList[index]) - new_name, ok = QInputDialog.getText(self, "Rename Trace", "New name:", text=current_name) - if ok and new_name and new_name != current_name: - self.trace_names[index] = new_name - self.update_list_item_appearance(item, index) + t = self.traces[index] + new_name, ok = QInputDialog.getText(self, "Rename Trace", "New name:", text=t.name) + if ok and new_name and new_name != t.name: + t.name = new_name self.obj_dataext.NBList[index] = new_name + self.update_list_item_appearance(item, index) if self.legend_check.isChecked(): self.refresh_plot() def toggle_trace_visibility(self, items: List[QListWidgetItem]) -> None: - any_visible = any(item.data(Qt.ItemDataRole.UserRole) in self.active_traces and self.active_traces[item.data(Qt.ItemDataRole.UserRole)].get_visible() for item in items) + # Use t.visible as single source of truth — same path as left-click toggle. + # Going directly to line_object.set_visible() bypasses refresh_plot and + # gets stomped the next time anything triggers a redraw. + any_visible = any(self.traces[item.data(Qt.ItemDataRole.UserRole)].visible for item in items) for item in items: index = item.data(Qt.ItemDataRole.UserRole) - if index in self.active_traces and self.active_traces[index]: - self.active_traces[index].set_visible(not any_visible) - self.canvas.draw() + self.traces[index].visible = not any_visible + self.update_list_item_appearance(item, index) + self.refresh_plot() def open_figure_options(self) -> None: try: @@ -731,7 +773,8 @@ def open_figure_options(self) -> None: from matplotlib.backends.qt_compat import QtWidgets from matplotlib.backends.qt_editor import _formlayout if hasattr(_formlayout, 'FormDialog'): - options = [('Title', self.fig.suptitle('').get_text())] + current_title = self.fig._suptitle.get_text() if self.fig._suptitle is not None else '' + options = [('Title', current_title)] if hasattr(self, 'axes'): options.extend([('X Label', self.axes.get_xlabel()), ('Y Label', self.axes.get_ylabel()), ('X Min', self.axes.get_xlim()[0]), ('X Max', self.axes.get_xlim()[1]), ('Y Min', self.axes.get_ylim()[0]), ('Y Max', self.axes.get_ylim()[1])]) dialog = _formlayout.FormDialog(options, parent=self, title='Figure Options') @@ -752,14 +795,28 @@ def open_figure_options(self) -> None: QMessageBox.information(self, "Figure Options", "Basic figure editing is available through the toolbar.") def on_timing_view_changed(self, state: int) -> None: - timing_enabled = state == Qt.CheckState.Checked + timing_enabled = state == Qt.CheckState.Checked.value self.timing_box.content_area.setEnabled(timing_enabled) self.autoscale_check.setEnabled(not timing_enabled) self.refresh_plot() def refresh_plot(self) -> None: + # Preserve zoom when autoscale is off. + # Guard _last_was_timing: timing view y-axis is normalized [0..N] space, + # not voltage/current — restoring it into a normal view clips all signals. + saved_xlim = saved_ylim = None + if (not self.autoscale_check.isChecked() + and not self.timing_check.isChecked() + and not self._last_was_timing + and hasattr(self, 'axes')): + saved_xlim = self.axes.get_xlim() + saved_ylim = self.axes.get_ylim() + + self._func_line = None # fig.clear() below wipes all artists + self.timing_annotations.clear() self.fig.clear() - self.active_traces.clear() + for t in self.traces.values(): + t.line_object = None if self.timing_check.isChecked(): self.axes = self.fig.add_subplot(111) self.plot_timing_diagram() @@ -775,20 +832,26 @@ def refresh_plot(self) -> None: self.on_push_dc() if hasattr(self, 'axes'): self.axes.grid(self.grid_check.isChecked()) + if saved_xlim is not None: + self.axes.set_xlim(saved_xlim) + self.axes.set_ylim(saved_ylim) if self.legend_check.isChecked(): - plt.subplots_adjust(top=0.85, bottom=0.1) + self.fig.subplots_adjust(top=0.85, bottom=0.1) self.position_legend() else: - plt.subplots_adjust(top=0.95, bottom=0.1) + self.fig.subplots_adjust(top=0.95, bottom=0.1) + self._restore_cursors() self.canvas.draw() + self._last_was_timing = self.timing_check.isChecked() def position_legend(self) -> None: if hasattr(self, 'axes') and self.legend_check.isChecked(): handles, labels = [], [] - for idx in sorted(self.trace_visibility.keys()): - if self.trace_visibility.get(idx) and idx in self.active_traces and self.active_traces[idx]: - handles.append(self.active_traces[idx]) - labels.append(self.trace_names.get(idx, self.obj_dataext.NBList[idx])) + for idx in sorted(self.traces.keys()): + t = self.traces[idx] + if t.visible and t.line_object: + handles.append(t.line_object) + labels.append(t.name) if handles: ncol = min(6, len(handles)) if len(handles) > 6 else min(4, len(handles)) legend = self.axes.legend(handles, labels, bbox_to_anchor=(0.5, 1.02), loc='lower center', ncol=ncol, frameon=True, fancybox=False, shadow=False, fontsize=LEGEND_FONT_SIZE, borderaxespad=0, columnspacing=1.5) @@ -798,202 +861,164 @@ def position_legend(self) -> None: frame.set_linewidth(1) frame.set_alpha(0.95) + def _get_transient_start_idx(self, time_data: "np.ndarray") -> int: + """Return the index into time_data where the .tran start time begins, or 0.""" + try: + with open(os.path.join(self.file_path, "analysis"), 'r') as f: + parts = f.read().strip().split() + if len(parts) >= 4 and parts[0] == '.tran': + start_time = float(parts[3]) + if start_time > 0: + return int(np.searchsorted(time_data, start_time)) + except Exception: + pass + return 0 + def plot_timing_diagram(self) -> None: - """ - Plot digital timing diagram with proper time offset handling. - - This method now correctly handles transient analysis with non-zero start times - by detecting and applying the appropriate time offset. - """ - # Clear any existing timing annotations + """Plot digital timing diagram with normalized trace heights.""" self.timing_annotations.clear() - visible_indices = [i for i, v in self.trace_visibility.items() if v] - if not visible_indices: - self.axes.text(0.5, 0.5, 'Select a waveform to display', - ha='center', va='center', transform=self.axes.transAxes) + if self.plot_type[0] != DataExtraction.TRANSIENT_ANALYSIS: + self.axes.text(0.5, 0.5, 'Digital timing view is only\navailable for transient analysis.', + ha='center', va='center', transform=self.axes.transAxes, + fontsize=11, color='#757575') self.axes.set_yticks([]) self.axes.set_yticklabels([]) return - # Collect all voltage data for threshold calculation - all_voltage_data = [] - for idx in visible_indices: - if idx < self.obj_dataext.volts_length: - all_voltage_data.extend(self.obj_dataext.y[idx]) - - # If no voltage data, use current data - if not all_voltage_data: - for idx in visible_indices: - all_voltage_data.extend(self.obj_dataext.y[idx]) - - if not all_voltage_data: + visible_indices = [i for i, t in self.traces.items() if t.visible] + if not visible_indices: + self.axes.text(0.5, 0.5, 'Select a waveform to display', + ha='center', va='center', transform=self.axes.transAxes) + self.axes.set_yticks([]) + self.axes.set_yticklabels([]) return - all_voltage_data = np.array(all_voltage_data, dtype=float) - vmin = np.min(all_voltage_data) - vmax = np.max(all_voltage_data) - - # Handle threshold setting - if self.threshold_spinbox.value() == self.threshold_spinbox.minimum(): - self.logic_threshold = vmin + 0.7 * (vmax - vmin) - self.threshold_spinbox.setSpecialValueText(f"Auto ({self.logic_threshold:.3f} V)") - else: - self.logic_threshold = self.threshold_spinbox.value() + manual_threshold = (None if self.threshold_spinbox.value() == self.threshold_spinbox.minimum() + else self.threshold_spinbox.value()) + if manual_threshold is None: + self.threshold_spinbox.setSpecialValueText("Auto (midpoint)") + self.logic_thresholds = {} - # Get time data + # Build local float arrays for all traces — never touch obj_dataext time_data = np.asarray(self.obj_dataext.x, dtype=float) - - # CRITICAL FIX: Detect and handle transient analysis time offset - # For transient analysis with .tran step stop start, we need to find - # where the actual analysis begins - - # Check if this is a transient analysis + y_data = {i: np.asarray(self.obj_dataext.y[i], dtype=float) + for i in range(len(self.obj_dataext.y))} + if self.plot_type[0] == DataExtraction.TRANSIENT_ANALYSIS: - # Read the analysis file to get the actual start time - try: - with open(os.path.join(self.file_path, "analysis"), 'r') as f: - analysis_content = f.read().strip() - - # Parse .tran command: .tran step stop start - if analysis_content.startswith('.tran'): - parts = analysis_content.split() - if len(parts) >= 4: - try: - # Convert scientific notation to float - start_time = float(parts[3]) - - # If start_time is not 0, we need to offset our data - if start_time > 0: - # Find the index where time >= start_time - start_idx = np.searchsorted(time_data, start_time) - - # Adjust time_data to start from the correct point - if start_idx > 0 and start_idx < len(time_data): - time_data = time_data[start_idx:] - - # Also adjust all data arrays - for idx in list(self.obj_dataext.y.keys()): - if idx < len(self.obj_dataext.y): - self.obj_dataext.y[idx] = self.obj_dataext.y[idx][start_idx:] - except (ValueError, IndexError): - pass # If parsing fails, use full data - except Exception as e: - logger.debug(f"Could not parse analysis file for time offset: {e}") - - # Prepare spacing for multiple traces - spacing_ref = max(1.0, vmax) - spacing = self.vertical_spacing * spacing_ref + start_idx = self._get_transient_start_idx(time_data) + if 0 < start_idx < len(time_data): + time_data = time_data[start_idx:] + y_data = {i: arr[start_idx:] for i, arr in y_data.items()} + + # Each trace occupies exactly 1.0 normalized unit of y-space. + # spacing = vertical_spacing (e.g. 1.2 → 20% gap between traces). + # This guarantees uniform height for all signals regardless of voltage domain. + spacing = self.vertical_spacing yticks, ylabels = [], [] - - # Calculate annotation offset based on time range - annotation_offset_base = 0.01 * (time_data[-1] - time_data[0]) if len(time_data) > 1 else 0.01 - # Plot each visible trace as a digital signal for rank, idx in enumerate(visible_indices[::-1]): - # Get the raw data for this trace - raw_data = np.asarray(self.obj_dataext.y[idx], dtype=float) - - # Make sure raw_data matches time_data length after offset adjustment - if len(raw_data) > len(time_data): - raw_data = raw_data[:len(time_data)] - elif len(raw_data) < len(time_data): - # This shouldn't happen, but handle it gracefully - time_data = time_data[:len(raw_data)] - + raw_data = y_data[idx] + + # Safety clamp — guards against malformed simulation output where a + # y array is shorter or longer than the time axis. Use a local + # trace_time so time_data is never mutated across iterations. + n = min(len(raw_data), len(time_data)) + raw_data = raw_data[:n] + trace_time = time_data[:n] + trace_vmin, trace_vmax = np.min(raw_data), np.max(raw_data) - - # Convert to digital logic levels - logic_data = np.where(raw_data > self.logic_threshold, trace_vmax, trace_vmin) - - # Apply vertical offset for stacking - logic_offset = logic_data + rank * spacing - - # Get trace properties - color = self.trace_colors.get(idx, 'blue') - thickness = self.trace_thickness.get(idx, DEFAULT_LINE_THICKNESS) - label = self.trace_names.get(idx, self.obj_dataext.NBList[idx]) - - # Plot the digital waveform - line, = self.axes.step(time_data, logic_offset, where="post", - linewidth=thickness, color=color, label=label) - self.active_traces[idx] = line - - # Add y-axis tick for this trace - y_center = rank * spacing + (trace_vmax + trace_vmin) / 2.0 + trace_unit = "V" if idx < self.obj_dataext.volts_length else "A" + + if trace_vmax - trace_vmin < 1e-10: + # Constant (DC) signal — state indeterminate, park at 0.5. + # No threshold line drawn (nothing to threshold against). + logic_normalized = np.full(n, 0.5) + else: + # Per-trace threshold: midpoint of its own swing (CMOS VDD/2 convention). + # Manual override applies the user's voltage, normalized into [0,1] for + # this trace so the axhline always sits within the trace bounds. + threshold = (manual_threshold if manual_threshold is not None + else (trace_vmin + trace_vmax) / 2.0) + logic_normalized = np.where(raw_data > threshold, 1.0, 0.0) + threshold_norm = float(np.clip( + (threshold - trace_vmin) / (trace_vmax - trace_vmin), 0.0, 1.0 + )) + self.logic_thresholds[idx] = threshold_norm + + logic_offset = logic_normalized + rank * spacing + + t = self.traces[idx] + line, = self.axes.step(trace_time, logic_offset, where="post", + linewidth=t.thickness, color=t.color, label=t.name) + t.line_object = line + + # y_center is always rank * spacing + 0.5 in normalized space. + y_center = rank * spacing + 0.5 yticks.append(y_center) - ylabels.append(label) - - # Add voltage annotation at the end - # Add voltage annotation at the right edge of the graph - # Add voltage annotation at the right edge of the graph - if len(raw_data) > 0: - final_voltage = f"{float(raw_data[-1]):.3f} V" - # Position the text at the right edge of the plot area - # Using transform coordinates: 1.01 means just outside the right edge - text_obj = self.axes.text(1.01, y_center, final_voltage, - transform=self.axes.get_yaxis_transform(), - va='center', ha='left', - fontsize=8, color=color, - clip_on=False) # This allows text to appear outside axes - self.timing_annotations[idx] = text_obj - - # Set y-axis limits and labels - total_height = (len(visible_indices) - 1) * spacing + vmax - self.axes.set_ylim(vmin - 0.1 * spacing_ref, total_height + 0.1 * spacing_ref) + ylabels.append(t.name) + + ann = [] + xform = self.axes.get_yaxis_transform() + if trace_vmax - trace_vmin < 1e-10: + ann.append(self.axes.text( + 1.01, y_center, + f"DC: {_format_measurement(float(trace_vmax), trace_unit)}", + transform=xform, va='center', ha='left', + fontsize=8, color=t.color, clip_on=False)) + else: + ann.append(self.axes.text( + 1.01, rank * spacing + 0.82, + f"H: {_format_measurement(float(trace_vmax), trace_unit)}", + transform=xform, va='center', ha='left', + fontsize=8, color=t.color, clip_on=False)) + ann.append(self.axes.text( + 1.01, rank * spacing + 0.18, + f"L: {_format_measurement(float(trace_vmin), trace_unit)}", + transform=xform, va='center', ha='left', + fontsize=8, color=t.color, clip_on=False)) + freq = _detect_frequency(trace_time, logic_normalized) + if freq is not None: + ann.append(self.axes.text( + 1.01, y_center, _format_frequency(freq), + transform=xform, va='center', ha='left', + fontsize=7.5, color=t.color, alpha=0.75, clip_on=False)) + self.timing_annotations[idx] = ann + + # Y-axis bounds: normalized traces sit in [0,1] per rank, evenly spaced. + total_height = (len(visible_indices) - 1) * spacing + 1.0 + margin = 0.15 * spacing + self.axes.set_ylim(-margin, total_height + margin) self.axes.set_yticks(yticks) self.axes.set_yticklabels(ylabels, fontsize=8) - - # Update tick colors to match trace colors + self.update_timing_tick_colors() - - # Set time axis with proper units - self.set_time_axis_label() - - # Add threshold line - self.axes.axhline(y=self.logic_threshold, color='red', linestyle=':', - alpha=THRESHOLD_ALPHA, linewidth=1) - - # Add title if legend is not shown + self.set_time_axis_label(time_data) + + # Threshold lines: logic_thresholds stores normalized [0,1] position, + # so axhline y = threshold_norm + rank * spacing sits correctly within the trace. + for rank, idx in enumerate(visible_indices[::-1]): + if idx in self.logic_thresholds: + self.axes.axhline(y=self.logic_thresholds[idx] + rank * spacing, + color='red', linestyle=':', alpha=THRESHOLD_ALPHA, linewidth=0.8) + if not self.legend_check.isChecked(): - self.axes.set_title(f'Digital Timing Diagram (Threshold: {self.logic_threshold:.3f} V)', - fontsize=10, pad=10) - def set_time_axis_label(self) -> None: + self.axes.set_title('Digital Timing Diagram', fontsize=10, pad=10) + + def set_time_axis_label(self, time_data: Optional["np.ndarray"] = None) -> None: if not hasattr(self, 'axes') or not hasattr(self.obj_dataext, 'x'): return - time_data = np.array(self.obj_dataext.x, dtype=float) + if time_data is None: + time_data = np.asarray(self.obj_dataext.x, dtype=float) if len(time_data) < 2: self.axes.set_xlabel('Time (s)', fontsize=10) return - time_span = abs(time_data[-1] - time_data[0]) - if time_span == 0: - scale, unit = 1, 's' - elif time_span < TIME_UNIT_THRESHOLD_PS: - scale, unit = 1e12, 'ps' - elif time_span < TIME_UNIT_THRESHOLD_NS: - scale, unit = 1e9, 'ns' - elif time_span < TIME_UNIT_THRESHOLD_US: - scale, unit = 1e6, 'µs' - elif time_span < TIME_UNIT_THRESHOLD_MS: - scale, unit = 1e3, 'ms' - else: - scale, unit = 1, 's' + scale, unit = self._get_time_scale_and_unit(time_data) scaled_time = time_data * scale - for line in self.active_traces.values(): + for line in (t.line_object for t in self.traces.values()): if line: - y_data = line.get_ydata() - # Step plots have one more y-value than x-value - if len(y_data) == len(scaled_time) + 1: - x_step_data = np.append(scaled_time, scaled_time[-1]) - line.set_data(x_step_data, y_data) - elif len(y_data) == len(scaled_time): - line.set_xdata(scaled_time) - + line.set_xdata(line.get_xdata() * scale) self.axes.set_xlim(scaled_time[0], scaled_time[-1]) - if hasattr(self, 'cursor_lines'): - for i, line in enumerate(self.cursor_lines): - if line and i < len(self.cursor_positions) and self.cursor_positions[i] is not None: - line.set_xdata([self.cursor_positions[i] * scale, self.cursor_positions[i] * scale]) self.axes.set_xlabel(f'Time ({unit})', fontsize=10) def on_threshold_changed(self, value: float) -> None: @@ -1006,21 +1031,85 @@ def on_spacing_changed(self, value: int) -> None: if self.timing_check.isChecked(): self.refresh_plot() + def _find_nearest_cursor(self, event) -> Optional[int]: + """Return cursor index if the click is within 8px of an existing cursor line.""" + if not self.cursor_lines or not hasattr(self, 'axes') or event.xdata is None: + return None + xlim = self.axes.get_xlim() + width_px = self.axes.get_window_extent().width + if width_px == 0: + return None + threshold = 8 * (xlim[1] - xlim[0]) / width_px + for i, line in enumerate(self.cursor_lines): + if line is None: + continue + if abs(event.xdata - line.get_xdata()[0]) < threshold: + return i + return None + + def _update_cursor_position(self, cursor_num: int, x_pos_scaled: float) -> None: + """Move an existing cursor line without recreating it (fast path for dragging).""" + if cursor_num >= len(self.cursor_lines) or self.cursor_lines[cursor_num] is None: + self.set_cursor(cursor_num, x_pos_scaled) + return + self.cursor_lines[cursor_num].set_xdata([x_pos_scaled, x_pos_scaled]) + scale = self._current_time_scale() + x_pos_original = x_pos_scaled / scale + self.cursor_positions[cursor_num] = x_pos_original + label = self.cursor1_label if cursor_num == 0 else self.cursor2_label + label.setText(f"Cursor {cursor_num + 1}: {x_pos_scaled:.6g}") + if len(self.cursor_positions) >= 2 and all(p is not None for p in self.cursor_positions[:2]): + delta_original = abs(self.cursor_positions[1] - self.cursor_positions[0]) + self.delta_label.setText(f"Delta: {delta_original * scale:.6g}") + if delta_original > 0: + self.measure_label.setText(f"Freq: {1.0 / delta_original:.6g} Hz") + self.canvas.draw_idle() + + def _get_time_scale_and_unit(self, time_data: Optional["np.ndarray"] = None) -> Tuple[float, str]: + """Single source of truth for time-axis unit selection. + + All callers (set_time_axis_label, _current_time_scale, set_cursor) + derive their scale factor from here — ensures they can never diverge. + time_data defaults to obj_dataext.x; pass a trimmed slice when a + subset of the axis is being displayed (e.g. transient start offset). + """ + if time_data is None: + time_data = np.asarray(self.obj_dataext.x, dtype=float) + time_span = abs(time_data[-1] - time_data[0]) if len(time_data) > 1 else 0.0 + if time_span == 0: return 1.0, 's' + if time_span < TIME_UNIT_THRESHOLD_PS: return 1e12, 'ps' + if time_span < TIME_UNIT_THRESHOLD_NS: return 1e9, 'ns' + if time_span < TIME_UNIT_THRESHOLD_US: return 1e6, 'µs' + if time_span < TIME_UNIT_THRESHOLD_MS: return 1e3, 'ms' + return 1.0, 's' + + def _current_time_scale(self) -> float: + return self._get_time_scale_and_unit()[0] + def on_canvas_click(self, event) -> None: - if hasattr(self, 'axes') and event.inaxes == self.axes: - if event.button == 1: + if not hasattr(self, 'axes') or event.inaxes != self.axes: + return + if self.nav_toolbar.mode: + return + near = self._find_nearest_cursor(event) + if event.button == 1: + if near is not None: + self._drag_cursor_idx = near + else: + self._drag_cursor_idx = None self.set_cursor(0, event.xdata) - elif event.button == 3: + elif event.button == 3: + if near is not None: + self._drag_cursor_idx = near + else: + self._drag_cursor_idx = None self.set_cursor(1, event.xdata) + def on_canvas_release(self, event) -> None: + self._drag_cursor_idx = None + def set_cursor(self, cursor_num: int, x_pos_scaled: float) -> None: - time_data = np.array(self.obj_dataext.x, dtype=float) - time_span = abs(time_data[-1] - time_data[0]) if len(time_data) > 1 else 0 - if time_span < TIME_UNIT_THRESHOLD_PS: scale = 1e12 - elif time_span < TIME_UNIT_THRESHOLD_NS: scale = 1e9 - elif time_span < TIME_UNIT_THRESHOLD_US: scale = 1e6 - elif time_span < TIME_UNIT_THRESHOLD_MS: scale = 1e3 - else: scale = 1 + scale = self._current_time_scale() x_pos_original = x_pos_scaled / scale if cursor_num < len(self.cursor_lines) and self.cursor_lines[cursor_num]: @@ -1051,7 +1140,10 @@ def set_cursor(self, cursor_num: int, x_pos_scaled: float) -> None: def clear_cursors(self) -> None: for line in self.cursor_lines: if line: - line.remove() + try: + line.remove() + except ValueError: + pass # already removed by fig.clear() self.cursor_lines.clear() self.cursor_positions.clear() self.cursor1_label.setText("Cursor 1: Not set") @@ -1060,9 +1152,31 @@ def clear_cursors(self) -> None: self.measure_label.setText("") self.canvas.draw() + def _restore_cursors(self) -> None: + """Re-create cursor axvlines after fig.clear(), using stored positions.""" + if not hasattr(self, 'axes') or not self.cursor_positions: + return + scale = self._current_time_scale() + colors = ['red', 'blue'] + new_lines: List[Optional[Line2D]] = [] + for i, x_orig in enumerate(self.cursor_positions): + if x_orig is None: + new_lines.append(None) + continue + color = colors[i] if i < len(colors) else 'green' + line = self.axes.axvline( + x=x_orig * scale, color=color, linestyle='--', alpha=CURSOR_ALPHA + ) + new_lines.append(line) + self.cursor_lines = new_lines + if new_lines: + logger.debug("Restored %d cursor(s) after plot refresh", len(new_lines)) + def on_mouse_move(self, event) -> None: if event.inaxes: self.coord_label.setText(f"X: {event.xdata:.6g}, Y: {event.ydata:.6g}") + if self._drag_cursor_idx is not None: + self._update_cursor_position(self._drag_cursor_idx, event.xdata) else: self.coord_label.setText("X: --, Y: --") @@ -1070,7 +1184,14 @@ def on_key_press(self, event) -> None: if event.key == 'g': self.grid_check.toggle() elif event.key == 'l': self.legend_check.toggle() elif event.key == 'p': self.open_figure_options() - elif event.key == 'escape': self.clear_cursors() + elif event.key == 'escape': + mode = str(self.nav_toolbar.mode).lower() + if 'zoom' in mode: + self.nav_toolbar.zoom() + elif 'pan' in mode: + self.nav_toolbar.pan() + else: + self.clear_cursors() def on_scroll(self, event) -> None: if not event.inaxes: return @@ -1091,9 +1212,9 @@ def export_image(self) -> None: file_name, file_filter = QFileDialog.getSaveFileName(self, "Export Image", "", "PNG Files (*.png);;SVG Files (*.svg);;All Files (*)") if file_name: try: - format = 'svg' if "svg" in file_filter else 'png' - if '.' not in os.path.basename(file_name): file_name += f'.{format}' - self.fig.savefig(file_name, format=format, dpi=DEFAULT_EXPORT_DPI, bbox_inches='tight') + fmt = 'svg' if "svg" in file_filter else 'png' + if '.' not in os.path.basename(file_name): file_name += f'.{fmt}' + self.fig.savefig(file_name, format=fmt, dpi=DEFAULT_EXPORT_DPI, bbox_inches='tight') self.status_bar.showMessage(f"Image exported to {file_name}", 3000) except Exception as e: logger.error(f"Error exporting image: {e}") @@ -1104,10 +1225,28 @@ def clear_plot(self) -> None: self.deselect_all_waveforms() def zoom_in(self) -> None: - if hasattr(self, 'axes'): self.nav_toolbar.zoom() + if not hasattr(self, 'axes'): + return + xlim, ylim = self.axes.get_xlim(), self.axes.get_ylim() + x_center = (xlim[0] + xlim[1]) / 2 + y_center = (ylim[0] + ylim[1]) / 2 + x_half = (xlim[1] - xlim[0]) * DEFAULT_ZOOM_FACTOR / 2 + y_half = (ylim[1] - ylim[0]) * DEFAULT_ZOOM_FACTOR / 2 + self.axes.set_xlim(x_center - x_half, x_center + x_half) + self.axes.set_ylim(y_center - y_half, y_center + y_half) + self.canvas.draw() def zoom_out(self) -> None: - if hasattr(self, 'axes'): self.nav_toolbar.back() + if not hasattr(self, 'axes'): + return + xlim, ylim = self.axes.get_xlim(), self.axes.get_ylim() + x_center = (xlim[0] + xlim[1]) / 2 + y_center = (ylim[0] + ylim[1]) / 2 + x_half = (xlim[1] - xlim[0]) / (DEFAULT_ZOOM_FACTOR * 2) + y_half = (ylim[1] - ylim[0]) / (DEFAULT_ZOOM_FACTOR * 2) + self.axes.set_xlim(x_center - x_half, x_center + x_half) + self.axes.set_ylim(y_center - y_half, y_center + y_half) + self.canvas.draw() def reset_view(self) -> None: if hasattr(self, 'axes'): self.nav_toolbar.home() @@ -1121,54 +1260,49 @@ def toggle_legend(self) -> None: self.refresh_plot() def plot_function(self) -> None: - # This function remains complex, will copy simplified logic if possible - # For now, keeping the original logic function_text = self.func_input.text() if not function_text: QMessageBox.warning(self, "Input Error", "Function input cannot be empty.") return - # Basic parsing (this is a simplified example, not a full math parser) - # It expects "trace1 vs trace2" or a simple expression with +, -, *, / - # For security, avoid using eval() directly on user input in production. - # This implementation is for a controlled environment. + # Remove previous function trace before adding new one + if self._func_line is not None: + try: + self._func_line.remove() + except ValueError: + pass # already cleared by a refresh_plot + self._func_line = None - if 'vs' in function_text: - parts = [p.strip() for p in function_text.split('vs')] - if len(parts) != 2: + if ' vs ' in function_text: + parts = function_text.split(' vs ', 1) + y_name, x_name = parts[0].strip(), parts[1].strip() + if not y_name or not x_name: QMessageBox.warning(self, "Syntax Error", "Use format 'trace1 vs trace2'.") return - y_name, x_name = parts[0], parts[1] try: x_idx = self.obj_dataext.NBList.index(x_name) y_idx = self.obj_dataext.NBList.index(y_name) x_data = np.array(self.obj_dataext.y[x_idx], dtype=float) y_data = np.array(self.obj_dataext.y[y_idx], dtype=float) - is_voltage_x = x_idx < self.volts_length is_voltage_y = y_idx < self.volts_length - - self.axes.plot(x_data, y_data, label=function_text) + line, = self.axes.plot(x_data, y_data, label=function_text) + self._func_line = line self.axes.set_xlabel(f"{x_name} ({'V' if is_voltage_x else 'A'})") self.axes.set_ylabel(f"{y_name} ({'V' if is_voltage_y else 'A'})") - except ValueError: QMessageBox.warning(self, "Trace Not Found", f"Could not find one of the traces: {x_name}, {y_name}") return else: - # Simple expression evaluation (use with caution) try: - # Replace trace names with data arrays - result_expr = function_text - for i, name in enumerate(self.obj_dataext.NBList): - if name in result_expr: - result_expr = result_expr.replace(name, f"np.array(self.obj_dataext.y[{i}], dtype=float)") - - # Evaluate the expression - y_data = eval(result_expr, {"np": np, "self": self}) + data_map = { + name: np.array(self.obj_dataext.y[i], dtype=float) + for i, name in enumerate(self.obj_dataext.NBList) + } + y_data = _safe_eval(function_text, data_map) x_data = np.array(self.obj_dataext.x, dtype=float) - self.axes.plot(x_data, y_data, label=function_text) - + line, = self.axes.plot(x_data, y_data, label=function_text) + self._func_line = line except Exception as e: QMessageBox.warning(self, "Evaluation Error", f"Could not plot function: {e}") return @@ -1179,15 +1313,15 @@ def plot_function(self) -> None: def multi_meter(self) -> None: - visible_indices = [i for i, v in self.trace_visibility.items() if v] - if not visible_indices: + visible = [(idx, t) for idx, t in self.traces.items() if t.visible] + if not visible: QMessageBox.warning(self, "Warning", "Please select at least one waveform") return location_x, location_y = 300, 300 - for idx in visible_indices: - is_voltage = idx < self.obj_dataext.volts_length + for idx, t in visible: rms_value = self.get_rms_value(self.obj_dataext.y[idx]) - meter = MultimeterWidgetClass(self.trace_names.get(idx, self.obj_dataext.NBList[idx]), rms_value, location_x, location_y, is_voltage) + meter = MultimeterWidgetClass(t.name, rms_value, location_x, location_y, idx < self.obj_dataext.volts_length) + self._meters.append(meter) # keep strong ref — no parent, otherwise GC'd if hasattr(self.obj_appconfig, 'dock_dict') and self.obj_appconfig.current_project['ProjectName'] in self.obj_appconfig.dock_dict: self.obj_appconfig.dock_dict[self.obj_appconfig.current_project['ProjectName']].append(meter) location_x += 50 @@ -1197,50 +1331,42 @@ def get_rms_value(self, data_points: List) -> Decimal: getcontext().prec = 5 return Decimal(str(np.sqrt(np.mean(np.square([float(x) for x in data_points]))))) - def redraw_cursors(self) -> None: - # This function might be redundant if set_time_axis_label handles cursor redraws - pass - def _plot_analysis_data(self, analysis_type: str) -> None: self.axes = self.fig.add_subplot(111) traces_plotted = 0 - for trace_index, is_visible in self.trace_visibility.items(): - if not is_visible: + first_visible = None + x_data = np.asarray(self.obj_dataext.x, dtype=float) + for idx, t in self.traces.items(): + if not t.visible: continue traces_plotted += 1 - color = self.trace_colors.get(trace_index, '#000000') - label = self.trace_names.get(trace_index, self.obj_dataext.NBList[trace_index]) - thickness = self.trace_thickness.get(trace_index, DEFAULT_LINE_THICKNESS) - style = self.trace_style.get(trace_index, '-') - x_data = np.asarray(self.obj_dataext.x, dtype=float) - y_data = np.asarray(self.obj_dataext.y[trace_index], dtype=float) - - plot_style = '-' if style == 'steps-post' else style - plot_func = self.axes.plot - if style == 'steps-post' and analysis_type in ['transient', 'dc']: + if first_visible is None: + first_visible = idx + y_data = np.asarray(self.obj_dataext.y[idx], dtype=float) + plot_style = '-' if t.style == 'steps-post' else t.style + plot_kwargs: dict = {} + if t.style == 'steps-post' and analysis_type in ['transient', 'dc']: plot_func = self.axes.step + plot_kwargs['where'] = 'post' elif analysis_type == 'ac_log': plot_func = self.axes.semilogx - - line, = plot_func(x_data, y_data, c=color, label=label, linewidth=thickness, linestyle=plot_style) - self.active_traces[trace_index] = line + else: + plot_func = self.axes.plot + line, = plot_func(x_data, y_data, color=t.color, label=t.name, + linewidth=t.thickness, linestyle=plot_style, **plot_kwargs) + t.line_object = line if analysis_type in ['ac_linear', 'ac_log']: self.axes.set_xlabel('Frequency (Hz)') - elif analysis_type == 'transient': - # set_time_axis_label is now called from refresh_plot - pass elif analysis_type == 'dc': self.axes.set_xlabel('Voltage Sweep (V)') - - # Set Y label based on the first plotted trace - first_visible = next((i for i, v in self.trace_visibility.items() if v), None) + if first_visible is not None: - self.axes.set_ylabel('Voltage (V)' if first_visible < self.volts_length else 'Current (A)') + self.axes.set_ylabel('Voltage (V)' if first_visible < self.volts_length else 'Current (A)') if traces_plotted == 0: self.axes.text(0.5, 0.5, 'Please select a waveform to plot', ha='center', va='center', transform=self.axes.transAxes) - + if analysis_type == 'transient': self.set_time_axis_label() diff --git a/src/ngspiceSimulation/plotting_widgets.py b/src/ngspiceSimulation/plotting_widgets.py index d6e4df42a..91bc2fa22 100644 --- a/src/ngspiceSimulation/plotting_widgets.py +++ b/src/ngspiceSimulation/plotting_widgets.py @@ -40,7 +40,6 @@ def __init__(self, title: str = "", parent: Optional[QWidget] = None) -> None: super().__init__(parent) self.title = title - # **FIX**: Set proper size policy self.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Maximum) @@ -147,7 +146,6 @@ def __init__(self, node_branch: str, rms_value: Decimal, """ super().__init__() - # **FIX**: Don't force window size, let it be managed by parent self.node_branch = node_branch self.rms_value = rms_value self.location_x = location_x @@ -165,14 +163,7 @@ def __init__(self, node_branch: str, rms_value: Decimal, f"{node_branch} = {rms_value}") def _setup_ui(self) -> None: - """Set up the user interface elements.""" - # Create main container widget - self.multimeter_container = QWidget(self) - - # Create labels based on measurement type self._create_labels() - - # Set up layout self._setup_layout() def _create_labels(self) -> None: @@ -191,14 +182,11 @@ def _create_labels(self) -> None: self.rms_value_label = QLabel(f"{self.rms_value} {unit_text}") def _setup_layout(self) -> None: - """Set up the grid layout for the widget.""" layout = QGridLayout(self) layout.addWidget(self.type_label, 0, 0) layout.addWidget(self.rms_title_label, 0, 1) layout.addWidget(self.node_branch_value_label, 1, 0) layout.addWidget(self.rms_value_label, 1, 1) - - self.multimeter_container.setLayout(layout) def _configure_window(self) -> None: """Configure window properties and display the widget."""