diff --git a/example_scilog.py b/example_scilog.py index 6a48e2b..d50708a 100755 --- a/example_scilog.py +++ b/example_scilog.py @@ -2,11 +2,13 @@ import argparse -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("pgroup", help="Expected form: p12345") +parser.add_argument("-u", "--url", default="https://lnode2.psi.ch/api/v1", help="Server address") clargs = parser.parse_args() pgroup = clargs.pgroup +url = clargs.url import urllib3 @@ -15,23 +17,23 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) from scilog import SciLog +from scilog import Basesnippet, Paragraph + +tmp = Basesnippet() +tmp.id = "2" -url = "https://lnode2.psi.ch/api/v1" log = SciLog(url) #print(log.token) -loc = log.get_snippets(title="location", ownerGroup="admin") +logbooks = log.get_logbooks(ownerGroup=pgroup) +print(logbooks) -assert len(loc) == 1 -loc_id = loc[0]["id"] -print(loc_id) +assert len(logbooks) == 1 +logbook = logbooks[0] +print(logbook) -lb = log.get_snippets(snippetType="logbook", ownerGroup=pgroup) +log.select_logbook(logbook) -assert len(lb) == 1 -lb_id = lb[0]["id"] -print(lb_id) - -res = log.post_snippet(snippetType="paragraph", ownerGroup=pgroup, parentId=lb_id, textcontent="

from python

") +res = log.send_message("

from python

