Merge branch 'authMixin' into 'master'
Minor redesign See merge request augustin_s/py_scilog!1
This commit is contained in:
@ -2,11 +2,13 @@
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("pgroup", help="Expected form: p12345")
|
||||
parser.add_argument("-u", "--url", default="https://lnode2.psi.ch/api/v1", help="Server address")
|
||||
|
||||
clargs = parser.parse_args()
|
||||
pgroup = clargs.pgroup
|
||||
url = clargs.url
|
||||
|
||||
|
||||
import urllib3
|
||||
@ -15,23 +17,23 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
from scilog import SciLog
|
||||
from scilog import Basesnippet, Paragraph
|
||||
|
||||
tmp = Basesnippet()
|
||||
tmp.id = "2"
|
||||
|
||||
url = "https://lnode2.psi.ch/api/v1"
|
||||
log = SciLog(url)
|
||||
#print(log.token)
|
||||
loc = log.get_snippets(title="location", ownerGroup="admin")
|
||||
logbooks = log.get_logbooks(ownerGroup=pgroup)
|
||||
print(logbooks)
|
||||
|
||||
assert len(loc) == 1
|
||||
loc_id = loc[0]["id"]
|
||||
print(loc_id)
|
||||
assert len(logbooks) == 1
|
||||
logbook = logbooks[0]
|
||||
print(logbook)
|
||||
|
||||
lb = log.get_snippets(snippetType="logbook", ownerGroup=pgroup)
|
||||
log.select_logbook(logbook)
|
||||
|
||||
assert len(lb) == 1
|
||||
lb_id = lb[0]["id"]
|
||||
print(lb_id)
|
||||
|
||||
res = log.post_snippet(snippetType="paragraph", ownerGroup=pgroup, parentId=lb_id, textcontent="<p>from python</p>")
|
||||
res = log.send_message("<p>from python</p>")
|
||||
print(res)
|
||||
|
||||
snips = log.get_snippets(snippetType="paragraph", ownerGroup=pgroup)
|
||||
|
@ -1,5 +1,6 @@
|
||||
|
||||
from .scicat import SciCat
|
||||
from .scilog import SciLog
|
||||
from .snippet import Basesnippet, Paragraph
|
||||
|
||||
|
||||
|
@ -4,13 +4,13 @@ from .config import Config
|
||||
from .utils import typename
|
||||
|
||||
|
||||
AUTH_HEADERS = {
|
||||
HEADER_JSON = {
|
||||
"Content-type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
|
||||
class AuthClient(ABC):
|
||||
class AuthMixin(ABC):
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address.rstrip("/")
|
||||
@ -22,19 +22,15 @@ class AuthClient(ABC):
|
||||
tn = typename(self)
|
||||
return f"{tn} @ {self.address}"
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self, username, password):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def auth_headers(self):
|
||||
headers = AUTH_HEADERS.copy()
|
||||
headers["Authorization"] = self.token
|
||||
return headers
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._retrieve_token()
|
||||
|
||||
def _retrieve_token(self):
|
||||
username = getpass.getuser()
|
||||
token = self._token
|
||||
if token is None:
|
@ -11,7 +11,7 @@ class Config(dict):
|
||||
folder = Path.home()
|
||||
self.fname = folder / fname
|
||||
content = self._load()
|
||||
super().update(content)
|
||||
super().__init__(content)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
self.update(**{name: value})
|
||||
@ -30,6 +30,9 @@ class Config(dict):
|
||||
def _save(self):
|
||||
json_save(self, self.fname)
|
||||
|
||||
def delete(self):
|
||||
self.fname.unlink()
|
||||
|
||||
|
||||
|
||||
def json_save(what, filename, *args, indent=4, sort_keys=True, **kwargs):
|
||||
|
78
scilog/httpclient.py
Normal file
78
scilog/httpclient.py
Normal file
@ -0,0 +1,78 @@
|
||||
import functools
|
||||
import json
|
||||
import requests
|
||||
|
||||
from .authmixin import AuthMixin, AuthError, HEADER_JSON
|
||||
|
||||
|
||||
def authenticated(func):
|
||||
@functools.wraps(func)
|
||||
def authenticated_call(client, *args, **kwargs):
|
||||
if not isinstance(client, HttpClient):
|
||||
raise AttributeError("First argument must be an instance of HttpClient")
|
||||
if "headers" in kwargs:
|
||||
kwargs["headers"] = kwargs["headers"].copy()
|
||||
else:
|
||||
kwargs["headers"] = {}
|
||||
kwargs["headers"]["Authorization"] = client.token
|
||||
return func(client, *args, **kwargs)
|
||||
return authenticated_call
|
||||
|
||||
|
||||
class HttpClient(AuthMixin):
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
self._verify_certificate = True
|
||||
self.login_path = self.address + "/users/login"
|
||||
super().__init__(address)
|
||||
|
||||
def authenticate(self, username, password):
|
||||
auth_payload = {
|
||||
"principal": username,
|
||||
"password": password
|
||||
}
|
||||
res = self._login(auth_payload, HEADER_JSON)
|
||||
try:
|
||||
token = "Bearer " + res["token"]
|
||||
except KeyError as e:
|
||||
raise AuthError(res) from e
|
||||
else:
|
||||
return token
|
||||
|
||||
@authenticated
|
||||
def get_request(self, url, params=None, headers=None, timeout=10):
|
||||
response = requests.get(url, params=params, headers=headers, timeout=timeout, verify=self._verify_certificate)
|
||||
if response.ok:
|
||||
return response.json()
|
||||
else:
|
||||
if response.reason == "Unauthorized":
|
||||
self.config.delete()
|
||||
raise response.raise_for_status()
|
||||
|
||||
@authenticated
|
||||
def post_request(self, url, payload=None, headers=None, timeout=10):
|
||||
return requests.post(url, json=payload, headers=headers, timeout=timeout, verify=self._verify_certificate).json()
|
||||
|
||||
def _login(self, payload=None, headers=None, timeout=10):
|
||||
return requests.post(self.login_path, json=payload, headers=headers, timeout=timeout, verify=self._verify_certificate).json()
|
||||
|
||||
@staticmethod
|
||||
def make_filter(where:dict=None, limit:int=0, skip:int=0, fields:dict=None, include:dict=None, order:list=None):
|
||||
filt = dict()
|
||||
if where is not None:
|
||||
items = [where.copy()]
|
||||
filt["where"] = {"and": items}
|
||||
if limit > 0:
|
||||
filt["limit"] = limit
|
||||
if skip > 0:
|
||||
filt["skip"] = skip
|
||||
if fields is not None:
|
||||
filt["fields"] = include
|
||||
if order is not None:
|
||||
filt["order"] = order
|
||||
filt = json.dumps(filt)
|
||||
return {"filter": filt}
|
||||
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
import json
|
||||
|
||||
|
||||
def make_filter(**kwargs):
|
||||
items = [{k: v} for k, v in kwargs.items()]
|
||||
filt = {"where": {"and": items}}
|
||||
filt = json.dumps(filt)
|
||||
return {"filter": filt}
|
||||
|
||||
|
||||
|
@ -1,29 +1,35 @@
|
||||
from .authclient import AuthClient, AuthError, AUTH_HEADERS
|
||||
from .utils import post_request, get_request
|
||||
from .authmixin import AuthMixin, AuthError, HEADER_JSON
|
||||
from .httpclient import HttpClient
|
||||
|
||||
|
||||
class SciCat(AuthClient):
|
||||
class SciCatRestAPI(HttpClient):
|
||||
def __init__(self, url):
|
||||
super().__init__(url)
|
||||
self.login_path = "https://dacat.psi.ch/auth/msad"
|
||||
|
||||
def authenticate(self, username, password):
|
||||
url = self.address + "/users/login"
|
||||
auth_payload = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
res = post_request(url, auth_payload, AUTH_HEADERS)
|
||||
res = self._login(auth_payload, HEADER_JSON)
|
||||
try:
|
||||
token = res["id"]
|
||||
token = res["access_token"]
|
||||
except KeyError as e:
|
||||
raise SciCatAuthError(res) from e
|
||||
else:
|
||||
return token
|
||||
|
||||
class SciCat():
|
||||
|
||||
def __init__(self, url="https://dacat.psi.ch/api/v3/"):
|
||||
self.http_client = SciCatRestAPI(url)
|
||||
|
||||
|
||||
@property
|
||||
def proposals(self):
|
||||
url = self.address + "/proposals"
|
||||
headers = self.auth_headers
|
||||
return get_request(url, headers=headers)
|
||||
url = self.http_client.address + "/proposals"
|
||||
return self.http_client.get_request(url, headers=HEADER_JSON)
|
||||
|
||||
|
||||
|
||||
|
@ -1,37 +1,74 @@
|
||||
from .authclient import AuthClient, AuthError, AUTH_HEADERS
|
||||
from .utils import post_request, get_request
|
||||
from .mkfilt import make_filter
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from .authmixin import AuthError, HEADER_JSON
|
||||
from .httpclient import HttpClient
|
||||
from .snippet import Snippet, Basesnippet, Paragraph
|
||||
|
||||
|
||||
class SciLog(AuthClient):
|
||||
|
||||
def authenticate(self, username, password):
|
||||
url = self.address + "/users/login"
|
||||
auth_payload = {
|
||||
"principal": username,
|
||||
"password": password
|
||||
}
|
||||
res = post_request(url, auth_payload, AUTH_HEADERS)
|
||||
try:
|
||||
token = "Bearer " + res["token"]
|
||||
except KeyError as e:
|
||||
raise SciLogAuthError(res) from e
|
||||
else:
|
||||
return token
|
||||
def pinned_to_logbook(logbook_keys):
|
||||
def pinned_to_logbook_inner(func):
|
||||
@functools.wraps(func)
|
||||
def pinned_to_logbook_call(log, *args, **kwargs):
|
||||
if not isinstance(log.logbook, Basesnippet):
|
||||
warnings.warn("No logbook selected.")
|
||||
else:
|
||||
for key in logbook_keys:
|
||||
if key not in kwargs:
|
||||
if key == "parentId":
|
||||
kwargs[key] = log.logbook.id
|
||||
else:
|
||||
kwargs[key] = getattr(log.logbook, key)
|
||||
return func(log, *args, **kwargs)
|
||||
return pinned_to_logbook_call
|
||||
return pinned_to_logbook_inner
|
||||
|
||||
|
||||
class SciLogRestAPI(HttpClient):
|
||||
def __init__(self, url):
|
||||
super().__init__(url)
|
||||
self._verify_certificate = False
|
||||
|
||||
|
||||
class SciLog():
|
||||
|
||||
def __init__(self, url="https://lnode2.psi.ch/api/v1"):
|
||||
self.http_client = SciLogRestAPI(url)
|
||||
self.logbook = None
|
||||
|
||||
def select_logbook(self, logbook:type(Basesnippet)):
|
||||
self.logbook = logbook
|
||||
|
||||
@pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"])
|
||||
def get_snippets(self, **kwargs):
|
||||
url = self.address + "/basesnippets"
|
||||
params = make_filter(**kwargs)
|
||||
headers = self.auth_headers
|
||||
return get_request(url, params=params, headers=headers)
|
||||
url = self.http_client.address + "/basesnippets"
|
||||
params = self.http_client.make_filter(where=kwargs)
|
||||
headers = HEADER_JSON.copy()
|
||||
return Basesnippet.from_http_response(self.http_client.get_request(url, params=params, headers=headers))
|
||||
|
||||
@pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"])
|
||||
def send_message(self, msg, **kwargs):
|
||||
url = self.http_client.address + "/basesnippets"
|
||||
snippet = Paragraph()
|
||||
snippet.import_dict(kwargs)
|
||||
snippet.textcontent = msg
|
||||
payload = snippet.to_dict(include_none=False)
|
||||
return Basesnippet.from_http_response(self.http_client.post_request(url, payload=payload, headers=HEADER_JSON))
|
||||
|
||||
@pinned_to_logbook(["parentId", "ownerGroup", "accessGroups"])
|
||||
def post_snippet(self, **kwargs):
|
||||
url = self.address + "/basesnippets"
|
||||
url = self.http_client.address + "/basesnippets"
|
||||
payload = kwargs
|
||||
headers = self.auth_headers
|
||||
return post_request(url, payload=payload, headers=headers)
|
||||
return Basesnippet.from_http_response(self.http_client.post_request(url, payload=payload, headers=HEADER_JSON))
|
||||
|
||||
def get_logbooks(self, **kwargs):
|
||||
url = self.http_client.address + "/basesnippets"
|
||||
snippet = Basesnippet()
|
||||
snippet.import_dict(kwargs)
|
||||
snippet.snippetType = "logbook"
|
||||
params = self.http_client.make_filter(where=snippet.to_dict(include_none=False))
|
||||
return Basesnippet.from_http_response(self.http_client.get_request(url, params=params, headers=HEADER_JSON))
|
||||
|
||||
|
||||
class SciLogAuthError(AuthError):
|
||||
|
122
scilog/snippet.py
Normal file
122
scilog/snippet.py
Normal file
@ -0,0 +1,122 @@
|
||||
import functools
|
||||
from typing import get_type_hints
|
||||
from .utils import typename
|
||||
|
||||
|
||||
def typechecked(func):
|
||||
@functools.wraps(func)
|
||||
def typechecked_call(obj, *args, **kwargs):
|
||||
type_hints = get_type_hints(func)
|
||||
del type_hints["return"]
|
||||
for arg, dtype in zip(args, type_hints.values()):
|
||||
arg_type = type(arg)
|
||||
if dtype != arg_type:
|
||||
raise TypeError(f"{func} expected to receive input of type {dtype.__name__} but received {arg_type.__name__}")
|
||||
return func(obj, *args, **kwargs)
|
||||
return typechecked_call
|
||||
|
||||
|
||||
def property_maker(name, dtype):
|
||||
storage_name = '_' + name
|
||||
|
||||
@property
|
||||
def prop(self) -> dtype:
|
||||
return getattr(self, storage_name)
|
||||
|
||||
@prop.setter
|
||||
@typechecked
|
||||
def prop(self, value: dtype) -> None:
|
||||
setattr(self, storage_name, value)
|
||||
|
||||
return prop
|
||||
|
||||
|
||||
|
||||
class Snippet:
|
||||
|
||||
def __init__(self, snippetType="snippet"):
|
||||
self._properties = []
|
||||
self.init_properties(snippetType=str)
|
||||
self.snippetType = snippetType
|
||||
|
||||
def init_properties(self, **kwargs):
|
||||
for name, dtype in kwargs.items():
|
||||
storage_name = '_' + name
|
||||
cls = type(self)
|
||||
setattr(cls, storage_name, None)
|
||||
setattr(cls, name, property_maker(name, dtype))
|
||||
self._properties.append(name)
|
||||
|
||||
def to_dict(self, include_none=True):
|
||||
if include_none:
|
||||
return {key: getattr(self, key) for key in self._properties}
|
||||
else:
|
||||
return {key: getattr(self, key) for key in self._properties if getattr(self, key) is not None}
|
||||
|
||||
def import_dict(self, properties):
|
||||
for name, value in properties.items():
|
||||
setattr(self, name, value)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, properties):
|
||||
new = cls()
|
||||
new.import_dict(properties)
|
||||
return new
|
||||
|
||||
def __str__(self):
|
||||
return typename(self)
|
||||
|
||||
@classmethod
|
||||
def from_http_response(cls, response):
|
||||
if isinstance(response, list):
|
||||
return [cls.from_dict(resp) for resp in response]
|
||||
else:
|
||||
return cls.from_dict(response)
|
||||
|
||||
|
||||
|
||||
class Basesnippet(Snippet):
|
||||
|
||||
def __init__(self, snippetType="basesnippet"):
|
||||
super().__init__(snippetType=snippetType)
|
||||
self.init_properties(
|
||||
id=str,
|
||||
parentId=str,
|
||||
ownerGroup=str,
|
||||
accessGroups=list,
|
||||
isPrivate=bool,
|
||||
createdAt=str,
|
||||
createdBy=str,
|
||||
updatedAt=str,
|
||||
updateBy=str,
|
||||
subsnippets=list,
|
||||
tags=list,
|
||||
dashboardName=str,
|
||||
files=list,
|
||||
location=str,
|
||||
defaultOrder=int,
|
||||
linkType=str,
|
||||
versionable=bool,
|
||||
deleted=bool
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Paragraph(Basesnippet):
|
||||
|
||||
def __init__(self, snippetType="paragraph"):
|
||||
super().__init__(snippetType=snippetType)
|
||||
self.init_properties(
|
||||
textcontent=str,
|
||||
isMessage=str
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tmp = Snippet(id=str, textcontent=str, defaultOrder=int)
|
||||
print(tmp.id)
|
||||
tmp.id = 2
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user