From b43cf9f4ffeeea3be060d04275fe4307453c2403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frieder=20Sch=C3=BCler?= Date: Wed, 29 Apr 2026 21:01:03 +0200 Subject: [PATCH] Fix arg-type, return-value, and assignment mypy errors - nmt.py: Use bytes() instead of list[int] for send_message/send_periodic - emcy.py: Fix wait() return type to Optional[EmcyError] - network.py: Fix add_node/create_node return types and type narrowing - variable.py: Fix missing return in read(), annotate Bits.raw, add type: ignore for inherently dynamic type operations - sdo/base.py: Fix missing return in get_variable(), add type: ignore for SdoArray.__len__ - objectdictionary/__init__.py: Fix __getitem__ indexability, add missing return in get_variable(), annotate ODVariable.parent, fix decode_phys/encode_phys types, add type: ignore for dynamic encode_raw operations --- canopen/emcy.py | 2 +- canopen/network.py | 10 ++-- canopen/nmt.py | 8 ++-- canopen/objectdictionary/__init__.py | 33 +++++++------ canopen/sdo/base.py | 3 +- canopen/variable.py | 16 ++++--- test/test_od.py | 69 ++++++++++++++++++++++++++++ test/test_sdo.py | 3 ++ 8 files changed, 112 insertions(+), 32 deletions(-) diff --git a/canopen/emcy.py b/canopen/emcy.py index 22d1eba8..6a860f07 100644 --- a/canopen/emcy.py +++ b/canopen/emcy.py @@ -57,7 +57,7 @@ def reset(self): def wait( self, emcy_code: Optional[int] = None, timeout: float = 10 - ) -> EmcyError: + ) -> Optional[EmcyError]: """Wait for a new EMCY to arrive. :param emcy_code: EMCY code to wait for diff --git a/canopen/network.py b/canopen/network.py index 6a8d95f6..3628bbc7 100644 --- a/canopen/network.py +++ b/canopen/network.py @@ -136,7 +136,7 @@ def add_node( node: Union[int, RemoteNode, LocalNode], object_dictionary: Union[str, ObjectDictionary, None] = None, upload_eds: bool = False, - ) -> RemoteNode: + ) -> Union[RemoteNode, LocalNode]: """Add a remote node to the network. :param node: @@ -156,13 +156,14 @@ def add_node( if upload_eds: logger.info("Trying to read EDS from node %d", node) object_dictionary = import_from_node(node, self) - node = RemoteNode(node, object_dictionary) + node = RemoteNode(node, object_dictionary) # type: ignore[arg-type] + assert node.id is not None self[node.id] = node return node def create_node( self, - node: int, + node: Union[int, LocalNode], object_dictionary: Union[str, ObjectDictionary, None] = None, ) -> LocalNode: """Create a local node in the network. @@ -178,7 +179,8 @@ def create_node( The Node object that was added. """ if isinstance(node, int): - node = LocalNode(node, object_dictionary) + node = LocalNode(node, object_dictionary) # type: ignore[arg-type] + assert node.id is not None self[node.id] = node return node diff --git a/canopen/nmt.py b/canopen/nmt.py index 4637d315..1605c334 100644 --- a/canopen/nmt.py +++ b/canopen/nmt.py @@ -147,7 +147,7 @@ def send_command(self, code: int): super(NmtMaster, self).send_command(code) logger.info( "Sending NMT command 0x%X to node %d", code, self.id) - self.network.send_message(0, [code, self.id]) + self.network.send_message(0, bytes([code, self.id])) def wait_for_heartbeat(self, timeout: float = 10): """Wait until a heartbeat message is received.""" @@ -190,7 +190,7 @@ def start_node_guarding(self, period: float): """ if self._node_guarding_producer: self.stop_node_guarding() - self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, None, period, True) + self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, b'', period, True) def stop_node_guarding(self): """Stops the node guarding mechanism.""" @@ -225,7 +225,7 @@ def send_command(self, code: int) -> None: if self._state == 0: logger.info("Sending boot-up message") - self.network.send_message(0x700 + self.id, [0]) + self.network.send_message(0x700 + self.id, b'\x00') # The heartbeat service should start on the transition # between INITIALIZING and PRE-OPERATIONAL state @@ -256,7 +256,7 @@ def start_heartbeat(self, heartbeat_time_ms: int): if heartbeat_time_ms > 0: logger.info("Start the heartbeat timer, interval is %d ms", self._heartbeat_time_ms) self._send_task = self.network.send_periodic( - 0x700 + self.id, [self._state], heartbeat_time_ms / 1000.0) + 0x700 + self.id, bytes([self._state]), heartbeat_time_ms / 1000.0) def stop_heartbeat(self): """Stop the heartbeat service.""" diff --git a/canopen/objectdictionary/__init__.py b/canopen/objectdictionary/__init__.py index fa694c56..99544281 100644 --- a/canopen/objectdictionary/__init__.py +++ b/canopen/objectdictionary/__init__.py @@ -69,7 +69,7 @@ def export_od( finally: # If dest is opened in this fn, it should be closed if opened_here: - dest.close() + dest.close() # type: ignore[union-attr] def import_od( @@ -92,7 +92,7 @@ def import_od( return ObjectDictionary() if hasattr(source, "read"): # File like object - filename = source.name + filename = source.name # type: ignore[union-attr] elif hasattr(source, "tag"): # XML tree, probably from an EPF file filename = "od.epf" @@ -139,7 +139,10 @@ def __getitem__( if item is None: if isinstance(index, str) and '.' in index: idx, sub = index.split('.', maxsplit=1) - return self[idx][sub] + parent = self[idx] + if not isinstance(parent, (ODRecord, ODArray)): + raise KeyError(f"{pretty_index(index)} was not found in Object Dictionary") + return parent[sub] raise KeyError(f"{pretty_index(index)} was not found in Object Dictionary") return item @@ -188,6 +191,7 @@ def get_variable( return obj elif isinstance(obj, (ODRecord, ODArray)): return obj.get(subindex) + return None class ODRecord(MutableMapping): @@ -259,7 +263,7 @@ class ODArray(Mapping): def __init__(self, name: str, index: int): #: The :class:`~canopen.ObjectDictionary` owning the record. - self.parent = None + self.parent: Optional[ObjectDictionary] = None #: 16-bit address of the array self.index = index #: Name of array @@ -339,7 +343,7 @@ def __init__(self, name: str, index: int, subindex: int = 0): #: The :class:`~canopen.ObjectDictionary`, #: :class:`~canopen.objectdictionary.ODRecord` or #: :class:`~canopen.objectdictionary.ODArray` owning the variable - self.parent = None + self.parent: Union[ObjectDictionary, ODRecord, ODArray, None] = None #: 16-bit address of the object in the dictionary self.index = index #: 8-bit sub-index of the object in the dictionary @@ -451,19 +455,19 @@ def encode_raw(self, value: Union[int, float, str, bytes, bytearray]) -> bytes: if isinstance(value, (bytes, bytearray)): return value elif self.data_type == VISIBLE_STRING: - return value.encode("ascii") + return value.encode("ascii") # type: ignore[union-attr] elif self.data_type == UNICODE_STRING: - return value.encode("utf_16_le") + return value.encode("utf_16_le") # type: ignore[union-attr] elif self.data_type in (DOMAIN, OCTET_STRING): - return bytes(value) + return bytes(value) # type: ignore[arg-type] elif self.data_type in self.STRUCT_TYPES: if self.data_type in INTEGER_TYPES: value = int(value) if self.data_type in NUMBER_TYPES: - if self.min is not None and value < self.min: + if self.min is not None and value < self.min: # type: ignore[operator] logger.warning( "Value %d is less than min value %d", value, self.min) - if self.max is not None and value > self.max: + if self.max is not None and value > self.max: # type: ignore[operator] logger.warning( "Value %d is greater than max value %d", value, self.max) @@ -477,16 +481,15 @@ def encode_raw(self, value: Union[int, float, str, bytes, bytearray]) -> bytes: raise TypeError( f"Do not know how to encode {value!r} to data type 0x{self.data_type:X}") - def decode_phys(self, value: int) -> Union[int, bool, float, str, bytes]: + def decode_phys(self, value: int) -> Union[int, float]: if self.data_type in INTEGER_TYPES: - value *= self.factor + return value * self.factor return value def encode_phys(self, value: Union[int, bool, float, str, bytes]) -> int: if self.data_type in INTEGER_TYPES: - value /= self.factor - value = int(round(value)) - return value + value = int(round(value / self.factor)) # type: ignore[operator] + return value # type: ignore[return-value] def decode_desc(self, value: int) -> str: if not self.value_descriptions: diff --git a/canopen/sdo/base.py b/canopen/sdo/base.py index e4215a3a..3797ff28 100644 --- a/canopen/sdo/base.py +++ b/canopen/sdo/base.py @@ -79,6 +79,7 @@ def get_variable( return obj elif isinstance(obj, (SdoRecord, SdoArray)): return obj.get(subindex) + return None def upload(self, index: int, subindex: int) -> bytes: raise NotImplementedError() @@ -134,7 +135,7 @@ def __iter__(self) -> Iterator[int]: return iter(range(1, len(self) + 1)) def __len__(self) -> int: - return self[0].raw + return self[0].raw # type: ignore[return-value] def __contains__(self, subindex: int) -> bool: return 0 <= subindex <= len(self) diff --git a/canopen/variable.py b/canopen/variable.py index d2538c3f..649add84 100644 --- a/canopen/variable.py +++ b/canopen/variable.py @@ -77,7 +77,7 @@ def raw(self) -> Union[int, bool, float, str, bytes]: """ value = self.od.decode_raw(self.data) text = f"Value of {self.name!r} ({pretty_index(self.index, self.subindex)}) is {value!r}" - if value in self.od.value_descriptions: + if isinstance(value, int) and value in self.od.value_descriptions: text += f" ({self.od.value_descriptions[value]})" logger.debug(text) return value @@ -97,7 +97,7 @@ def phys(self) -> Union[int, bool, float, str, bytes]: either a :class:`float` or an :class:`int`. Non integers will be passed as is. """ - value = self.od.decode_phys(self.raw) + value = self.od.decode_phys(self.raw) # type: ignore[arg-type] if self.od.unit: logger.debug("Physical value is %s %s", value, self.od.unit) return value @@ -109,13 +109,13 @@ def phys(self, value: Union[int, bool, float, str, bytes]): @property def desc(self) -> str: """Converts to and from a description of the value as a string.""" - value = self.od.decode_desc(self.raw) + value = self.od.decode_desc(self.raw) # type: ignore[arg-type] logger.debug("Description is '%s'", value) return value @desc.setter def desc(self, desc: str): - self.raw = self.od.encode_desc(desc) + self.raw = self.od.encode_desc(desc) # type: ignore[assignment] @property def bits(self) -> "Bits": @@ -142,6 +142,7 @@ def read(self, fmt: str = "raw") -> Union[int, bool, float, str, bytes]: return self.phys elif fmt == "desc": return self.desc + raise ValueError(f"Invalid format '{fmt}'") def write( self, value: Union[int, bool, float, str, bytes], fmt: str = "raw" @@ -161,13 +162,14 @@ def write( elif fmt == "phys": self.phys = value elif fmt == "desc": - self.desc = value + self.desc = value # type: ignore[assignment] class Bits(Mapping): def __init__(self, variable: Variable): self.variable = variable + self.raw: Union[int, bool, float, str, bytes] = 0 self.read() @staticmethod @@ -181,11 +183,11 @@ def _get_bits(key): return bits def __getitem__(self, key) -> int: - return self.variable.od.decode_bits(self.raw, self._get_bits(key)) + return self.variable.od.decode_bits(self.raw, self._get_bits(key)) # type: ignore[arg-type] def __setitem__(self, key, value: int): self.raw = self.variable.od.encode_bits( - self.raw, self._get_bits(key), value) + self.raw, self._get_bits(key), value) # type: ignore[arg-type] self.write() def __iter__(self): diff --git a/test/test_od.py b/test/test_od.py index 9ab0e187..1f0a7924 100644 --- a/test/test_od.py +++ b/test/test_od.py @@ -1,6 +1,7 @@ import unittest from canopen import objectdictionary as od +from canopen.variable import Variable class TestDataConversions(unittest.TestCase): @@ -260,6 +261,17 @@ def test_get_item_index(self): self.assertIsInstance(item, od.ODArray) self.assertIs(item, array) + def test_get_item_dot_on_variable(self): + test_od = od.ObjectDictionary() + var = od.ODVariable("Test Variable", 0x1000) + test_od.add_object(var) + with self.assertRaises(KeyError): + test_od["Test Variable.sub"] + + def test_get_variable_not_found(self): + test_od = od.ObjectDictionary() + self.assertIsNone(test_od.get_variable(0x9999)) + class TestArray(unittest.TestCase): @@ -276,5 +288,62 @@ def test_subindexes(self): self.assertEqual(array[3].name, "Test Variable_3") +class _StubVariable(Variable): + """Minimal concrete Variable for testing read/write/bits.""" + + def __init__(self, od_var): + super().__init__(od_var) + self._data = od_var.encode_raw(od_var.default) + + def get_data(self): + return self._data + + def set_data(self, data): + self._data = data + + +class TestVariable(unittest.TestCase): + + def test_read_invalid_format(self): + var = od.ODVariable("Test UNSIGNED8", 0x1000) + var.data_type = od.UNSIGNED8 + var.default = 0 + v = _StubVariable(var) + with self.assertRaises(ValueError): + v.read(fmt="invalid") + + def test_write_desc(self): + var = od.ODVariable("Test UNSIGNED8", 0x1000) + var.data_type = od.UNSIGNED8 + var.default = 0 + var.add_value_description(0, "Off") + var.add_value_description(1, "On") + v = _StubVariable(var) + v.write("On", fmt="desc") + self.assertEqual(v.raw, 1) + + def test_raw_with_string_value(self): + var = od.ODVariable("Test VISIBLE_STRING", 0x1000) + var.data_type = od.VISIBLE_STRING + var.default = "hello" + var.add_value_description(0, "Off") + v = _StubVariable(var) + # String value must not be looked up in value_descriptions + self.assertEqual(v.raw, "hello") + + def test_bits(self): + var = od.ODVariable("Test UNSIGNED8", 0x1000) + var.data_type = od.UNSIGNED8 + var.default = 0 + var.add_bit_definition("BIT 0", [0]) + var.add_bit_definition("BIT 2 and 3", [2, 3]) + v = _StubVariable(var) + v.raw = 5 + bits = v.bits + self.assertEqual(bits[0], 1) + bits[0] = 0 + self.assertEqual(v.raw, 4) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_sdo.py b/test/test_sdo.py index 9764a6d3..a316f9f6 100644 --- a/test/test_sdo.py +++ b/test/test_sdo.py @@ -48,6 +48,9 @@ def test_array_members_dynamic(self): for var in array.values(): self.assertIsInstance(var, canopen.sdo.SdoVariable) + def test_get_variable_not_found(self): + self.assertIsNone(self.sdo_node.get_variable(0x9999)) + class TestSDO(unittest.TestCase): """