") print(res) snips = log.get_snippets(snippetType="paragraph", ownerGroup=pgroup) diff --git a/scilog/__init__.py b/scilog/__init__.py index 0fd2f3d..5790e10 100644 --- a/scilog/__init__.py +++ b/scilog/__init__.py @@ -1,5 +1,6 @@ from .scicat import SciCat from .scilog import SciLog +from .snippet import Basesnippet, Paragraph diff --git a/scilog/authclient.py b/scilog/authmixin.py similarity index 85% rename from scilog/authclient.py rename to scilog/authmixin.py index f4273aa..be0d395 100644 --- a/scilog/authclient.py +++ b/scilog/authmixin.py @@ -4,13 +4,13 @@ from .config import Config from .utils import typename -AUTH_HEADERS = { +HEADER_JSON = { "Content-type": "application/json", "Accept": "application/json" } -class AuthClient(ABC): +class AuthMixin(ABC): def __init__(self, address): self.address = address.rstrip("/") @@ -22,19 +22,15 @@ class AuthClient(ABC): tn = typename(self) return f"{tn} @ {self.address}" - @abstractmethod def authenticate(self, username, password): raise NotImplementedError - @property - def auth_headers(self): - headers = AUTH_HEADERS.copy() - headers["Authorization"] = self.token - return headers - @property def token(self): + return self._retrieve_token() + + def _retrieve_token(self): username = getpass.getuser() token = self._token if token is None: diff --git a/scilog/config.py b/scilog/config.py index bcb69a8..35fdb02 100644 --- a/scilog/config.py +++ b/scilog/config.py @@ -11,7 +11,7 @@ class Config(dict): folder = Path.home() self.fname = folder / fname content = self._load() - super().update(content) + super().__init__(content) def __setitem__(self, name, value): self.update(**{name: value}) @@ -30,6 +30,9 @@ class Config(dict): def _save(self): json_save(self, self.fname) + def delete(self): + self.fname.unlink() + def json_save(what, filename, *args, indent=4, sort_keys=True, **kwargs): diff --git a/scilog/httpclient.py b/scilog/httpclient.py new file mode 100644 index 0000000..e01d99a --- /dev/null +++ b/scilog/httpclient.py @@ -0,0 +1,78 @@ +import functools +import json +import requests + +from .authmixin import AuthMixin, AuthError, HEADER_JSON + + +def authenticated(func): + @functools.wraps(func) + def authenticated_call(client, *args, **kwargs): + if not isinstance(client, HttpClient): + raise AttributeError("First argument must be an instance of HttpClient") + if "headers" in kwargs: + kwargs["headers"] = kwargs["headers"].copy() + else: + kwargs["headers"] = {} + kwargs["headers"]["Authorization"] = client.token + return func(client, *args, **kwargs) + return authenticated_call + + +class HttpClient(AuthMixin): + + def __init__(self, address): + self.address = address + self._verify_certificate = True + self.login_path = self.address + "/users/login" + super().__init__(address) + + def authenticate(self, username, password): + auth_payload = { + "principal": username, + "password": password + } + res = self._login(auth_payload, HEADER_JSON) + try: + token = "Bearer " + res["token"] + except KeyError as e: + raise AuthError(res) from e + else: + return token + + @authenticated + def get_request(self, url, params=None, headers=None, timeout=10): + response = requests.get(url, params=params, headers=headers, timeout=timeout, verify=self._verify_certificate) + if response.ok: + return response.json() + else: + if response.reason == "Unauthorized": + self.config.delete() + raise response.raise_for_status() + + @authenticated + def post_request(self, url, payload=None, headers=None, timeout=10): + return requests.post(url, json=payload, headers=headers, timeout=timeout, verify=self._verify_certificate).json() + + def _login(self, payload=None, headers=None, timeout=10): + return requests.post(self.login_path, json=payload, headers=headers, timeout=timeout, verify=self._verify_certificate).json() + + @staticmethod + def make_filter(where:dict=None, limit:int=0, skip:int=0, fields:dict=None, include:dict=None, order:list=None): + filt = dict() + if where is not None: + items = [where.copy()] + filt["where"] = {"and": items} + if limit > 0: + filt["limit"] = limit + if skip > 0: + filt["skip"] = skip + if fields is not None: + filt["fields"] = include + if order is not None: + filt["order"] = order + filt = json.dumps(filt) + return {"filter": filt} + + + diff --git a/scilog/mkfilt.py b/scilog/mkfilt.py deleted file mode 100644 index d11097b..0000000 --- a/scilog/mkfilt.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - - -def make_filter(**kwargs): - items = [{k: v} for k, v in kwargs.items()] - filt = {"where": {"and": items}} - filt = json.dumps(filt) - return {"filter": filt} - - - diff --git a/scilog/scicat.py b/scilog/scicat.py index 91f21db..a351340 100644 --- a/scilog/scicat.py +++ b/scilog/scicat.py @@ -1,29 +1,35 @@ -from .authclient import AuthClient, AuthError, AUTH_HEADERS -from .utils import post_request, get_request +from .authmixin import AuthMixin, AuthError, HEADER_JSON +from .httpclient import HttpClient -class SciCat(AuthClient): +class SciCatRestAPI(HttpClient): + def __init__(self, url): + super().__init__(url) + self.login_path = "https://dacat.psi.ch/auth/msad" def authenticate(self, username, password): - url = self.address + "/users/login" auth_payload = { "username": username, "password": password } - res = post_request(url, auth_payload, AUTH_HEADERS) + res = self._login(auth_payload, HEADER_JSON) try: - token = res["id"] + token = res["access_token"] except KeyError as e: raise SciCatAuthError(res) from e else: return token +class SciCat(): + + def __init__(self, url="https://dacat.psi.ch/api/v3/"): + self.http_client = SciCatRestAPI(url) + @property def proposals(self): - url = self.address + "/proposals" - headers = self.auth_headers - return get_request(url, headers=headers) + url = self.http_client.address + "/proposals" + return self.http_client.get_request(url, headers=HEADER_JSON) diff --git a/scilog/scilog.py b/scilog/scilog.py index d9109a3..22e55bb 100644 --- a/scilog/scilog.py +++ b/scilog/scilog.py @@ -1,37 +1,74 @@ -from .authclient import AuthClient, AuthError, AUTH_HEADERS -from .utils import post_request, get_request -from .mkfilt import make_filter +from __future__ import annotations +import functools +import warnings + +from .authmixin import AuthError, HEADER_JSON +from .httpclient import HttpClient +from .snippet import Snippet, Basesnippet, Paragraph -class SciLog(AuthClient): - - def authenticate(self, username, password): - url = self.address + "/users/login" - auth_payload = { - "principal": username, - "password": password - } - res = post_request(url, auth_payload, AUTH_HEADERS) - try: - token = "Bearer " + res["token"] - except KeyError as e: - raise SciLogAuthError(res) from e - else: - return token +def pinned_to_logbook(logbook_keys): + def pinned_to_logbook_inner(func): + @functools.wraps(func) + def pinned_to_logbook_call(log, *args, **kwargs): + if not isinstance(log.logbook, Basesnippet): + warnings.warn("No logbook selected.") + else: + for key in logbook_keys: + if key not in kwargs: + if key == "parentId": + kwargs[key] = log.logbook.id + else: + kwargs[key] = getattr(log.logbook, key) + return func(log, *args, **kwargs) + return pinned_to_logbook_call + return pinned_to_logbook_inner +class SciLogRestAPI(HttpClient): + def __init__(self, url): + super().__init__(url) + self._verify_certificate = False + + +class SciLog(): + + def __init__(self, url="https://lnode2.psi.ch/api/v1"): + self.http_client = SciLogRestAPI(url) + self.logbook = None + + def select_logbook(self, logbook:type(Basesnippet)): + self.logbook = logbook + + @pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"]) def get_snippets(self, **kwargs): - url = self.address + "/basesnippets" - params = make_filter(**kwargs) - headers = self.auth_headers - return get_request(url, params=params, headers=headers) + url = self.http_client.address + "/basesnippets" + params = self.http_client.make_filter(where=kwargs) + headers = HEADER_JSON.copy() + return Basesnippet.from_http_response(self.http_client.get_request(url, params=params, headers=headers)) + @pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"]) + def send_message(self, msg, **kwargs): + url = self.http_client.address + "/basesnippets" + snippet = Paragraph() + snippet.import_dict(kwargs) + snippet.textcontent = msg + payload = snippet.to_dict(include_none=False) + return Basesnippet.from_http_response(self.http_client.post_request(url, payload=payload, headers=HEADER_JSON)) + + @pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"]) def post_snippet(self, **kwargs): - url = self.address + "/basesnippets" + url = self.http_client.address + "/basesnippets" payload = kwargs - headers = self.auth_headers - return post_request(url, payload=payload, headers=headers) + return Basesnippet.from_http_response(self.http_client.post_request(url, payload=payload, headers=HEADER_JSON)) + def get_logbooks(self, **kwargs): + url = self.http_client.address + "/basesnippets" + snippet = Basesnippet() + snippet.import_dict(kwargs) + snippet.snippetType = "logbook" + params = self.http_client.make_filter(where=snippet.to_dict(include_none=False)) + return Basesnippet.from_http_response(self.http_client.get_request(url, params=params, headers=HEADER_JSON)) class SciLogAuthError(AuthError): diff --git a/scilog/snippet.py b/scilog/snippet.py new file mode 100644 index 0000000..c717b00 --- /dev/null +++ b/scilog/snippet.py @@ -0,0 +1,122 @@ +import functools +from typing import get_type_hints +from .utils import typename + + +def typechecked(func): + @functools.wraps(func) + def typechecked_call(obj, *args, **kwargs): + type_hints = get_type_hints(func) + del type_hints["return"] + for arg, dtype in zip(args, type_hints.values()): + arg_type = type(arg) + if dtype != arg_type: + raise TypeError(f"{func} expected to receive input of type {dtype.__name__} but received {arg_type.__name__}") + return func(obj, *args, **kwargs) + return typechecked_call + + +def property_maker(name, dtype): + storage_name = '_' + name + + @property + def prop(self) -> dtype: + return getattr(self, storage_name) + + @prop.setter + @typechecked + def prop(self, value: dtype) -> None: + setattr(self, storage_name, value) + + return prop + + + +class Snippet: + + def __init__(self, snippetType="snippet"): + self._properties = [] + self.init_properties(snippetType=str) + self.snippetType = snippetType + + def init_properties(self, **kwargs): + for name, dtype in kwargs.items(): + storage_name = '_' + name + cls = type(self) + setattr(cls, storage_name, None) + setattr(cls, name, property_maker(name, dtype)) + self._properties.append(name) + + def to_dict(self, include_none=True): + if include_none: + return {key: getattr(self, key) for key in self._properties} + else: + return {key: getattr(self, key) for key in self._properties if getattr(self, key) is not None} + + def import_dict(self, properties): + for name, value in properties.items(): + setattr(self, name, value) + + @classmethod + def from_dict(cls, properties): + new = cls() + new.import_dict(properties) + return new + + def __str__(self): + return typename(self) + + @classmethod + def from_http_response(cls, response): + if isinstance(response, list): + return [cls.from_dict(resp) for resp in response] + else: + return cls.from_dict(response) + + + +class Basesnippet(Snippet): + + def __init__(self, snippetType="basesnippet"): + super().__init__(snippetType=snippetType) + self.init_properties( + id=str, + parentId=str, + ownerGroup=str, + accessGroups=list, + isPrivate=bool, + createdAt=str, + createdBy=str, + updatedAt=str, + updateBy=str, + subsnippets=list, + tags=list, + dashboardName=str, + files=list, + location=str, + defaultOrder=int, + linkType=str, + versionable=bool, + deleted=bool + ) + + + +class Paragraph(Basesnippet): + + def __init__(self, snippetType="paragraph"): + super().__init__(snippetType=snippetType) + self.init_properties( + textcontent=str, + isMessage=str + ) + + + +if __name__ == "__main__": + tmp = Snippet(id=str, textcontent=str, defaultOrder=int) + print(tmp.id) + tmp.id = 2 + + +