diff --git a/src/pydase/client/client.py b/src/pydase/client/client.py index 7371ab0..6a82514 100644 --- a/src/pydase/client/client.py +++ b/src/pydase/client/client.py @@ -8,6 +8,7 @@ import socketio # type: ignore import pydase.components from pydase.client.proxy_loader import ProxyClassMixin, ProxyLoader +from pydase.utils.helpers import current_event_loop_exists from pydase.utils.serialization.deserializer import loads from pydase.utils.serialization.types import SerializedDataService, SerializedObject @@ -74,6 +75,7 @@ class ProxyClass(ProxyClassMixin, pydase.components.DeviceConnection): self, sio_client: socketio.AsyncClient, loop: asyncio.AbstractEventLoop ) -> None: super().__init__() + pydase.components.DeviceConnection.__init__(self) self._initialise(sio_client=sio_client, loop=loop) @@ -107,7 +109,11 @@ class Client: ): self._url = url self._sio = socketio.AsyncClient() - self._loop = asyncio.new_event_loop() + if not current_event_loop_exists(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + else: + self._loop = asyncio.get_event_loop() self.proxy = ProxyClass(sio_client=self._sio, loop=self._loop) """A proxy object representing the remote service, facilitating interaction as if it were local.""" diff --git a/src/pydase/server/server.py b/src/pydase/server/server.py index 9c701d7..b7c0c08 100644 --- a/src/pydase/server/server.py +++ b/src/pydase/server/server.py @@ -13,6 +13,7 @@ from pydase.config import ServiceConfig from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.state_manager import StateManager from pydase.server.web_server import WebServer +from pydase.utils.helpers import current_event_loop_exists HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. @@ -156,13 +157,17 @@ class Server: self._web_port = web_port self._enable_web = enable_web self._kwargs = kwargs - self._loop: asyncio.AbstractEventLoop self._additional_servers = additional_servers self.should_exit = False self.servers: dict[str, asyncio.Future[Any]] = {} self._state_manager = StateManager(self._service, filename) self._observer = DataServiceObserver(self._state_manager) self._state_manager.load_state() + if not current_event_loop_exists(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + else: + self._loop = asyncio.get_event_loop() def run(self) -> None: """ @@ -170,7 +175,7 @@ class Server: This method should be called to start the server after it's been instantiated. """ - asyncio.run(self.serve()) + self._loop.run_until_complete(self.serve()) async def serve(self) -> None: process_id = os.getpid() @@ -186,10 +191,8 @@ class Server: logger.info("Finished server process [%s]", process_id) async def startup(self) -> None: - self._loop = asyncio.get_running_loop() self._loop.set_exception_handler(self.custom_exception_handler) self.install_signal_handlers() - self._service._task_manager.start_autostart_tasks() for server in self._additional_servers: addin_server = server["server"]( diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index e9d6653..7809a52 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -13,6 +13,7 @@ from typing import ( from typing_extensions import TypeIs import pydase +from pydase.utils.helpers import current_event_loop_exists logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -42,6 +43,11 @@ class Task(pydase.DataService, Generic[R]): autostart: bool = False, ) -> None: super().__init__() + if not current_event_loop_exists(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + else: + self._loop = asyncio.get_event_loop() self._func_name = func.__name__ self._bound_func: Callable[[], Coroutine[None, None, R | None]] | None = None if is_bound_method(func): diff --git a/src/pydase/utils/helpers.py b/src/pydase/utils/helpers.py index 3c5269f..3c22c6c 100644 --- a/src/pydase/utils/helpers.py +++ b/src/pydase/utils/helpers.py @@ -201,3 +201,10 @@ def function_has_arguments(func: Callable[..., Any]) -> bool: def is_descriptor(obj: object) -> bool: """Check if an object is a descriptor.""" return any(hasattr(obj, method) for method in ("__get__", "__set__", "__delete__")) + + +def current_event_loop_exists() -> bool: + """Check if an event loop has been set.""" + import asyncio + + return asyncio.get_event_loop_policy()._local._loop is not None # type: ignore