diff --git a/canopen/emcy.py b/canopen/emcy.py index 22d1eba8..8d075008 100644 --- a/canopen/emcy.py +++ b/canopen/emcy.py @@ -39,7 +39,10 @@ def on_emcy(self, can_id, data, timestamp): self.emcy_received.notify_all() for callback in self.callbacks: - callback(entry) + try: + callback(entry) + except Exception: + logger.exception("Exception in EMCY callback") def add_callback(self, callback: Callable[[EmcyError], None]): """Get notified on EMCY messages from this node. diff --git a/test/test_emcy.py b/test/test_emcy.py index d883e9c8..6c153165 100644 --- a/test/test_emcy.py +++ b/test/test_emcy.py @@ -25,13 +25,12 @@ def check_error(self, err, code, reg, data, ts): self.assertAlmostEqual(err.timestamp, ts) def test_emcy_consumer_on_emcy(self): - # Make sure multiple callbacks receive the same information. + """Make sure multiple callbacks receive the same information.""" acc1 = [] acc2 = [] self.emcy.add_callback(lambda err: acc1.append(err)) self.emcy.add_callback(lambda err: acc2.append(err)) - # Dispatch an EMCY datagram. self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) self.assertEqual(len(self.emcy.log), 1) @@ -45,7 +44,6 @@ def test_emcy_consumer_on_emcy(self): data=bytes([0, 1, 2, 3, 4]), ts=1000, ) - # Dispatch a new EMCY datagram. self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) self.assertEqual(len(self.emcy.log), 2) self.assertEqual(len(self.emcy.active), 2) @@ -58,7 +56,6 @@ def test_emcy_consumer_on_emcy(self): data=bytes([4, 3, 2, 1, 0]), ts=2000, ) - # Dispatch an EMCY reset. self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) self.assertEqual(len(self.emcy.log), 3) self.assertEqual(len(self.emcy.active), 0) @@ -123,6 +120,65 @@ def push_reset(): t.start() self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) + def test_emcy_consumer_initialization(self): + consumer = canopen.emcy.EmcyConsumer() + self.assertEqual(consumer.log, []) + self.assertEqual(consumer.active, []) + self.assertEqual(consumer.callbacks, []) + + def test_emcy_consumer_multiple_callbacks(self): + """Test adding multiple callbacks and their execution order.""" + call_order = [] + self.emcy.add_callback(lambda err: call_order.append('callback1')) + self.emcy.add_callback(lambda err: call_order.append('callback2')) + self.emcy.add_callback(lambda err: call_order.append('callback3')) + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.assertEqual(call_order, ['callback1', 'callback2', 'callback3']) + + def test_emcy_consumer_callback_exception_handling(self): + """Test that callback exceptions don't break other callbacks or the system.""" + successful_callbacks = [] + self.emcy.add_callback(lambda err: successful_callbacks.append('success1')) + self.emcy.add_callback( + lambda err: exec('raise ValueError("Test exception in callback")') + ) + self.emcy.add_callback(lambda err: successful_callbacks.append('success2')) + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.assertEqual(successful_callbacks, ['success1', 'success2']) + + def test_emcy_consumer_error_reset_variants(self): + """Test different error reset code patterns.""" + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + self.assertEqual(len(self.emcy.active), 2) + self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 3000) + self.assertEqual(len(self.emcy.active), 0) + self.emcy.on_emcy(0x81, b'\x01\x30\x02\x00\x01\x02\x03\x04', 4000) + self.assertEqual(len(self.emcy.active), 1) + self.emcy.on_emcy(0x81, b'\x99\x00\x01\x00\x00\x00\x00\x00', 5000) + self.assertEqual(len(self.emcy.active), 0) + + def test_emcy_consumer_wait_timeout_edge_cases(self): + """Test wait method with various timeout scenarios.""" + result = self.emcy.wait(timeout=0) + self.assertIsNone(result) + result = self.emcy.wait(timeout=0.001) + self.assertIsNone(result) + + def test_emcy_consumer_wait_concurrent_errors(self): + """Test wait method when multiple errors arrive concurrently.""" + def push_multiple_errors(): + self.emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) + self.emcy.on_emcy(0x81, b'\x02\x20\x01\x01\x02\x03\x04\x05', 101) + self.emcy.on_emcy(0x81, b'\x03\x20\x01\x01\x02\x03\x04\x05', 102) + t = threading.Timer(TIMEOUT / 2, push_multiple_errors) + with self.assertLogs(level=logging.INFO): + t.start() + err = self.emcy.wait(0x2003, timeout=TIMEOUT) + t.join(TIMEOUT) + self.assertIsNotNone(err) + self.assertEqual(err.code, 0x2003) + class TestEmcyError(unittest.TestCase): def test_emcy_error(self): @@ -180,6 +236,26 @@ def check(code, expected): check(0xff00, "Device Specific") check(0xffff, "Device Specific") + def test_emcy_error_initialization_types(self): + """Test EmcyError initialization with various data types.""" + error = EmcyError(0x1000, 0, b'', 123.456) + self.assertEqual(error.code, 0x1000) + self.assertEqual(error.register, 0) + self.assertEqual(error.data, b'') + self.assertEqual(error.timestamp, 123.456) + error = EmcyError(0xFFFF, 0xFF, b'\xFF' * 5, float('inf')) + self.assertEqual(error.code, 0xFFFF) + self.assertEqual(error.register, 0xFF) + self.assertEqual(error.data, b'\xFF' * 5) + self.assertEqual(error.timestamp, float('inf')) + + def test_emcy_error_str_edge_cases(self): + for code in (0x0000, 0x0001, 0x0100, 0xFFFF): + error = EmcyError(code, 0, b'', 1000) + s = str(error) + self.assertIsInstance(s, str) + self.assertIn(f"0x{code:04X}", s) + class TestEmcyProducer(unittest.TestCase): def setUp(self): @@ -220,6 +296,91 @@ def check(*args, res): check(3, res=b'\x00\x00\x03\x00\x00\x00\x00\x00') check(3, b"\xaa\xbb", res=b'\x00\x00\x03\xaa\xbb\x00\x00\x00') + def test_emcy_producer_initialization(self): + producer = canopen.emcy.EmcyProducer(0x123) + self.assertEqual(producer.cob_id, 0x123) + self.assertIsNotNone(producer.network) + + def test_emcy_producer_send_edge_cases(self): + self.emcy.send(0xFFFF, 0xFF, b'\xFF\xFF\xFF\xFF\xFF') + self.check_response(b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') + self.emcy.send(0x0000, 0x00) + self.check_response(b'\x00\x00\x00\x00\x00\x00\x00\x00') + self.emcy.send(0x1234, 0x56, b'\xAB\xCD') + self.check_response(b'\x34\x12\x56\xAB\xCD\x00\x00\x00') + self.emcy.send(0x1234, 0x56, b'\xAB\xCD\xEF\x12\x34') + self.check_response(b'\x34\x12\x56\xAB\xCD\xEF\x12\x34') + + def test_emcy_producer_reset_edge_cases(self): + self.emcy.reset(0xFF) + self.check_response(b'\x00\x00\xFF\x00\x00\x00\x00\x00') + self.emcy.reset(0xFF, b'\xFF\xFF\xFF\xFF\xFF') + self.check_response(b'\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF') + self.emcy.reset(0x12, b'\xAB\xCD') + self.check_response(b'\x00\x00\x12\xAB\xCD\x00\x00\x00') + + +class TestEmcyIntegration(unittest.TestCase): + """Integration tests for EMCY producer and consumer.""" + + def setUp(self): + self.txbus = can.Bus(interface="virtual") + self.rxbus = can.Bus(interface="virtual") + self.net = canopen.Network(self.txbus) + self.net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.net.connect() + self.rx_net = canopen.Network(self.rxbus) + self.rx_net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.rx_net.connect() + self.producer = canopen.emcy.EmcyProducer(0x081) + self.producer.network = self.net + self.consumer = canopen.emcy.EmcyConsumer() + self.rx_net.subscribe(0x081, self.consumer.on_emcy) + + def tearDown(self): + self.net.disconnect() + self.rx_net.disconnect() + self.txbus.shutdown() + self.rxbus.shutdown() + + def test_producer_consumer_integration(self): + """Test that producer and consumer work together.""" + received_errors = [] + self.consumer.add_callback(lambda err: received_errors.append(err)) + t = threading.Timer( + TIMEOUT / 2, + lambda: self.producer.send(0x2001, 0x02, b'\x01\x02\x03\x04\x05'), + ) + with self.assertLogs(level=logging.INFO): + t.start() + err = self.consumer.wait(0x2001, timeout=TIMEOUT) + t.join(TIMEOUT) + self.assertIsNotNone(err) + self.assertEqual(err.code, 0x2001) + self.assertEqual(err.register, 0x02) + self.assertEqual(err.data, b'\x01\x02\x03\x04\x05') + self.assertEqual(received_errors, [err]) + + def test_producer_reset_consumer_integration(self): + """Test producer reset clears consumer active errors.""" + t = threading.Timer( + TIMEOUT / 2, + lambda: self.producer.send(0x2001, 0x02, b'\x01\x02\x03\x04\x05'), + ) + with self.assertLogs(level=logging.INFO): + t.start() + self.consumer.wait(0x2001, timeout=TIMEOUT) + t.join(TIMEOUT) + self.assertEqual(len(self.consumer.active), 1) + t = threading.Timer(TIMEOUT / 2, self.producer.reset) + with self.assertLogs(level=logging.INFO): + t.start() + err = self.consumer.wait(timeout=TIMEOUT) + t.join(TIMEOUT) + self.assertIsNotNone(err) + self.assertEqual(len(self.consumer.active), 0) + self.assertEqual(len(self.consumer.log), 2) + if __name__ == "__main__": unittest.main()