mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-22 02:20:02 +02:00
324 lines
11 KiB
Python
324 lines
11 KiB
Python
"""
|
|
This module provides the BECService class, which is the base class for all BEC services.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import getpass
|
|
import socket
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import asdict, dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import psutil
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
from bec_lib import messages
|
|
from bec_lib.endpoints import MessageEndpoints
|
|
from bec_lib.logger import bec_logger
|
|
from bec_lib.messages import BECStatus
|
|
from bec_lib.service_config import ServiceConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from bec_lib.connector import ConnectorBase
|
|
|
|
|
|
logger = bec_logger.logger
|
|
|
|
SERVICE_CONFIG = None
|
|
|
|
|
|
class BECService:
|
|
def __init__(
|
|
self,
|
|
config: str | ServiceConfig,
|
|
connector_cls: ConnectorBase,
|
|
unique_service=False,
|
|
wait_for_server=False,
|
|
) -> None:
|
|
super().__init__()
|
|
self._import_config(config)
|
|
self._connector_cls = connector_cls
|
|
self.connector = connector_cls(self.bootstrap_server)
|
|
self._unique_service = unique_service
|
|
self.wait_for_server = wait_for_server
|
|
self.__service_id = str(uuid.uuid4())
|
|
self._user = getpass.getuser()
|
|
self._hostname = socket.gethostname()
|
|
self._service_info_thread = None
|
|
self._service_info_event = threading.Event()
|
|
self._metrics_emitter_thread = None
|
|
self._metrics_emitter_event = threading.Event()
|
|
self._services_info = {}
|
|
self._initialize_logger()
|
|
self._check_services()
|
|
self._status = BECStatus.BUSY
|
|
self._start_update_service_info()
|
|
self._start_metrics_emitter()
|
|
self._wait_for_server()
|
|
|
|
@property
|
|
def _service_name(self):
|
|
return (
|
|
self.__class__.__name__
|
|
if self._unique_service
|
|
else f"{self.__class__.__name__}/{self._service_id}"
|
|
)
|
|
|
|
@property
|
|
def _service_id(self):
|
|
return self.__service_id
|
|
|
|
def _import_config(self, config: str | ServiceConfig) -> None:
|
|
if isinstance(config, str):
|
|
self._service_config = ServiceConfig(config_path=config)
|
|
elif isinstance(config, ServiceConfig):
|
|
self._service_config = config
|
|
else:
|
|
raise TypeError("config must be of type str or ServiceConfig")
|
|
global SERVICE_CONFIG
|
|
SERVICE_CONFIG = self._service_config
|
|
self.bootstrap_server = self._service_config.redis
|
|
|
|
def _check_services(self, timeout_time=8, sleep_time=0.5) -> None:
|
|
if not self._unique_service:
|
|
return
|
|
elapsed_time = 0
|
|
while self._run_service_check(timeout_time, elapsed_time):
|
|
elapsed_time += sleep_time
|
|
time.sleep(sleep_time)
|
|
|
|
def _run_service_check(self, timeout_time: float, elapsed_time: float) -> bool:
|
|
self._update_existing_services()
|
|
try:
|
|
for service_name, msg in self._services_info.items():
|
|
if service_name == self.__class__.__name__:
|
|
raise RuntimeError(
|
|
f"Another instance of {self.__class__.__name__} launched by user"
|
|
f" {msg.content['info']['user']} is already running on"
|
|
f" {msg.content['info']['hostname']}"
|
|
)
|
|
return False
|
|
except RuntimeError as service_error:
|
|
if elapsed_time > timeout_time:
|
|
raise RuntimeError from service_error
|
|
return True
|
|
|
|
def _initialize_logger(self) -> None:
|
|
bec_logger.configure(
|
|
self.bootstrap_server, self._connector_cls, service_name=self.__class__.__name__
|
|
)
|
|
|
|
def _update_existing_services(self):
|
|
service_keys = self.connector.keys(MessageEndpoints.service_status("*").endpoint)
|
|
if not service_keys:
|
|
return
|
|
services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys]
|
|
msgs = [self.connector.get(service) for service in services]
|
|
self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None}
|
|
|
|
def _update_service_info(self):
|
|
while not self._service_info_event.is_set():
|
|
logger.trace("Updating service info")
|
|
try:
|
|
self._send_service_status()
|
|
except Exception:
|
|
# exception is not explicitely specified,
|
|
# because it depends on the underlying connector
|
|
pass
|
|
self._service_info_event.wait(timeout=3)
|
|
|
|
def _send_service_status(self):
|
|
self.connector.set_and_publish(
|
|
topic=MessageEndpoints.service_status(self._service_id),
|
|
msg=messages.StatusMessage(
|
|
name=self._service_name,
|
|
status=self.status,
|
|
info={"user": self._user, "hostname": self._hostname, "timestamp": time.time()},
|
|
),
|
|
expire=6,
|
|
)
|
|
|
|
@property
|
|
def status(self) -> BECStatus:
|
|
"""get the current BECService status"""
|
|
return self._status
|
|
|
|
@status.setter
|
|
def status(self, val: BECStatus):
|
|
self._status = val
|
|
self._send_service_status()
|
|
|
|
def _start_update_service_info(self):
|
|
self._service_info_thread = threading.Thread(
|
|
target=self._update_service_info, daemon=True, name="update_service_info"
|
|
)
|
|
self._service_info_thread.start()
|
|
|
|
def _start_metrics_emitter(self):
|
|
self._metrics_emitter_thread = threading.Thread(
|
|
target=self._get_metrics, daemon=True, name="metrics_emitter"
|
|
)
|
|
self._metrics_emitter_thread.start()
|
|
|
|
def _get_metrics(self):
|
|
proc = psutil.Process()
|
|
while not self._metrics_emitter_event.is_set():
|
|
res = proc.as_dict(
|
|
attrs=[
|
|
"name",
|
|
"num_threads",
|
|
"pid",
|
|
"cpu_percent",
|
|
"memory_info",
|
|
"cmdline",
|
|
"cpu_times",
|
|
"create_time",
|
|
"memory_percent",
|
|
]
|
|
)
|
|
|
|
data = asdict(
|
|
ServiceMetric(
|
|
process_name=res["name"],
|
|
username=self._user,
|
|
hostname=self._hostname,
|
|
cpu_percent=res["cpu_percent"],
|
|
cpu_times=res["cpu_times"].user,
|
|
cmdline=res["cmdline"],
|
|
num_threads=res["num_threads"],
|
|
pid=res["pid"],
|
|
memory_in_mb=res["memory_info"].rss / 1024 / 1024,
|
|
memory_used_percent=res["memory_percent"],
|
|
create_time=res["create_time"],
|
|
)
|
|
)
|
|
msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data)
|
|
try:
|
|
self.connector.send(MessageEndpoints.metrics(self._service_id), msg)
|
|
except Exception:
|
|
# exception is not explicitely specified,
|
|
# because it depends on the underlying connector
|
|
pass
|
|
self._metrics_emitter_event.wait(timeout=1)
|
|
|
|
def set_global_var(self, name: str, val: Any) -> None:
|
|
"""Set a global variable through Redis
|
|
|
|
Args:
|
|
name (str): Name of the variable
|
|
val (Any): Value of the variable
|
|
|
|
"""
|
|
self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
|
|
|
|
def get_global_var(self, name: str) -> Any:
|
|
"""Get a global variable from Redis
|
|
|
|
Args:
|
|
name (str): Name of the variable
|
|
|
|
Returns:
|
|
Any: Value of the variable
|
|
"""
|
|
msg = self.connector.get(MessageEndpoints.global_vars(name))
|
|
if msg:
|
|
return msg.content.get("value")
|
|
return None
|
|
|
|
def delete_global_var(self, name: str) -> None:
|
|
"""Delete a global variable from Redis
|
|
|
|
Args:
|
|
name (str): Name of the variable
|
|
|
|
"""
|
|
self.connector.delete(MessageEndpoints.global_vars(name))
|
|
|
|
def global_vars(self) -> str:
|
|
"""Get all available global variables"""
|
|
# sadly, this cannot be a property as it causes side effects with IPython's tab completion
|
|
available_keys = self.connector.keys(MessageEndpoints.global_vars("*"))
|
|
|
|
def get_endpoint_from_topic(topic: str) -> str:
|
|
return topic.decode().split(MessageEndpoints.global_vars(""))[-1]
|
|
|
|
endpoints = [get_endpoint_from_topic(k) for k in available_keys]
|
|
|
|
console = Console()
|
|
table = Table(title="Global variables")
|
|
table.add_column("Variable", justify="center")
|
|
table.add_column("Content", justify="center")
|
|
for endpoint in endpoints:
|
|
var = str(self.get_global_var(endpoint))
|
|
if len(var) > 40:
|
|
var = var[0:20] + "..., " + var[-20:]
|
|
table.add_row(endpoint, var)
|
|
with console.capture() as capture:
|
|
console.print(table)
|
|
out = capture.get()
|
|
logger.info(out)
|
|
print(out)
|
|
|
|
def shutdown(self):
|
|
"""shutdown the BECService"""
|
|
try:
|
|
self.connector.shutdown()
|
|
self._service_info_event.set()
|
|
if self._service_info_thread:
|
|
self._service_info_thread.join()
|
|
self._metrics_emitter_event.set()
|
|
if self._metrics_emitter_thread:
|
|
self._metrics_emitter_thread.join()
|
|
except AttributeError:
|
|
print("Failed to shutdown BECService.")
|
|
|
|
@property
|
|
def service_status(self):
|
|
"""get the status of active services"""
|
|
self._update_existing_services()
|
|
return self._services_info
|
|
|
|
def wait_for_service(self, name, status=BECStatus.RUNNING):
|
|
logger.info(f"Waiting for {name}.")
|
|
while True:
|
|
service_status_msg = self.service_status.get(name)
|
|
if service_status_msg is not None:
|
|
service_status = BECStatus(service_status_msg.content["status"])
|
|
if service_status == status:
|
|
break
|
|
time.sleep(0.05)
|
|
logger.success(f"{name} is running.")
|
|
|
|
def _wait_for_server(self):
|
|
if not self.wait_for_server:
|
|
return
|
|
try:
|
|
self.wait_for_service("ScanServer", BECStatus.RUNNING)
|
|
self.wait_for_service("ScanBundler", BECStatus.RUNNING)
|
|
self.wait_for_service("DeviceServer", BECStatus.RUNNING)
|
|
self.wait_for_service("SciHub", BECStatus.RUNNING)
|
|
logger.success("All BEC services are running.")
|
|
except KeyboardInterrupt:
|
|
logger.warning("KeyboardInterrupt received. Stopped waiting for BEC services.")
|
|
|
|
|
|
@dataclass
|
|
class ServiceMetric:
|
|
"""Container for keeping performance metrics."""
|
|
|
|
pid: int
|
|
cmdline: list
|
|
process_name: str
|
|
username: str
|
|
hostname: str
|
|
cpu_percent: float
|
|
cpu_times: dict
|
|
num_threads: int
|
|
memory_in_mb: float
|
|
memory_used_percent: float
|
|
create_time: float
|