diff --git a/ophyd_devices/utils/static_device_test.py b/ophyd_devices/utils/static_device_test.py index 070ffb5..bf0a474 100644 --- a/ophyd_devices/utils/static_device_test.py +++ b/ophyd_devices/utils/static_device_test.py @@ -154,13 +154,17 @@ class StaticDeviceTest: if "auto_monitor" not in config: self.print_and_write(f"WARNING: Device {name} is configured without auto monitor.") - def connect_device(self, name: str, conf: dict) -> int: + def connect_device( + self, name: str, conf: dict, force_connect: bool = False, timeout_per_device: float = 30 + ) -> int: """ Connect to the device Args: name(str): name of the device conf(dict): device config + force_connect(bool): force connection to all signals even if devices report .connected = True. Default is False. + timeout_per_device(float): timeout for each device connection. Default is 30 seconds. Returns: int: 0 if all checks passed, 1 otherwise @@ -170,7 +174,9 @@ class StaticDeviceTest: conf_in["name"] = name obj, _ = device_manager.construct_device_obj(conf_in, None) - device_manager.connect_device(obj, wait_for_all=True) + device_manager.connect_device( + obj, wait_for_all=True, timeout=timeout_per_device, force=force_connect + ) assert obj.connected is True self.check_basic_ophyd_methods(obj) obj.destroy() @@ -237,12 +243,16 @@ class StaticDeviceTest: db_config.pop("deviceType", None) return db_config - def run(self, connect: bool) -> None: + def run( + self, connect: bool, force_connect: bool = False, timeout_per_device: float = 30 + ) -> None: """ Run the tests Args: connect(bool): connect to the devices + force_connect(bool): force connection to all signals even if devices report .connected = True. Default is False. + timeout_per_device(float): timeout for each device connection. Default is 30 seconds. """ failed_devices = [] for name, conf in self.config.items(): @@ -251,7 +261,9 @@ class StaticDeviceTest: return_val += self.validate_schema(name, conf) return_val += self.check_device_classes(name, conf) if connect: - return_val += self.connect_device(name, conf) + return_val += self.connect_device( + name, conf, force_connect=force_connect, timeout_per_device=timeout_per_device + ) if return_val == 0: self.print_and_write("OK") @@ -287,10 +299,17 @@ class StaticDeviceTest: if self.file is not None: # Write only if no output file is provided self.file.write(text + "\n") - def run_with_list_output(self, connect: bool = False) -> list[TestResult]: + def run_with_list_output( + self, connect: bool = False, force_connect: bool = False, timeout_per_device: float = 30 + ) -> list[TestResult]: """ Run the tests and return a list of tuples with the device name, success status, and error message. + Args: + connect(bool): connect to the devices + force_connect(bool): force connection to all signals even if devices report .connected = True. Default is False. + timeout_per_device(float): timeout for each device connection. Default is 30 seconds. + Returns: list[tuple[str, bool, str]]: list of tuples with the device name, success status, and error message """ @@ -313,7 +332,12 @@ class StaticDeviceTest: return_val += self.validate_schema(name, conf) return_val += self.check_device_classes(name, conf) if device_manager is not None and connect: - return_val += self.connect_device(name, conf) + return_val += self.connect_device( + name, + conf, + force_connect=force_connect, + timeout_per_device=timeout_per_device, + ) if return_val == 0: status = True self.print_and_write(f"{name} is OK") @@ -340,6 +364,18 @@ def launch() -> None: "--output", default="./device_test_reports", help="path to the output directory" ) optional.add_argument("--connect", action="store_true", help="connect to the devices") + optional.add_argument( + "--force-connect", + action="store_true", + default=False, + help="force connection to all signals", + ) + optional.add_argument( + "--timeout-per-device", + type=float, + default=30, + help="timeout for each device connection in seconds", + ) parser.add_help = True clargs = parser.parse_args() @@ -372,7 +408,7 @@ def launch() -> None: os.path.join(clargs.output, f"report_{report_name}.txt"), "w", encoding="utf-8" ) as report_file: device_config_test = StaticDeviceTest(config_file=file, output_file=report_file) - device_config_test.run(clargs.connect) + device_config_test.run(clargs.connect, clargs.force_connect, clargs.timeout_per_device) if __name__ == "__main__": # pragma: no cover