2 Commits

Author SHA1 Message Date
96a2bfb362 acc.qss for Injector/Cyclotron 2025-01-07 12:38:15 +01:00
6b59fe16ce added 3rd party packages, elog, bigtree 2024-02-27 15:40:00 +01:00
103 changed files with 20854 additions and 1 deletions

View File

@@ -250,7 +250,7 @@ QGroupBox::title#MACHINE2 {
border: 2px solid #98c998;
border-radius: 3px;
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #ffffff, stop: 1#98c998);
stop: 0 #ffffff, stop: 1 #98c998);
}
QGroupBox#Machine::disabled
@@ -271,6 +271,102 @@ QGroupBox#Porthos::disabled
}
QWidget#INJECTOR, QTabWidget#INJECTOR
{
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #FFFFFF, stop: 1 #008b8b);
color: black;
font-size: 10pt;
font-style: normal;
font-weight: 600;
font-family: "Sans Serif";
border-radius: 0px;
margin-top: 0.0ex;
margin-left: 0.0ex;
padding-top: 2px;
padding-bottom: 4px;
}
QGroupBox#INJECTOR
{
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #008b8b, stop: 1 #ffffff);
color: #008b8b;
font-size: 10pt;
font-style: normal;
font-weight: 600;
font-family: "Sans Serif";
border: 2px solid #008b8b;
border-radius: 5px;
margin-top: 1.5ex;
margin-left: 0.0ex;
margin-bottom: 0.0ex;
padding-top: 2px;
padding-bottom: 4px;
qproperty-alignment: 'AlignCenter | AlignVCenter';
}
QGroupBox::title#INJECTOR {
subcontrol-origin: margin;
subcontrol-position: top center;
padding: 2px 2px 2px 2px;
margin: 0px 0px 0px 0px;
border: 2px solid #008b8b;
border-radius: 3px;
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #ffffff , stop: 1 #008b8b);
}
QWidget#CYCLOTRON, QTabWidget#CYCLOTRON
{
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #FFFFFF, stop: 1 #000047ab);
color: black;
font-size: 10pt;
font-style: normal;
font-weight: 600;
font-family: "Sans Serif";
border-radius: 0px;
margin-top: 0.0ex;
margin-left: 0.0ex;
padding-top: 2px;
padding-bottom: 4px;
}
QGroupBox#CYCLOTRON
{
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #0047ab, stop: 1 #ffffff);
color: #0047ab;
font-size: 10pt;
font-style: normal;
font-weight: 600;
font-family: "Sans Serif";
border: 2px solid #0047ab;
border-radius: 5px;
margin-top: 1.5ex;
margin-left: 0.0ex;
margin-bottom: 0.0ex;
padding-top: 2px;
padding-bottom: 4px;
qproperty-alignment: 'AlignCenter | AlignVCenter';
}
QGroupBox::title#CYCLOTRON {
subcontrol-origin: margin;
subcontrol-position: top center;
padding: 2px 2px 2px 2px;
margin: 0px 0px 0px 0px;
border: 2px solid #0047ab;
border-radius: 3px;
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
stop: 0 #ffffff , stop: 1 #0047ab);
}
QWidget#MACHINE, QTabWidget#MACHINE
{
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,

View File

@@ -0,0 +1,13 @@
Metadata-Version: 1.1
Name: pyscan
Version: 2.8.0
Summary: PyScan is a python class that performs a scan for single or multiple given knobs.
Home-page: UNKNOWN
Author: Paul Scherrer Institute
Author-email: UNKNOWN
License: UNKNOWN
Description: UNKNOWN
Platform: UNKNOWN
Requires: numpy
Requires: pcaspy
Requires: requests

View File

@@ -0,0 +1,32 @@
README.md
setup.py
pyscan/__init__.py
pyscan/config.py
pyscan/scan.py
pyscan/scan_actions.py
pyscan/scan_parameters.py
pyscan/scanner.py
pyscan/utils.py
pyscan.egg-info/PKG-INFO
pyscan.egg-info/SOURCES.txt
pyscan.egg-info/dependency_links.txt
pyscan.egg-info/top_level.txt
pyscan/dal/__init__.py
pyscan/dal/bsread_dal.py
pyscan/dal/epics_dal.py
pyscan/dal/function_dal.py
pyscan/dal/pshell_dal.py
pyscan/interface/__init__.py
pyscan/interface/pshell.py
pyscan/interface/pyScan/__init__.py
pyscan/interface/pyScan/scan.py
pyscan/interface/pyScan/utils.py
pyscan/positioner/__init__.py
pyscan/positioner/area.py
pyscan/positioner/bsread.py
pyscan/positioner/compound.py
pyscan/positioner/line.py
pyscan/positioner/serial.py
pyscan/positioner/static.py
pyscan/positioner/time.py
pyscan/positioner/vector.py

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@
pyscan

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,111 @@
Metadata-Version: 1.0
Name: elog
Version: 1.3.4
Summary: Python library to access Elog.
Home-page: https://github.com/paulscherrerinstitute/py_elog
Author: Paul Scherrer Institute (PSI)
Author-email: UNKNOWN
License: UNKNOWN
Description: [![Build Status](https://travis-ci.org/paulscherrerinstitute/py_elog.svg?branch=master)](https://travis-ci.org/paulscherrerinstitute/py_elog) [![Build status](https://ci.appveyor.com/api/projects/status/glo428gqw951y512?svg=true)](https://ci.appveyor.com/project/simongregorebner/py-elog)
# Overview
This Python module provides a native interface [electronic logbooks](https://midas.psi.ch/elog/). It is compatible with Python versions 3.5 and higher.
# Usage
For accessing a logbook at ```http[s]://<hostename>:<port>/[<subdir>/]<logbook>/[<msg_id>]``` a logbook handle must be retrieved.
```python
import elog
# Open GFA SwissFEL test logbook
logbook = elog.open('https://elog-gfa.psi.ch/SwissFEL+test/')
# Contstructor using detailed arguments
# Open demo logbook on local host: http://localhost:8080/demo/
logbook = elog.open('localhost', 'demo', port=8080, use_ssl=False)
```
Once you have hold of the logbook handle one of its public methods can be used to read, create, reply to, edit or delete the message.
## Get Existing Message Ids
Get all the existing message ids of a logbook
```python
message_ids = logbook.get_message_ids()
```
To get if of the last inserted message
```python
last_message_id = logbook.get_last_message_id()
```
## Read Message
```python
# Read message with with message ID = 23
message, attributes, attachments = logbook.read(23)
```
## Create Message
```python
# Create new message with some text, attributes (dict of attributes + kwargs) and attachments
new_msg_id = logbook.post('This is message text', attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
```
What attributes are required is determined by the configuration of the elog server (keywork `Required Attributes`).
If the configuration looks like this:
```
Required Attributes = Author, Type
```
You have to provide author and type when posting a message.
In case type need to be specified, the supported keywords can as well be found in the elog configuration with the key `Options Type`.
If the config looks like this:
```
Options Type = Routine, Software Installation, Problem Fixed, Configuration, Other
```
A working create call would look like this:
```python
new_msg_id = logbook.post('This is message text', author='me', type='Routine')
```
## Reply to Message
```python
# Reply to message with ID=23
new_msg_id = logbook.post('This is a reply', msg_id=23, reply=True, attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
```
## Edit Message
```python
# Edit message with ID=23. Changed message text, some attributes (dict of edited attributes + kwargs) and new attachments
edited_msg_id = logbook.post('This is new message text', msg_id=23, attributes=dict_of_changed_attributes, attachments=list_of_new_attachments, attribute_as_param='new value')
```
## Delete Message (and all its replies)
```python
# Delete message with ID=23. All its replies will also be deleted.
logbook.delete(23)
```
__Note:__ Due to the way elog implements delete this function is only supported on english logbooks.
# Installation
The Elog module and only depends on the `passlib` and `requests` library used for password encryption and http(s) communication. It is packed as [anaconda package](https://anaconda.org/paulscherrerinstitute/elog) and can be installed as follows:
```bash
conda install -c paulscherrerinstitute elog
```
Keywords: elog,electronic,logbook
Platform: UNKNOWN

View File

@@ -0,0 +1,8 @@
setup.py
elog/__init__.py
elog/logbook.py
elog/logbook_exceptions.py
elog.egg-info/PKG-INFO
elog.egg-info/SOURCES.txt
elog.egg-info/dependency_links.txt
elog.egg-info/top_level.txt

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@
elog

View File

@@ -0,0 +1 @@

1
packages/elog.pth Normal file
View File

@@ -0,0 +1 @@
./elog-1.3.4-py3.7.egg

13
packages/elog/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from elog.logbook import Logbook
from elog.logbook import LogbookError, LogbookAuthenticationError, LogbookServerProblem, LogbookMessageRejected, \
LogbookInvalidMessageID, LogbookInvalidAttachmentType
def open(*args, **kwargs):
"""
Will return a Logbook object. All arguments are passed to the logbook constructor.
:param args:
:param kwargs:
:return: Logbook() instance
"""
return Logbook(*args, **kwargs)

571
packages/elog/logbook.py Normal file
View File

@@ -0,0 +1,571 @@
import requests
import urllib.parse
import os
import builtins
import re
from elog.logbook_exceptions import *
from datetime import datetime
class Logbook(object):
"""
Logbook provides methods to interface with logbook on location: "server:port/subdir/logbook". User can create,
edit, delete logbook messages.
"""
def __init__(self, hostname, logbook='', port=None, user=None, password=None, subdir='', use_ssl=True,
encrypt_pwd=True):
"""
:param hostname: elog server hostname. If whole url is specified here, it will be parsed and arguments:
"logbook, port, subdir, use_ssl" will be overwritten by parsed values.
:param logbook: name of the logbook on the elog server
:param port: elog server port (if not specified will default to '80' if use_ssl=False or '443' if use_ssl=True
:param user: username (if authentication needed)
:param password: password (if authentication needed) Password will be encrypted with sha256 unless
encrypt_pwd=False (default: True)
:param subdir: subdirectory of logbooks locations
:param use_ssl: connect using ssl (ignored if url starts with 'http://'' or 'https://'?
:param encrypt_pwd: To avoid exposing password in the code, this flag can be set to False and password
will then be handled as it is (user needs to provide sha256 encrypted password with
salt= '' and rounds=5000)
:return:
"""
hostname = hostname.strip()
# parse url to see if some parameters are defined with url
parsed_url = urllib.parse.urlsplit(hostname)
# ---- handle SSL -----
# hostname must be modified according to use_ssl flag. If hostname starts with https:// or http://
# the use_ssl flag is ignored
url_scheme = parsed_url.scheme
if url_scheme == 'http':
use_ssl = False
elif url_scheme == 'https':
use_ssl = True
elif not url_scheme:
# add http or https
if use_ssl:
url_scheme = 'https'
else:
url_scheme = 'http'
# ---- handle port -----
# 1) by default use port defined in the url
# 2) remove any 'default' ports such as 80 for http and 443 for https
# 3) if port not defined in url and not 'default' add it to netloc
netloc = parsed_url.netloc
if netloc == "" and "localhost" in hostname:
netloc = 'localhost'
netloc_split = netloc.split(':')
if len(netloc_split) > 1:
# port defined in url --> remove if needed
port = netloc_split[1]
if (port == 80 and not use_ssl) or (port == 443 and use_ssl):
netloc = netloc_split[0]
else:
# add port info if needed
if port is not None and not (port == 80 and not use_ssl) and not (port == 443 and use_ssl):
netloc += ':{}'.format(port)
# ---- handle subdir and logbook -----
# parsed_url.path = /<subdir>/<logbook>/
# Remove last '/' for easier parsing
url_path = parsed_url.path
if url_path.endswith('/'):
url_path = url_path[:-1]
splitted_path = url_path.split('/')
if url_path and len(splitted_path) > 1:
# If here ... then at least some part of path is defined.
# If logbook defined --> treat path current path as subdir and add logbook at the end
# to define the full path. Else treat existing path as <subdir>/<logbook>.
# Put first and last '/' back on its place
if logbook:
url_path += '/{}'.format(logbook)
else:
logbook = splitted_path[-1]
else:
# There is nothing. Use arguments.
url_path = subdir + '/' + logbook
# urllib.parse.quote replaces special characters with %xx escapes
# self._logbook_path = urllib.parse.quote('/' + url_path + '/').replace('//', '/')
self._logbook_path = ('/' + url_path + '/').replace('//', '/')
self._url = url_scheme + '://' + netloc + self._logbook_path
self.logbook = logbook
self._user = user
self._password = _handle_pswd(password, encrypt_pwd)
def post(self, message, msg_id=None, reply=False, attributes=None, attachments=None, encoding=None,
**kwargs):
"""
Posts message to the logbook. If msg_id is not specified new message will be created, otherwise existing
message will be edited, or a reply (if reply=True) to it will be created. This method returns the msg_id
of the newly created message.
:param message: string with message text
:param msg_id: ID number of message to edit or reply. If not specified new message is created.
:param reply: If 'True' reply to existing message is created instead of editing it
:param attributes: Dictionary of attributes. Following attributes are used internally by the elog and will be
ignored: Text, Date, Encoding, Reply to, In reply to, Locked by, Attachment
:param attachments: list of:
- file like objects which read() will return bytes (if file_like_object.name is not
defined, default name "attachment<i>" will be used.
- paths to the files
All items will be appended as attachment to the elog entry. In case of unknown
attachment an exception LogbookInvalidAttachment will be raised.
:param encoding: Defines encoding of the message. Can be: 'plain' -> plain text, 'html'->html-text,
'ELCode' --> elog formatting syntax
:param kwargs: Anything in the kwargs will be interpreted as attribute. e.g.: logbook.post('Test text',
Author='Rok Vintar), "Author" will be sent as an attribute. If named same as one of the
attributes defined in "attributes", kwargs will have priority.
:return: msg_id
"""
attributes = attributes or {}
attributes = {**attributes, **kwargs} # kwargs as attributes with higher priority
attachments = attachments or []
if encoding is not None:
if encoding not in ['plain', 'HTML', 'ELCode']:
raise LogbookMessageRejected('Invalid message encoding. Valid options: plain, HTML, ELCode.')
attributes['Encoding'] = encoding
attributes_to_edit = dict()
if msg_id:
# Message exists, we can continue
if reply:
# Verify that there is a message on the server, otherwise do not reply to it!
self._check_if_message_on_server(msg_id) # raises exception in case of none existing message
attributes['reply_to'] = str(msg_id)
else: # Edit existing
attributes['edit_id'] = str(msg_id)
attributes['skiplock'] = '1'
# Handle existing attachments
msg_to_edit, attributes_to_edit, attach_to_edit = self.read(msg_id)
i = 0
for attachment in attach_to_edit:
if attachment:
# Existing attachments must be passed as regular arguments attachment<i> with value= file name
# Read message returnes full urls to existing attachments:
# <hostname>:[<port>][/<subdir]/<logbook>/<msg_id>/<file_name>
attributes['attachment' + str(i)] = os.path.basename(attachment)
i += 1
for attribute, data in attributes.items():
new_data = attributes.get(attribute)
if new_data is not None:
attributes_to_edit[attribute] = new_data
else:
# As we create a new message, specify creation time if not already specified in attributes
if 'When' not in attributes:
attributes['When'] = int(datetime.now().timestamp())
if not attributes_to_edit:
attributes_to_edit = attributes
# Remove any attributes that should not be sent
_remove_reserved_attributes(attributes_to_edit)
if attachments:
files_to_attach, objects_to_close = self._prepare_attachments(attachments)
else:
objects_to_close = list()
files_to_attach = list()
# Make requests module think that Text is a "file". This is the only way to force requests to send data as
# multipart/form-data even if there are no attachments. Elog understands only multipart/form-data
files_to_attach.append(('Text', ('', message)))
# Base attributes are common to all messages
self._add_base_msg_attributes(attributes_to_edit)
# Keys in attributes cannot have certain characters like whitespaces or dashes for the http request
attributes_to_edit = _replace_special_characters_in_attribute_keys(attributes_to_edit)
try:
response = requests.post(self._url, data=attributes_to_edit, files=files_to_attach, allow_redirects=False,
verify=False)
# Validate response. Any problems will raise an Exception.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
# Close file like objects that were opened by the elog (if path
for file_like_object in objects_to_close:
if hasattr(file_like_object, 'close'):
file_like_object.close()
except requests.RequestException as e:
# Check if message on server.
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to post a message, ' + 'because of:\n' +
'{0}'.format(e))
# Any error before here should raise an exception, but check again for nay case.
if not resp_msg_id or resp_msg_id < 1:
raise LogbookInvalidMessageID('Invalid message ID: ' + str(resp_msg_id) + ' returned')
return resp_msg_id
def read(self, msg_id):
"""
Reads message from the logbook server and returns tuple of (message, attributes, attachments) where:
message: string with message body
attributes: dictionary of all attributes returned by the logbook
attachments: list of urls to attachments on the logbook server
:param msg_id: ID of the message to be read
:return: message, attributes, attachments
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
response = requests.get(self._url + str(msg_id) + '?cmd=download', headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read the message with ID: ' + str(msg_id) +
'because of:\n' + '{0}'.format(e))
# Parse message to separate message body, attributes and attachments
attributes = dict()
attachments = list()
returned_msg = resp_message.decode('utf-8', 'ignore').splitlines()
delimiter_idx = returned_msg.index('========================================')
message = '\n'.join(returned_msg[delimiter_idx + 1:])
for line in returned_msg[0:delimiter_idx]:
line = line.split(': ')
data = ''.join(line[1:])
if line[0] == 'Attachment':
attachments = data.split(',')
# Here are only attachment names, make a full url out of it, so they could be
# recognisable by others, and downloaded if needed
attachments = [self._url + '{0}'.format(i) for i in attachments]
else:
attributes[line[0]] = data
return message, attributes, attachments
def delete(self, msg_id):
"""
Deletes message thread (!!!message + all replies!!!) from logbook.
It also deletes all of attachments of corresponding messages from the server.
:param msg_id: message to be deleted
:return:
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
self._check_if_message_on_server(msg_id) # check if something to delete
response = requests.get(self._url + str(msg_id) + '?cmd=Delete&confirm=Yes', headers=request_headers,
allow_redirects=False, verify=False)
_validate_response(response) # raises exception if any other error identified
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to delete the message with ID: ' + str(msg_id) +
'because of:\n' + '{0}'.format(e))
# Additional validation: If successfully deleted then status_code = 302. In case command was not executed at
# all (not English language --> no download command supported) status_code = 200 and the content is just a
# html page of this whole message.
if response.status_code == 200:
raise LogbookServerProblem('Cannot process delete command (only logbooks in English supported).')
def search(self, search_term, n_results = 20, scope="subtext"):
"""
Searches the logbook and returns the message ids.
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
# Putting n_results = 0 crashes the elog. also in the web-gui.
n_results = 1 if n_results < 1 else n_results
params = {
"mode": "full",
"reverse": "1",
"npp": n_results,
scope: search_term
}
try:
response = requests.get(self._url, params=params, headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
_validate_response(response)
resp_message = response
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read message ids '
'because of:\n' + '{0}'.format(e))
from lxml import html
tree = html.fromstring(resp_message.content)
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
message_ids = [int(m.split("/")[-1]) for m in message_ids]
return message_ids
def get_last_message_id(self):
ids = self.get_message_ids()
if len(ids) > 0:
return ids[0]
else:
return None
def get_message_ids(self):
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
response = requests.get(self._url + 'page', headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
_validate_response(response)
resp_message = response
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read message ids '
'because of:\n' + '{0}'.format(e))
from lxml import html
tree = html.fromstring(resp_message.content)
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
message_ids = [int(m.split("/")[-1]) for m in message_ids]
return message_ids
def _check_if_message_on_server(self, msg_id):
"""Try to load page for specific message. If there is a htm tag like <td class="errormsg"> then there is no
such message.
:param msg_id: ID of message to be checked
:return:
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
response = requests.get(self._url + str(msg_id), headers=request_headers, allow_redirects=False,
verify=False)
# If there is no message code 200 will be returned (OK) and _validate_response will not recognise it
# but there will be some error in the html code.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
# If there is no message, code 200 will be returned (OK) but there will be some error indication in
# the html code.
if re.findall('<td.*?class="errormsg".*?>.*?</td>',
resp_message.decode('utf-8', 'ignore'),
flags=re.DOTALL):
raise LogbookInvalidMessageID('Message with ID: ' + str(msg_id) + ' does not exist on logbook.')
except requests.RequestException as e:
raise LogbookServerProblem('No response from the logbook server.\nDetails: ' + '{0}'.format(e))
def _add_base_msg_attributes(self, data):
"""
Adds base message attributes which are used by all messages.
:param data: dict of current attributes
:return: content string
"""
data['cmd'] = 'Submit'
data['exp'] = self.logbook
if self._user:
data['unm'] = self._user
if self._password:
data['upwd'] = self._password
def _prepare_attachments(self, files):
"""
Parses attachments to content objects. Attachments can be:
- file like objects: must have method read() which returns bytes. If it has attribute .name it will be used
for attachment name, otherwise generic attribute<i> name will be used.
- path to the file on disk
Note that if attachment is is an url pointing to the existing Logbook server it will be ignored and no
exceptions will be raised. This can happen if attachments returned with read_method are resend.
:param files: list of file like objects or paths
:return: content string
"""
prepared = list()
i = 0
objects_to_close = list() # objects that are created (opened) by elog must be later closed
for file_obj in files:
if hasattr(file_obj, 'read'):
i += 1
attribute_name = 'attfile' + str(i)
filename = attribute_name # If file like object has no name specified use this one
candidate_filename = os.path.basename(file_obj.name)
if filename: # use only if not empty string
filename = candidate_filename
elif isinstance(file_obj, str):
# Check if it is:
# - a path to the file --> open file and append
# - an url pointing to the existing Logbook server --> ignore
filename = ""
attribute_name = ""
if os.path.isfile(file_obj):
i += 1
attribute_name = 'attfile' + str(i)
file_obj = builtins.open(file_obj, 'rb')
filename = os.path.basename(file_obj.name)
objects_to_close.append(file_obj)
elif not file_obj.startswith(self._url):
raise LogbookInvalidAttachmentType('Invalid type of attachment: \"' + file_obj + '\".')
else:
raise LogbookInvalidAttachmentType('Invalid type of attachment[' + str(i) + '].')
prepared.append((attribute_name, (filename, file_obj)))
return prepared, objects_to_close
def _make_user_and_pswd_cookie(self):
"""
prepares user name and password cookie. It is sent in header when posting a message.
:return: user name and password value for the Cookie header
"""
cookie = ''
if self._user:
cookie += 'unm=' + self._user + ';'
if self._password:
cookie += 'upwd=' + self._password + ';'
return cookie
def _remove_reserved_attributes(attributes):
"""
Removes elog reserved attributes (from the attributes dict) that can not be sent.
:param attributes: dictionary of attributes to be cleaned.
:return:
"""
if attributes:
attributes.get('$@MID@$', None)
attributes.pop('Date', None)
attributes.pop('Attachment', None)
attributes.pop('Text', None) # Remove this one because it will be send attachment like
def _replace_special_characters_in_attribute_keys(attributes):
"""
Replaces special characters in elog attribute keys by underscore, otherwise attribute values will be erased in
the http request. This is using the same replacement elog itself is using to handle these cases
:param attributes: dictionary of attributes to be cleaned.
:return: attributes with replaced keys
"""
return {re.sub('[^0-9a-zA-Z]', '_', key): value for key, value in attributes.items()}
def _validate_response(response):
""" Validate response of the request."""
msg_id = None
if response.status_code not in [200, 302]:
# 200 --> OK; 302 --> Found
# Html page is returned with error description (handling errors same way as on original client. Looks
# like there is no other way.
err = re.findall('<td.*?class="errormsg".*?>.*?</td>',
response.content.decode('utf-8', 'ignore'),
flags=re.DOTALL)
if len(err) > 0:
# Remove html tags
# If part of the message has: Please go back... remove this part since it is an instruction for
# the user when using browser.
err = re.sub('(?:<.*?>)', '', err[0])
if err:
raise LogbookMessageRejected('Rejected because of: ' + err)
else:
raise LogbookMessageRejected('Rejected because of unknown error.')
# Other unknown errors
raise LogbookMessageRejected('Rejected because of unknown error.')
else:
location = response.headers.get('Location')
if location is not None:
if 'has moved' in location:
raise LogbookServerProblem('Logbook server has moved to another location.')
elif 'fail' in location:
raise LogbookAuthenticationError('Invalid username or password.')
else:
# returned locations is something like: '<host>/<sub_dir>/<logbook>/<msg_id><query>
# with urllib.parse.urlparse returns attribute path=<sub_dir>/<logbook>/<msg_id>
msg_id = int(urllib.parse.urlsplit(location).path.split('/')[-1])
if b'form name=form1' in response.content or b'type=password' in response.content:
# Not to smart to check this way, but no other indication of this kind of error.
# C client does it the same way
raise LogbookAuthenticationError('Invalid username or password.')
return response.content, response.headers, msg_id
def _handle_pswd(password, encrypt=True):
"""
Takes password string and returns password as needed by elog. If encrypt=True then password will be
sha256 encrypted (salt='', rounds=5000). Before returning password, any trailing $5$$ will be removed
independent off encrypt flag.
:param password: password string
:param encrypt: encrypt password?
:return: elog prepared password
"""
if encrypt and password:
from passlib.hash import sha256_crypt
return sha256_crypt.encrypt(password, salt='', rounds=5000)[4:]
elif password and password.startswith('$5$$'):
return password[4:]
else:
return password

View File

@@ -0,0 +1,28 @@
class LogbookError(Exception):
""" Parent logbook exception."""
pass
class LogbookAuthenticationError(LogbookError):
""" Raise when problem with username and password."""
pass
class LogbookServerProblem(LogbookError):
""" Raise when problem accessing logbook server."""
pass
class LogbookMessageRejected(LogbookError):
""" Raised when manipulating/creating message was rejected by the server or there was problem composing message."""
pass
class LogbookInvalidMessageID(LogbookMessageRejected):
""" Raised when there is no message with specified ID on the server."""
pass
class LogbookInvalidAttachmentType(LogbookMessageRejected):
""" Raised when passed attachment has invalid type."""
pass

1
packages/pyscan.pth Executable file
View File

@@ -0,0 +1 @@
./pyscan-2.8.0-py3.7.egg

View File

@@ -0,0 +1,19 @@
# Import the scan part.
from .scan import *
from .scan_parameters import *
from .scan_actions import *
from .scanner import *
# Import DALs
from .dal.epics_dal import *
from .dal.bsread_dal import *
from .dal.pshell_dal import *
# Import positioners.
from .positioner.line import *
from .positioner.serial import *
from .positioner.vector import *
from .positioner.area import *
from .positioner.compound import *
from .positioner.time import *
from .positioner.static import *

58
packages/pyscan/config.py Normal file
View File

@@ -0,0 +1,58 @@
#########################
# General configuration #
#########################
# Minimum tolerance for comparing floats.
max_float_tolerance = 0.00001
# 1ms time tolerance for time critical measurements.
max_time_tolerance = 0.05
######################
# Scan configuration #
######################
# Default number of scans.
scan_default_n_measurements = 1
# Default interval between multiple measurements in a single position. Taken into account when n_measurements > 1.
scan_default_measurement_interval = 0
# Interval to sleep while the scan is paused.
scan_pause_sleep_interval = 0.1
# Maximum number of retries to read the channels to get valid data.
scan_acquisition_retry_limit = 3
# Delay between acquisition retries.
scan_acquisition_retry_delay = 1
############################
# BSREAD DAL configuration #
############################
# Queue size for collecting messages from bs_read.
bs_queue_size = 20
# Max time to wait until the bs read message we need arrives.
bs_read_timeout = 5
# Max time to wait for a message (if there is none). Important for stopping threads etc.
bs_receive_timeout = 1
# Default bs_read connection address.
bs_default_host = None
# Default bs_read connection port.
bs_default_port = None
# Default bs connection port.
bs_connection_mode = "sub"
# Default property value for bs properties missing in stream. Exception means to raise an Exception when this happens.
bs_default_missing_property_value = Exception
###########################
# EPICS DAL configuration #
###########################
# Default set and match timeout - how much time a PV has to reach the target value.
epics_default_set_and_match_timeout = 3
# After all motors have reached their destination (set_and_match), extra time to wait.
epics_default_settling_time = 0
############################
# PShell DAL configuration #
############################
pshell_default_server_url = "http://sf-daq-mgmt:8090"
pshell_default_scan_in_background = False

View File

View File

@@ -0,0 +1,186 @@
import math
from time import time
from bsread import Source, mflow
from pyscan import config
from pyscan.utils import convert_to_list
class ReadGroupInterface(object):
"""
Provide a beam synchronous acquisition for PV data.
"""
def __init__(self, properties, conditions=None, host=None, port=None, filter_function=None):
"""
Create the bsread group read interface.
:param properties: List of PVs to read for processing.
:param conditions: List of PVs to read as conditions.
:param filter_function: Filter the BS stream with a custom function.
"""
self.host = host
self.port = port
self.properties = convert_to_list(properties)
self.conditions = convert_to_list(conditions)
self.filter = filter_function
self._message_cache = None
self._message_cache_timestamp = None
self._message_cache_position_index = None
self._connect_bsread(config.bs_default_host, config.bs_default_port)
def _connect_bsread(self, host, port):
# Configure the connection type.
if config.bs_connection_mode.lower() == "sub":
mode = mflow.SUB
elif config.bs_connection_mode.lower() == "pull":
mode = mflow.PULL
if host and port:
self.stream = Source(host=host,
port=port,
queue_size=config.bs_queue_size,
receive_timeout=config.bs_receive_timeout,
mode=mode)
else:
channels = [x.identifier for x in self.properties] + [x.identifier for x in self.conditions]
self.stream = Source(channels=channels,
queue_size=config.bs_queue_size,
receive_timeout=config.bs_receive_timeout,
mode=mode)
self.stream.connect()
@staticmethod
def is_message_after_timestamp(message, timestamp):
"""
Check if the received message was captured after the provided timestamp.
:param message: Message to inspect.
:param timestamp: Timestamp to compare the message to.
:return: True if the message is after the timestamp, False otherwise.
"""
# Receive might timeout, in this case we have nothing to compare.
if not message:
return False
# This is how BSread encodes the timestamp.
current_sec = int(timestamp)
current_ns = int(math.modf(timestamp)[0] * 1e9)
message_sec = message.data.global_timestamp
message_ns = message.data.global_timestamp_offset
# If the seconds are the same, the nanoseconds must be equal or larger.
if message_sec == current_sec:
return message_ns >= current_ns
# If the seconds are not the same, the message seconds need to be larger than the current seconds.
else:
return message_sec > current_sec
@staticmethod
def _get_missing_property_default(property_definition):
"""
In case a bs read value is missing, either return the default value or raise an Exception.
:param property_definition:
:return:
"""
# Exception is defined, raise it.
if Exception == property_definition.default_value:
raise property_definition.default_value("Property '%s' missing in bs stream."
% property_definition.identifier)
# Else just return the default value.
else:
return property_definition.default_value
def _read_pvs_from_cache(self, properties):
"""
Read the requested properties from the cache.
:param properties: List of properties to read.
:return: List with PV values.
"""
if not self._message_cache:
raise ValueError("Message cache is empty, cannot read PVs %s." % properties)
pv_values = []
for property_name, property_definition in ((x.identifier, x) for x in properties):
if property_name in self._message_cache.data.data:
value = self._message_cache.data.data[property_name].value
else:
value = self._get_missing_property_default(property_definition)
# TODO: Check if the python conversion works in every case?
# BS read always return numpy, and we always convert it to Python.
pv_values.append(value)
return pv_values
def read(self, current_position_index=None, retry=False):
"""
Reads the PV values from BSread. It uses the first PVs data sampled after the invocation of this method.
:return: List of values for read pvs. Note: Condition PVs are excluded.
"""
# Perform the actual read.
read_timestamp = time()
while time() - read_timestamp < config.bs_read_timeout:
message = self.stream.receive(filter=self.filter)
if self.is_message_after_timestamp(message, read_timestamp):
self._message_cache = message
self._message_cache_position_index = current_position_index
self._message_cache_timestamp = read_timestamp
return self._read_pvs_from_cache(self.properties)
else:
raise Exception("Read timeout exceeded for BS read stream. Could not find the desired package in time.")
def read_cached_conditions(self):
"""
Returns the conditions associated with the last read command.
:return: List of condition values.
"""
return self._read_pvs_from_cache(self.conditions)
def close(self):
"""
Disconnect from the stream and clear the message cache.
"""
if self.stream:
self.stream.disconnect()
self._message_cache = None
self._message_cache_timestamp = None
class ImmediateReadGroupInterface(ReadGroupInterface):
@staticmethod
def is_message_after_timestamp(message, timestamp):
"""
Every message is a good message, expect a NULL one.
:param message: Message to inspect.
:param timestamp: Timestamp to compare the message to.
:return: True if the message is not None.
"""
# Receive might timeout, in this case we have nothing to compare.
if not message:
return False
return True
def read(self, current_position_index=None, retry=False):
# Invalidate cache on retry attempt.
if retry:
self._message_cache_position_index = None
# Message for this position already cached.
if current_position_index is not None and current_position_index == self._message_cache_position_index:
return self._read_pvs_from_cache(self.properties)
return super(ImmediateReadGroupInterface, self).read(current_position_index=current_position_index,
retry=retry)

View File

@@ -0,0 +1,208 @@
import time
from itertools import count
from pyscan import config
from pyscan.utils import convert_to_list, validate_lists_length, connect_to_pv, compare_channel_value
class PyEpicsDal(object):
"""
Provide a high level abstraction over PyEpics with group support.
"""
def __init__(self):
self.groups = {}
self.pvs = {}
def add_group(self, group_name, group_interface):
# Do not allow to overwrite the group.
if group_name in self.groups:
raise ValueError("Group with name %s already exists. "
"Use different name of close existing group first." % group_name)
self.groups[group_name] = group_interface
return group_name
def add_reader_group(self, group_name, pv_names):
self.add_group(group_name, ReadGroupInterface(pv_names))
def add_writer_group(self, group_name, pv_names, readback_pv_names=None, tolerances=None, timeout=None):
self.add_group(group_name, WriteGroupInterface(pv_names, readback_pv_names, tolerances, timeout))
def get_group(self, handle):
return self.groups.get(handle)
def close_group(self, group_name):
if group_name not in self.groups:
raise ValueError("Group does not exist. Available groups:\n%s" % self.groups.keys())
# Close the PV connection.
self.groups[group_name].close()
del self.groups[group_name]
def close_all_groups(self):
for group in self.groups.values():
group.close()
self.groups.clear()
class WriteGroupInterface(object):
"""
Manage a group of Write PVs.
"""
default_timeout = 5
default_get_sleep = 0.1
def __init__(self, pv_names, readback_pv_names=None, tolerances=None, timeout=None):
"""
Initialize the write group.
:param pv_names: PV names (or name, list or single string) to connect to.
:param readback_pv_names: PV names (or name, list or single string) of readback PVs to connect to.
:param tolerances: Tolerances to be used for set_and_match. You can also specify them on the set_and_match
:param timeout: Timeout to reach the destination.
"""
self.pv_names = convert_to_list(pv_names)
self.pvs = [self.connect(pv_name) for pv_name in self.pv_names]
if readback_pv_names:
self.readback_pv_name = convert_to_list(readback_pv_names)
self.readback_pvs = [self.connect(pv_name) for pv_name in self.readback_pv_name]
else:
self.readback_pv_name = self.pv_names
self.readback_pvs = self.pvs
self.tolerances = self._setup_tolerances(tolerances)
# We also do not allow timeout to be zero.
self.timeout = timeout or self.default_timeout
# Verify if all provided lists are of same size.
validate_lists_length(self.pvs, self.readback_pvs, self.tolerances)
# Check if timeout is int or float.
if not isinstance(self.timeout, (int, float)):
raise ValueError("Timeout must be int or float, but %s was provided." % self.timeout)
def _setup_tolerances(self, tolerances):
"""
Construct the list of tolerances. No tolerance can be less then the minimal tolerance.
:param tolerances: Input tolerances.
:return: Tolerances adjusted to the minimum value, if needed.
"""
# If the provided tolerances are empty, substitute them with a list of default tolerances.
tolerances = convert_to_list(tolerances) or [config.max_float_tolerance] * len(self.pvs)
# Each tolerance needs to be at least the size of the minimum tolerance.
tolerances = [max(config.max_float_tolerance, tolerance) for tolerance in tolerances]
return tolerances
def set_and_match(self, values, tolerances=None, timeout=None):
"""
Set the value and wait for the PV to reach it, within tollerance.
:param values: Values to set (Must match the number of PVs in this group)
:param tolerances: Tolerances for each PV (Must match the number of PVs in this group)
:param timeout: Timeout, single value, to wait until the value is reached.
:raise ValueError if any position cannot be reached.
"""
values = convert_to_list(values)
if not tolerances:
tolerances = self.tolerances
else:
# We do not allow tolerances to be less than the default tolerance.
tolerances = self._setup_tolerances(tolerances)
if not timeout:
timeout = self.timeout
# Verify if all provided lists are of same size.
validate_lists_length(self.pvs, values, tolerances)
# Check if timeout is int or float.
if not isinstance(timeout, (int, float)):
raise ValueError("Timeout must be int or float, but %s was provided." % timeout)
# Write all the PV values.
for pv, value in zip(self.pvs, values):
pv.put(value)
# Boolean array to represent which PVs have reached their target value.s
within_tolerance = [False] * len(self.pvs)
initial_timestamp = time.time()
# Read values until all PVs have reached the desired value or time has run out.
while (not all(within_tolerance)) and (time.time() - initial_timestamp < timeout):
# Get only the PVs that have not yet reached the final position.
for index, pv, tolerance in ((index, pv, tolerance) for index, pv, tolerance, values_reached
in zip(count(), self.readback_pvs, tolerances, within_tolerance)
if not values_reached):
current_value = pv.get()
expected_value = values[index]
if compare_channel_value(current_value, expected_value, tolerance):
within_tolerance[index] = True
time.sleep(self.default_get_sleep)
if not all(within_tolerance):
error_message = ""
# Get the indexes that did not reach the supposed values.
for index in [index for index, reached_value in enumerate(within_tolerance) if not reached_value]:
expected_value = values[index]
pv_name = self.pv_names[index]
tolerance = tolerances[index]
error_message += "Cannot achieve value %s, on PV %s, with tolerance %s.\n" % \
(expected_value, pv_name, tolerance)
raise ValueError(error_message)
@staticmethod
def connect(pv_name):
return connect_to_pv(pv_name)
def close(self):
"""
Close all PV connections.
"""
for pv in self.pvs:
pv.disconnect()
class ReadGroupInterface(object):
"""
Manage group of read PVs.
"""
def __init__(self, pv_names):
"""
Initialize the group.
:param pv_names: PV names (or name, list or single string) to connect to.
"""
self.pv_names = convert_to_list(pv_names)
self.pvs = [self.connect(pv_name) for pv_name in self.pv_names]
def read(self, current_position_index=None, retry=None):
"""
Read PVs one by one.
:param current_position_index: Index of the current scan.
:param retry: Is this the first read attempt or a retry.
:return: Result
"""
result = []
for pv in self.pvs:
result.append(pv.get())
return result
@staticmethod
def connect(pv_name):
return connect_to_pv(pv_name)
def close(self):
"""
Close all PV connections.
"""
for pv in self.pvs:
pv.disconnect()

View File

@@ -0,0 +1,40 @@
from pyscan.utils import convert_to_list
class FunctionProxy(object):
"""
Provide an interface for using external methods as DAL.
"""
def __init__(self, functions):
"""
Initialize the function dal.
:param functions: List (or single item) of FUNCTION_VALUE type.
"""
self.functions = convert_to_list(functions)
def read(self, current_position_index=None, retry=False):
"""
Read the results from all the provided functions.
:return: Read results.
"""
results = []
for func in self.functions:
# The function either accepts the current position index, or nothing.
try:
result = func.call_function()
except TypeError:
result = func.call_function(current_position_index)
results.append(result)
return results
def write(self, values):
"""
Write the values to the provided functions.
:param values: Values to write.
"""
values = convert_to_list(values)
for func, value in zip(self.functions, values):
func.call_function(value)

View File

@@ -0,0 +1,118 @@
import json
from collections import OrderedDict
import requests
from bsread.data.helpers import get_channel_reader
from pyscan import config
SERVER_URL_PATHS = {
"run": "/run",
"data": "/data-bs"
}
class PShellFunction(object):
def __init__(self, script_name, parameters, server_url=None, scan_in_background=None, multiple_parameters=False,
return_values=None):
if server_url is None:
server_url = config.pshell_default_server_url
if scan_in_background is None:
scan_in_background = config.pshell_default_scan_in_background
self.script_name = script_name
self.parameters = parameters
self.server_url = server_url.rstrip("/")
self.scan_in_background = scan_in_background
self.multiple_parameters = multiple_parameters
self.return_values = return_values
@staticmethod
def _load_raw_data(server_url, data_path):
load_data_url = server_url + SERVER_URL_PATHS["data"] + "/" + data_path
raw_data = requests.get(url=load_data_url, stream=True).raw.read()
return raw_data
@classmethod
def read_raw_data(cls, data_path, server_url=None):
if server_url is None:
server_url = config.pshell_default_server_url
raw_data_bytes = cls._load_raw_data(server_url, data_path)
offset = 0
def read_chunk():
nonlocal offset
nonlocal raw_data_bytes
size = int.from_bytes(raw_data_bytes[offset:offset + 4], byteorder='big', signed=False)
# Offset for the size of the length.
offset += 4
data_chunk = raw_data_bytes[offset:offset + size]
offset += size
return data_chunk
# First chunk is main header.
main_header = json.loads(read_chunk().decode(), object_pairs_hook=OrderedDict)
# Second chunk is data header.
data_header = json.loads(read_chunk().decode(), object_pairs_hook=OrderedDict)
result_data = {}
for channel in data_header["channels"]:
raw_channel_data = read_chunk()
raw_channel_timestamp = read_chunk()
channel_name = channel["name"]
# Default encoding is small, other valid value is 'big'.
channel["encoding"] = "<" if channel.get("encoding", "little") else ">"
channel_value_reader = get_channel_reader(channel)
result_data[channel_name] = channel_value_reader(raw_channel_data)
return result_data
def read(self, current_position_index=None, retry=False):
parameters = self.get_scan_parameters(current_position_index)
run_request = {"script": self.script_name,
"pars": parameters,
"background": self.scan_in_background}
raw_scan_result = self._execute_scan(run_request)
scan_result = json.loads(raw_scan_result)
return scan_result
def get_scan_parameters(self, current_position_index):
if self.multiple_parameters:
try:
position_parameters = self.parameters[current_position_index]
except IndexError:
raise ValueError("Cannot find parameters for position index %s. Parameters: " %
(current_position_index, self.parameters))
return position_parameters
else:
return self.parameters
def _execute_scan(self, execution_parameters):
run_url = self.server_url + SERVER_URL_PATHS["run"]
result = requests.put(url=run_url, json=execution_parameters)
if result.status_code != 200:
raise Exception(result.text)
return result.text

View File

View File

@@ -0,0 +1,385 @@
from pyscan import scan, action_restore, ZigZagVectorPositioner, VectorPositioner, CompoundPositioner
from pyscan.scan import EPICS_READER
from pyscan.positioner.area import AreaPositioner, ZigZagAreaPositioner
from pyscan.positioner.line import ZigZagLinePositioner, LinePositioner
from pyscan.positioner.time import TimePositioner
from pyscan.scan_parameters import scan_settings
from pyscan.utils import convert_to_list
def _generate_scan_parameters(relative, writables, latency):
# If the scan is relative we collect the initial writables offset, and restore the state at the end of the scan.
offsets = None
finalization_action = []
if relative:
pv_names = [x.pv_name for x in convert_to_list(writables) or []]
reader = EPICS_READER(pv_names)
offsets = reader.read()
reader.close()
finalization_action.append(action_restore(writables))
settings = scan_settings(settling_time=latency)
return offsets, finalization_action, settings
def _convert_steps_parameter(steps):
n_steps = None
step_size = None
steps_list = convert_to_list(steps)
# If steps is a float or a list of floats, then this are step sizes.
if isinstance(steps_list[0], float):
step_size = steps_list
# If steps is an int, this is the number of steps.
elif isinstance(steps, int):
n_steps = steps
return n_steps, step_size
def lscan(writables, readables, start, end, steps, latency=0.0, relative=False,
passes=1, zigzag=False, before_read=None, after_read=None, title=None):
"""Line Scan: positioners change together, linearly from start to end positions.
Args:
writables(list of Writable): Positioners set on each step.
readables(list of Readable): Sensors to be sampled on each step.
start(list of float): start positions of writables.
end(list of float): final positions of writables.
steps(int or float or list of float): number of scan steps (int) or step size (float).
relative (bool, optional): if true, start and end positions are relative to
current at start of the scan
latency(float, optional): settling time for each step before readout, defaults to 0.0.
passes(int, optional): number of passes
zigzag(bool, optional): if true writables invert direction on each pass.
before_read (function, optional): callback on each step, before each readout. Callback may have as
optional parameters list of positions.
after_read (function, optional): callback on each step, after each readout. Callback may have as
optional parameters a ScanRecord object.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
n_steps, step_size = _convert_steps_parameter(steps)
if zigzag:
positioner_class = ZigZagLinePositioner
else:
positioner_class = LinePositioner
positioner = positioner_class(start=start, end=end, step_size=step_size,
n_steps=n_steps, offsets=offsets, passes=passes)
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
finalization=finalization_actions)
return result
def ascan(writables, readables, start, end, steps, latency=0.0, relative=False,
passes=1, zigzag=False, before_read=None, after_read=None, title=None):
"""
Area Scan: multi-dimentional scan, each positioner is a dimention.
:param writables: List of identifiers to write to at each step.
:param readables: List of identifiers to read from at each step.
:param start: Start position for writables.
:param end: Stop position for writables.
:param steps: Number of scan steps(integer) or step size (float).
:param latency: Settling time before each readout. Default = 0.
:param relative: Start and stop positions are relative to the current position.
:param passes: Number of passes for each scan.
:param zigzag: If True and passes > 1, invert moving direction on each pass.
:param before_read: List of callback functions on each step before readback.
:param after_read: List of callback functions on each step after readback.
:param title: Not used in this implementation - legacy.
:return: Data from the scan.
"""
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
n_steps, step_size = _convert_steps_parameter(steps)
if zigzag:
positioner_class = ZigZagAreaPositioner
else:
positioner_class = AreaPositioner
positioner = positioner_class(start=start, end=end, step_size=step_size,
n_steps=n_steps, offsets=offsets, passes=passes)
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
finalization=finalization_actions)
return result
def vscan(writables, readables, vector, line=False, latency=0.0, relative=False, passes=1, zigzag=False,
before_read=None, after_read=None, title=None):
"""Vector Scan: positioners change following values provided in a vector.
Args:
writables(list of Writable): Positioners set on each step.
readables(list of Readable): Sensors to be sampled on each step.
vector(list of list of float): table of positioner values.
line (bool, optional): if true, processs as line scan (1d)
relative (bool, optional): if true, start and end positions are relative to current at
start of the scan
latency(float, optional): settling time for each step before readout, defaults to 0.0.
passes(int, optional): number of passes
zigzag(bool, optional): if true writables invert direction on each pass.
before_read (function, optional): callback on each step, before each readout.
after_read (function, optional): callback on each step, after each readout.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
# The compound positioner does not allow you to do zigzag positioning.
if not line and zigzag:
raise ValueError("Area vector scan cannot use zigzag positioning.")
if zigzag:
positioner_class = ZigZagVectorPositioner
else:
positioner_class = VectorPositioner
# If the vector is treated as a line scan, move all motors to the next position at the same time.
if line:
positioner = positioner_class(positions=vector, passes=passes, offsets=offsets)
# The vector is treated as an area scan. Move motors one by one, covering all positions.
else:
vector = convert_to_list(vector)
if not all(isinstance(x, list) for x in vector):
raise ValueError("In case of area scan, a list of lists is required for a vector.")
positioner = CompoundPositioner([VectorPositioner(positions=x, passes=passes, offsets=offsets)
for x in vector])
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
finalization=finalization_actions)
return result
def rscan(writable, readables, regions, latency=0.0, relative=False, passes=1, zigzag=False, before_read=None,
after_read=None, title=None):
"""Region Scan: positioner scanned linearly, from start to end positions, in multiple regions.
Args:
writable(Writable): Positioner set on each step, for each region.
readables(list of Readable): Sensors to be sampled on each step.
regions (list of tuples (float,float, int) or (float,float, float)): each tuple define a scan region
(start, stop, steps) or (start, stop, step_size)
relative (bool, optional): if true, start and end positions are relative to
current at start of the scan
latency(float, optional): settling time for each step before readout, defaults to 0.0.
passes(int, optional): number of passes
zigzag(bool, optional): if true writable invert direction on each pass.
before_read (function, optional): callback on each step, before each readout. Callback may have as
optional parameters list of positions.
after_read (function, optional): callback on each step, after each readout. Callback may have as
optional parameters a ScanRecord object.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("Region scan not supported.")
def cscan(writables, readables, start, end, steps, latency=0.0, time=None, relative=False, passes=1, zigzag=False,
before_read=None, after_read=None, title=None):
"""Continuous Scan: positioner change continuously from start to end position and readables are sampled on the fly.
Args:
writable(Speedable or list of Motor): A positioner with a getSpeed method or
a list of motors.
readables(list of Readable): Sensors to be sampled on each step.
start(float or list of float): start positions of writables.
end(float or list of float): final positions of writabless.
steps(int or float or list of float): number of scan steps (int) or step size (float).
time (float, seconds): if not None then writables is Motor array and speeds are
set according to time.
relative (bool, optional): if true, start and end positions are relative to
current at start of the scan
latency(float, optional): sleep time in each step before readout, defaults to 0.0.
before_read (function, optional): callback on each step, before each readout.
Callback may have as optional parameters list of positions.
after_read (function, optional): callback on each step, after each readout.
Callback may have as optional parameters a ScanRecord object.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("Continuous scan not supported.")
def hscan(config, writable, readables, start, end, steps, passes=1, zigzag=False, before_stream=None, after_stream=None,
after_read=None, title=None):
"""Hardware Scan: values sampled by external hardware and received asynchronously.
Args:
config(dict): Configuration of the hardware scan. The "class" key provides the implementation class.
Other keys are implementation specific.
writable(Writable): A positioner appropriated to the hardware scan type.
readables(list of Readable): Sensors appropriated to the hardware scan type.
start(float): start positions of writable.
end(float): final positions of writables.
steps(int or float): number of scan steps (int) or step size (float).
before_stream (function, optional): callback before just before starting positioner move.
after_stream (function, optional): callback before just after stopping positioner move.
after_read (function, optional): callback on each readout.
Callback may have as optional parameters a ScanRecord object.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("Hardware scan not supported.")
def bscan(stream, records, before_read=None, after_read=None, title=None):
"""BS Scan: records all values in a beam synchronous stream.
Args:
stream(Stream): stream object
records(int): number of records to store
before_read (function, optional): callback on each step, before each readout.
Callback may have as optional parameters list of positions.
after_read (function, optional): callback on each step, after each readout.
Callback may have as optional parameters a ScanRecord object.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("BS scan not supported.")
def tscan(readables, points, interval, before_read=None, after_read=None, title=None):
"""Time Scan: sensors are sampled in fixed time intervals.
Args:
readables(list of Readable): Sensors to be sampled on each step.
points(int): number of samples.
interval(float): time interval between readouts. Minimum temporization is 0.001s
before_read (function, optional): callback on each step, before each readout.
after_read (function, optional): callback on each step, after each readout.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
positioner = TimePositioner(interval, points)
result = scan(positioner, readables, before_read=before_read, after_read=after_read)
return result
def mscan(trigger, readables, points, timeout=None, async=True, take_initial=False, before_read=None, after_read=None,
title=None):
"""Monitor Scan: sensors are sampled when received change event of the trigger device.
Args:
trigger(Device): Source of the sampling triggering.
readables(list of Readable): Sensors to be sampled on each step.
If trigger has cache and is included in readables, it is not read
for each step, but the change event value is used.
points(int): number of samples.
timeout(float, optional): maximum scan time in seconds.
async(bool, optional): if True then records are sampled and stored on event change callback. Enforce
reading only cached values of sensors.
If False, the scan execution loop waits for trigger cache update. Do not make
cache only access, but may loose change events.
take_initial(bool, optional): if True include current values as first record (before first trigger).
before_read (function, optional): callback on each step, before each readout.
after_read (function, optional): callback on each step, after each readout.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("Monitor scan not supported.")
def escan(name, title=None):
"""Epics Scan: execute an Epics Scan Record.
Args:
name(str): Name of scan record.
title(str, optional): plotting window name.
Returns:
ScanResult object.
"""
raise NotImplementedError("Epics scan not supported.")
def bsearch(writables, readable, start, end, steps, maximum=True, strategy="Normal", latency=0.0, relative=False,
before_read=None, after_read=None, title=None):
"""Binary search: searches writables in a binary search fashion to find a local maximum for the readable.
Args:
writables(list of Writable): Positioners set on each step.
readable(Readable): Sensor to be sampled.
start(list of float): start positions of writables.
end(list of float): final positions of writables.
steps(float or list of float): resolution of search for each writable.
maximum (bool , optional): if True (default) search maximum, otherwise minimum.
strategy (str , optional): "Normal": starts search midway to scan range and advance in the best direction.
Uses orthogonal neighborhood (4-neighborhood for 2d)
"Boundary": starts search on scan range.
"FullNeighborhood": Uses complete neighborhood (8-neighborhood for 2d)
latency(float, optional): settling time for each step before readout, defaults to 0.0.
relative (bool, optional): if true, start and end positions are relative to current at
start of the scan
before_read (function, optional): callback on each step, before each readout.
after_read (function, optional): callback on each step, after each readout.
title(str, optional): plotting window name.
Returns:
SearchResult object.
"""
raise NotImplementedError("Binary search scan not supported.")
def hsearch(writables, readable, range_min, range_max, initial_step, resolution, noise_filtering_steps=1, maximum=True,
latency=0.0, relative=False, before_read=None, after_read=None, title=None):
"""Hill Climbing search: searches writables in decreasing steps to find a local maximum for the readable.
Args:
writables(list of Writable): Positioners set on each step.
readable(Readable): Sensor to be sampled.
range_min(list of float): minimum positions of writables.
range_max(list of float): maximum positions of writables.
initial_step(float or list of float):initial step size for for each writable.
resolution(float or list of float): resolution of search for each writable (minimum step size).
noise_filtering_steps(int): number of aditional steps to filter noise
maximum (bool , optional): if True (default) search maximum, otherwise minimum.
latency(float, optional): settling time for each step before readout, defaults to 0.0.
relative (bool, optional): if true, range_min and range_max positions are relative to current at
start of the scan
before_read (function, optional): callback on each step, before each readout.
after_read (function, optional): callback on each step, after each readout.
title(str, optional): plotting window name.
Returns:
SearchResult object.
"""
raise NotImplementedError("Hill climbing scan not supported.")

View File

@@ -0,0 +1 @@
from .scan import *

View File

@@ -0,0 +1,713 @@
import traceback
from copy import deepcopy
from datetime import datetime
from time import sleep
import numpy as np
from pyscan.dal.epics_dal import PyEpicsDal
from pyscan.interface.pyScan.utils import PyScanDataProcessor
from pyscan.positioner.compound import CompoundPositioner
from pyscan.positioner.serial import SerialPositioner
from pyscan.positioner.vector import VectorPositioner
from pyscan.scan_parameters import scan_settings
from pyscan.scanner import Scanner
from pyscan.utils import convert_to_list, convert_to_position_list, compare_channel_value
READ_GROUP = "Measurements"
WRITE_GROUP = "Knobs"
MONITOR_GROUP = "Monitors"
class Scan(object):
def execute_scan(self):
after_executor = self.get_action_executor("In-loopPostAction")
# Wrap the post action executor to update the number of completed scans.
def progress_after_executor(scanner_instance, data):
# Execute other post actions.
after_executor(scanner_instance)
# Update progress.
self.n_done_measurements += 1
self.ProgDisp.Progress = 100.0 * (self.n_done_measurements /
self.n_total_positions)
def prepare_monitors(reader):
# If there are no monitors defined we have nothing to validate.
if not self.dimensions[-1]["Monitor"]:
return None
def validate_monitors(position, data):
monitor_values = reader.read()
combined_data = zip(self.dimensions[-1]['Monitor'],
self.dimensions[-1]['MonitorValue'],
self.dimensions[-1]['MonitorTolerance'],
self.dimensions[-1]['MonitorAction'],
self.dimensions[-1]['MonitorTimeout'],
monitor_values)
for pv, expected_value, tolerance, action, timeout, value in combined_data:
# Monitor value does not match.
if not compare_channel_value(value, expected_value, tolerance):
if action == "Abort":
raise ValueError("Monitor %s, expected value %s, tolerance %s, has value %s. Aborting."
% (pv, expected_value, tolerance, value))
elif action == "WaitAndAbort":
return False
else:
raise ValueError("MonitorAction %s, on PV %s, is not supported." % (pv, action))
return True
return validate_monitors
# Setup scan settings.
settings = scan_settings(settling_time=self.dimensions[-1]["KnobWaitingExtra"],
n_measurements=self.dimensions[-1]["NumberOfMeasurements"],
measurement_interval=self.dimensions[-1]["Waiting"])
data_processor = PyScanDataProcessor(self.outdict,
n_readbacks=self.n_readbacks,
n_validations=self.n_validations,
n_observables=self.n_observables,
n_measurements=settings.n_measurements)
self.scanner = Scanner(positioner=self.get_positioner(), data_processor=data_processor,
reader=self.epics_dal.get_group(READ_GROUP).read,
writer=self.epics_dal.get_group(WRITE_GROUP).set_and_match,
before_measurement_executor=self.get_action_executor("In-loopPreAction"),
after_measurement_executor=progress_after_executor,
initialization_executor=self.get_action_executor("PreAction"),
finalization_executor=self.get_action_executor("PostAction"),
data_validator=prepare_monitors(self.epics_dal.get_group(MONITOR_GROUP)),
settings=settings)
self.outdict.update(self.scanner.discrete_scan())
def get_positioner(self):
"""
Generate a positioner for the provided dimensions.
:return: Positioner object.
"""
# Read all the initial positions - in case we need to do an additive scan.
initial_positions = self.epics_dal.get_group(READ_GROUP).read()
positioners = []
knob_readback_offset = 0
for dimension in self.dimensions:
is_additive = bool(dimension.get("Additive", 0))
is_series = bool(dimension.get("Series", 0))
n_knob_readbacks = len(dimension["KnobReadback"])
# This dimension uses relative positions, read the PVs initial state.
# We also need initial positions for the series scan.
if is_additive or is_series:
offsets = convert_to_list(
initial_positions[knob_readback_offset:knob_readback_offset + n_knob_readbacks])
else:
offsets = None
# Series scan in this dimension, use StepByStepVectorPositioner.
if is_series:
# In the StepByStep positioner, the initial values need to be added to the steps.
positions = convert_to_list(dimension["ScanValues"])
positioners.append(SerialPositioner(positions, initial_positions=offsets,
offsets=offsets if is_additive else None))
# Line scan in this dimension, use VectorPositioner.
else:
positions = convert_to_position_list(convert_to_list(dimension["KnobExpanded"]))
positioners.append(VectorPositioner(positions, offsets=offsets))
# Increase the knob readback offset.
knob_readback_offset += n_knob_readbacks
# Assemble all individual positioners together.
positioner = CompoundPositioner(positioners)
return positioner
def get_action_executor(self, entry_name):
actions = []
max_waiting = 0
for dim_index, dim in enumerate(self.dimensions):
for action_index, action in enumerate(dim[entry_name]):
set_pv, read_pv, value, tolerance, timeout = action
if set_pv == "match":
raise NotImplementedError("match not yet implemented for PreAction.")
# Initialize the write group, to speed up in loop stuff.
group_name = "%s_%d_%d" % (entry_name, dim_index, action_index)
self.epics_dal.add_writer_group(group_name, set_pv, read_pv, tolerance, timeout)
actions.append((group_name, value))
if entry_name + "Waiting" in dim:
max_waiting = max(max_waiting, dim[entry_name + "Waiting"])
def execute(scanner):
for action in actions:
name = action[0]
value = action[1]
# Retrieve the epics group and write the value.
self.epics_dal.get_group(name).set_and_match(value)
sleep(max_waiting)
return execute
class DummyProgress(object):
def __init__(self):
# For Thomas?
self.Progress = 1
self.abortScan = 0
def __init__(self):
self.dimensions = None
self.epics_dal = None
self.scanner = None
self.outdict = None
self.all_read_pvs = None
self.n_readbacks = None
self.n_validations = None
self.n_observables = None
self.n_total_positions = None
self.n_measurements = None
# Accessed by some clients.
self.ProgDisp = Scan.DummyProgress()
self._pauseScan = 0
# Just to make old GUI work.
self._abortScan = 0
self.n_done_measurements = 0
@property
def abortScan(self):
return self._abort_scan
@abortScan.setter
def abortScan(self, value):
self._abortScan = value
if self._abortScan:
self.scanner.abort_scan()
@property
def pauseScan(self):
return self._pauseScan
@pauseScan.setter
def pauseScan(self, value):
self._pauseScan = value
if self._pauseScan:
self.scanner.pause_scan()
else:
self.scanner.resume_scan()
def initializeScan(self, inlist, dal=None):
"""
Initialize and verify the provided scan values.
:param inlist: List of dictionaries for each dimension.
:param dal: Which reader should be used to access the PVs. Default: PyEpicsDal.
:return: Dictionary with results.
"""
if not inlist:
raise ValueError("Provided inlist is empty.")
if dal is not None:
self.epics_dal = dal
else:
self.epics_dal = PyEpicsDal()
# Prepare the scan dimensions.
if isinstance(inlist, list):
self.dimensions = inlist
# In case it is a simple one dimensional scan.
else:
self.dimensions = [inlist]
try:
for index, dic in enumerate(self.dimensions):
# We read most of the PVs only if declared in the last dimension.
is_last_dimension = index == (len(self.dimensions) - 1)
# Just in case there are identical input dictionaries. (Normally, it may not happen.)
dic['ID'] = index
# Waiting time.
if is_last_dimension and ('Waiting' not in dic.keys()):
raise ValueError('Waiting for the scan was not given.')
# Validation channels - values just added to the results.
if 'Validation' in dic.keys():
if not isinstance(dic['Validation'], list):
raise ValueError('Validation should be a list of channels. Input dictionary %d.' % index)
else:
dic['Validation'] = []
# Relative scan.
if 'Additive' not in dic.keys():
dic['Additive'] = 0
# Step back when pause is invoked.
if is_last_dimension and ('StepbackOnPause' not in dic.keys()):
dic['StepbackOnPause'] = 1
# Number of measurments per position.
if is_last_dimension and ('NumberOfMeasurements' not in dic.keys()):
dic['NumberOfMeasurements'] = 1
# PVs to sample.
if is_last_dimension and ('Observable' not in dic.keys()):
raise ValueError('The observable is not given.')
elif is_last_dimension:
if not isinstance(dic['Observable'], list):
dic['Observable'] = [dic['Observable']]
self._setup_knobs(index, dic)
self._setup_knob_scan_values(index, dic)
self._setup_pre_actions(index, dic)
self._setup_inloop_pre_actions(index, dic)
self._setup_post_action(index, dic)
self._setup_inloop_post_action(index, dic)
# Total number of measurements
self.n_total_positions = 1
for dic in self.dimensions:
if not dic['Series']:
self.n_total_positions = self.n_total_positions * dic['Nstep']
else:
self.n_total_positions = self.n_total_positions * sum(dic['Nstep'])
self._setup_epics_dal()
# Monitors only in the last dimension.
self._setup_monitors(self.dimensions[-1])
# Prealocating the place for the output
self.outdict = {"ErrorMessage": None,
"KnobReadback": self.allocateOutput(),
"Validation": self.allocateOutput(),
"Observable": self.allocateOutput()}
except ValueError:
self.outdict = {"ErrorMessage": traceback.format_exc()}
# Backward compatibility.
self.ProgDisp.Progress = 0
self.ProgDisp.abortScan = 0
self._pauseScan = 0
self.abortScan = 0
self.n_done_measurements = 0
return self.outdict
def allocateOutput(self):
root_list = []
for dimension in reversed(self.dimensions):
n_steps = dimension['Nstep']
if dimension['Series']:
# For Series scan, each step of each knob represents another result.
current_dimension_list = []
for n_steps_in_knob in n_steps:
current_knob_list = []
for _ in range(n_steps_in_knob):
current_knob_list.append(deepcopy(root_list))
current_dimension_list.append(deepcopy(current_knob_list))
root_list = current_dimension_list
else:
# For line scan, each step represents another result.
current_dimension_list = []
for _ in range(n_steps):
current_dimension_list.append(deepcopy(root_list))
root_list = current_dimension_list
return root_list
def _setup_epics_dal(self):
# Collect all PVs that need to be read at each scan step.
self.all_read_pvs = []
all_write_pvs = []
all_readback_pvs = []
all_tolerances = []
max_knob_waiting = -1
self.n_readbacks = 0
for d in self.dimensions:
self.all_read_pvs.append(d['KnobReadback'])
self.n_readbacks += len(d['KnobReadback'])
# Collect all data need to write to PVs.
all_write_pvs.append(d["Knob"])
all_readback_pvs.append(d["KnobReadback"])
all_tolerances.append(d["KnobTolerance"])
max_knob_waiting = max(max_knob_waiting, max(d["KnobWaiting"]))
self.all_read_pvs.append(self.dimensions[-1]['Validation'])
self.n_validations = len(self.dimensions[-1]['Validation'])
self.all_read_pvs.append(self.dimensions[-1]['Observable'])
self.n_observables = len(self.dimensions[-1]['Observable'])
# Expand all read PVs
self.all_read_pvs = [item for sublist in self.all_read_pvs for item in sublist]
# Expand Knobs and readbacks PVs.
all_write_pvs = [item for sublist in all_write_pvs for item in sublist]
all_readback_pvs = [item for sublist in all_readback_pvs for item in sublist]
all_tolerances = [item for sublist in all_tolerances for item in sublist]
# Initialize PV connections and check if all PV names are valid.
self.epics_dal.add_reader_group(READ_GROUP, self.all_read_pvs)
self.epics_dal.add_writer_group(WRITE_GROUP, all_write_pvs, all_readback_pvs, all_tolerances, max_knob_waiting)
def _setup_knobs(self, index, dic):
"""
Setup the values for moving knobs in the scan.
:param index: Index in the dictionary.
:param dic: The dictionary.
"""
if 'Knob' not in dic.keys():
raise ValueError('Knob for the scan was not given for the input dictionary %d.' % index)
else:
if not isinstance(dic['Knob'], list):
dic['Knob'] = [dic['Knob']]
if 'KnobReadback' not in dic.keys():
dic['KnobReadback'] = dic['Knob']
if not isinstance(dic['KnobReadback'], list):
dic['KnobReadback'] = [dic['KnobReadback']]
if len(dic['KnobReadback']) != len(dic['Knob']):
raise ValueError('The number of KnobReadback does not meet to the number of Knobs.')
if 'KnobTolerance' not in dic.keys():
dic['KnobTolerance'] = [1.0] * len(dic['Knob'])
if not isinstance(dic['KnobTolerance'], list):
dic['KnobTolerance'] = [dic['KnobTolerance']]
if len(dic['KnobTolerance']) != len(dic['Knob']):
raise ValueError('The number of KnobTolerance does not meet to the number of Knobs.')
if 'KnobWaiting' not in dic.keys():
dic['KnobWaiting'] = [10.0] * len(dic['Knob'])
if not isinstance(dic['KnobWaiting'], list):
dic['KnobWaiting'] = [dic['KnobWaiting']]
if len(dic['KnobWaiting']) != len(dic['Knob']):
raise ValueError('The number of KnobWaiting does not meet to the number of Knobs.')
if 'KnobWaitingExtra' not in dic.keys():
dic['KnobWaitingExtra'] = 0.0
else:
try:
dic['KnobWaitingExtra'] = float(dic['KnobWaitingExtra'])
except:
raise ValueError('KnobWaitingExtra is not a number in the input dictionary %d.' % index)
# Originally dic["Knob"] values were saved. I'm supposing this was a bug - readback values needed to be saved.
# TODO: We can optimize this by moving the initialization in the epics_dal init
# but pre actions need to be moved after the epics_dal init than
self.epics_dal.add_reader_group("KnobReadback", dic['KnobReadback'])
dic['KnobSaved'] = self.epics_dal.get_group("KnobReadback").read()
self.epics_dal.close_group("KnobReadback")
def _setup_knob_scan_values(self, index, dic):
if 'Series' not in dic.keys():
dic['Series'] = 0
if not dic['Series']: # Setting up scan values for SKS and MKS
if 'ScanValues' not in dic.keys():
if 'ScanRange' not in dic.keys():
raise ValueError('Neither ScanRange nor ScanValues is given '
'in the input dictionary %d.' % index)
elif not isinstance(dic['ScanRange'], list):
raise ValueError('ScanRange is not given in the right format. '
'Input dictionary %d.' % index)
elif not isinstance(dic['ScanRange'][0], list):
dic['ScanRange'] = [dic['ScanRange']]
if ('Nstep' not in dic.keys()) and ('StepSize' not in dic.keys()):
raise ValueError('Neither Nstep nor StepSize is given.')
if 'Nstep' in dic.keys(): # StepSize is ignored when Nstep is given
if not isinstance(dic['Nstep'], int):
raise ValueError('Nstep should be an integer. Input dictionary %d.' % index)
ran = []
for r in dic['ScanRange']:
s = (r[1] - r[0]) / (dic['Nstep'] - 1)
f = np.arange(r[0], r[1], s)
f = np.append(f, np.array(r[1]))
ran.append(f.tolist())
dic['KnobExpanded'] = ran
else: # StepSize given
if len(dic['Knob']) > 1:
raise ValueError('Give Nstep instead of StepSize for MKS. '
'Input dictionary %d.' % index)
# StepSize is only valid for SKS
r = dic['ScanRange'][0]
# TODO: THIS IS RECONSTRUCTED AND MIGHT BE WRONG, CHECK!
s = dic['StepSize'][0]
f = np.arange(r[0], r[1], s)
f = np.append(f, np.array(r[1]))
dic['Nstep'] = len(f)
dic['KnobExpanded'] = [f.tolist()]
else:
# Scan values explicitly defined.
if not isinstance(dic['ScanValues'], list):
raise ValueError('ScanValues is not given in the right fromat. '
'Input dictionary %d.' % index)
if len(dic['ScanValues']) != len(dic['Knob']) and len(dic['Knob']) != 1:
raise ValueError('The length of ScanValues does not meet to the number of Knobs.')
if len(dic['Knob']) > 1:
minlen = 100000
for r in dic['ScanValues']:
if minlen > len(r):
minlen = len(r)
ran = []
for r in dic['ScanValues']:
ran.append(r[0:minlen]) # Cut at the length of the shortest list.
dic['KnobExpanded'] = ran
dic['Nstep'] = minlen
else:
dic['KnobExpanded'] = [dic['ScanValues']]
dic['Nstep'] = len(dic['ScanValues'])
else: # Setting up scan values for Series scan
if 'ScanValues' not in dic.keys():
raise ValueError('ScanValues should be given for Series '
'scan in the input dictionary %d.' % index)
if not isinstance(dic['ScanValues'], list):
raise ValueError('ScanValues should be given as a list (of lists) '
'for Series scan in the input dictionary %d.' % index)
if len(dic['Knob']) != len(dic['ScanValues']):
raise ValueError('Scan values length does not match to the '
'number of knobs in the input dictionary %d.' % index)
Nstep = []
for vl in dic['ScanValues']:
if not isinstance(vl, list):
raise ValueError('ScanValue element should be given as a list for '
'Series scan in the input dictionary %d.' % index)
Nstep.append(len(vl))
dic['Nstep'] = Nstep
def _setup_pre_actions(self, index, dic):
if 'PreAction' in dic.keys():
if not isinstance(dic['PreAction'], list):
raise ValueError('PreAction should be a list. Input dictionary %d.' % index)
for l in dic['PreAction']:
if not isinstance(l, list):
raise ValueError('Every PreAction should be a list. Input dictionary %d.' % index)
if len(l) != 5:
if not l[0] == 'SpecialAction':
raise ValueError('Every PreAction should be in a form of '
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
'Input dictionary ' + str(index) + '.')
if 'PreActionWaiting' not in dic.keys():
dic['PreActionWaiting'] = 0.0
if not isinstance(dic['PreActionWaiting'], float) and not isinstance(dic['PreActionWaiting'], int):
raise ValueError('PreActionWating should be a float. Input dictionary %d.' % index)
if 'PreActionOrder' not in dic.keys():
dic['PreActionOrder'] = [0] * len(dic['PreAction'])
if not isinstance(dic['PreActionOrder'], list):
raise ValueError('PreActionOrder should be a list. Input dictionary %d.' % index)
else:
dic['PreAction'] = []
dic['PreActionWaiting'] = 0.0
dic['PreActionOrder'] = [0] * len(dic['PreAction'])
def _setup_inloop_pre_actions(self, index, dic):
if 'In-loopPreAction' in dic.keys():
if not isinstance(dic['In-loopPreAction'], list):
raise ValueError('In-loopPreAction should be a list. Input dictionary %d.' % index)
for l in dic['In-loopPreAction']:
if not isinstance(l, list):
raise ValueError('Every In-loopPreAction should be a list. '
'Input dictionary ' + str(index) + '.')
if len(l) != 5:
if not l[0] == 'SpecialAction':
raise ValueError('Every In-loopPreAction should be in a form of '
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
'Input dictionary ' + str(index) + '.')
if 'In-loopPreActionWaiting' not in dic.keys():
dic['In-loopPreActionWaiting'] = 0.0
if not isinstance(dic['In-loopPreActionWaiting'], float) and not isinstance(
dic['In-loopPreActionWaiting'], int):
raise ValueError('In-loopPreActionWating should be a float. Input dictionary %d.' % index)
if 'In-loopPreActionOrder' not in dic.keys():
dic['In-loopPreActionOrder'] = [0] * len(dic['In-loopPreAction'])
if not isinstance(dic['In-loopPreActionOrder'], list):
raise ValueError('In-loopPreActionOrder should be a list. Input dictionary %d.' % index)
else:
dic['In-loopPreAction'] = []
dic['In-loopPreActionWaiting'] = 0.0
dic['In-loopPreActionOrder'] = [0] * len(dic['In-loopPreAction'])
def _setup_post_action(self, index, dic):
if 'PostAction' in dic.keys():
if dic['PostAction'] == 'Restore':
PA = []
for index in range(0, len(dic['Knob'])):
k = dic['Knob'][index]
v = dic['KnobSaved'][index]
PA.append([k, k, v, 1.0, 10])
dic['PostAction'] = PA
elif not isinstance(dic['PostAction'], list):
raise ValueError('PostAction should be a list. Input dictionary %d.' % index)
Restore = 0
for index in range(0, len(dic['PostAction'])):
l = dic['PostAction'][index]
if l == 'Restore':
Restore = 1
PA = []
for j in range(0, len(dic['Knob'])):
k = dic['Knob'][j]
v = dic['KnobSaved'][j]
PA.append([k, k, v, 1.0, 10])
elif not isinstance(l, list):
raise ValueError('Every PostAction should be a list. Input dictionary %d.' % index)
elif len(l) != 5:
if not l[0] == 'SpecialAction':
raise ValueError('Every PostAction should be in a form of '
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
'Input dictionary %d.' % index)
if Restore:
dic['PostAction'].remove('Restore')
dic['PostAction'] = dic['PostAction'] + PA
else:
dic['PostAction'] = []
def _setup_inloop_post_action(self, index, dic):
if 'In-loopPostAction' in dic.keys():
if dic['In-loopPostAction'] == 'Restore':
PA = []
for index in range(0, len(dic['Knob'])):
k = dic['Knob'][index]
v = dic['KnobSaved'][index]
PA.append([k, k, v, 1.0, 10])
dic['In-loopPostAction'] = PA
elif not isinstance(dic['In-loopPostAction'], list):
raise ValueError('In-loopPostAction should be a list. Input dictionary %d.' % index)
Restore = 0
for index in range(0, len(dic['In-loopPostAction'])):
l = dic['In-loopPostAction'][index]
if l == 'Restore':
Restore = 1
PA = []
for j in range(0, len(dic['Knob'])):
k = dic['Knob'][j]
v = dic['KnobSaved'][j]
PA.append([k, k, v, 1.0, 10])
dic['In-loopPostAction'][index] = PA
elif not isinstance(l, list):
raise ValueError('Every In-loopPostAction should be a list. '
'Input dictionary %d.' % index)
elif len(l) != 5:
raise ValueError('Every In-loopPostAction should be in a form of '
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
'Input dictionary %d.' % index)
if Restore:
dic['In-loopPostAction'].remove('Restore')
dic['In-loopPostAction'] = dic['In-loopPostAction'] + PA
else:
dic['In-loopPostAction'] = []
def _setup_monitors(self, dic):
if ('Monitor' in dic.keys()) and (dic['Monitor']):
if isinstance(dic['Monitor'], str):
dic['Monitor'] = [dic['Monitor']]
# Initialize monitor group and check if all monitor PVs are valid.
self.epics_dal.add_reader_group(MONITOR_GROUP, dic["Monitor"])
if 'MonitorValue' not in dic.keys():
dic["MonitorValue"] = self.epics_dal.get_group(MONITOR_GROUP).read()
elif not isinstance(dic['MonitorValue'], list):
dic['MonitorValue'] = [dic['MonitorValue']]
if len(dic['MonitorValue']) != len(dic['Monitor']):
raise ValueError('The length of MonitorValue does not meet to the length of Monitor.')
# Try to construct the monitor tolerance, if not given.
if 'MonitorTolerance' not in dic.keys():
dic['MonitorTolerance'] = []
for value in self.epics_dal.get_group(MONITOR_GROUP).read():
if isinstance(value, str):
# No tolerance for string values.
dic['MonitorTolerance'].append(None)
elif value == 0:
# Default tolerance for unknown values is 0.1.
dic['MonitorTolerance'].append(0.1)
else:
# 10% of the current value will be the torelance when not given
dic['MonitorTolerance'].append(abs(value * 0.1))
elif not isinstance(dic['MonitorTolerance'], list):
dic['MonitorTolerance'] = [dic['MonitorTolerance']]
if len(dic['MonitorTolerance']) != len(dic['Monitor']):
raise ValueError('The length of MonitorTolerance does not meet to the length of Monitor.')
if 'MonitorAction' not in dic.keys():
raise ValueError('MonitorAction is not give though Monitor is given.')
if not isinstance(dic['MonitorAction'], list):
dic['MonitorAction'] = [dic['MonitorAction']]
for m in dic['MonitorAction']:
if m != 'Abort' and m != 'Wait' and m != 'WaitAndAbort':
raise ValueError('MonitorAction shold be Wait, Abort, or WaitAndAbort.')
if 'MonitorTimeout' not in dic.keys():
dic['MonitorTimeout'] = [30.0] * len(dic['Monitor'])
elif not isinstance(dic['MonitorTimeout'], list):
dic['MonitorValue'] = [dic['MonitorValue']]
if len(dic['MonitorValue']) != len(dic['Monitor']):
raise ValueError('The length of MonitorValue does not meet to the length of Monitor.')
for m in dic['MonitorTimeout']:
try:
float(m)
except:
raise ValueError('MonitorTimeout should be a list of float(or int).')
else:
dic['Monitor'] = []
dic['MonitorValue'] = []
dic['MonitorTolerance'] = []
dic['MonitorAction'] = []
dic['MonitorTimeout'] = []
def startScan(self):
if self.outdict['ErrorMessage']:
if 'After the last scan,' not in self.outdict['ErrorMessage']:
self.outdict['ErrorMessage'] = 'It seems that the initialization was not successful... ' \
'No scan was performed.'
return self.outdict
# Execute the scan.
self.outdict['TimeStampStart'] = datetime.now()
self.execute_scan()
self.outdict['TimeStampEnd'] = datetime.now()
self.outdict['ErrorMessage'] = 'Measurement finalized (finished/aborted) normally. ' \
'Need initialisation before next measurement.'
# Cleanup after the scan.
self.epics_dal.close_all_groups()
return self.outdict

View File

@@ -0,0 +1,41 @@
from pyscan.utils import flat_list_generator
class PyScanDataProcessor(object):
def __init__(self, output, n_readbacks, n_validations, n_observables, n_measurements):
self.n_readbacks = n_readbacks
self.n_validations = n_validations
self.n_observables = n_observables
self.n_measurements = n_measurements
self.output = output
self.KnobReadback_output_position = flat_list_generator(self.output["KnobReadback"])
self.Validation_output_position = flat_list_generator(self.output["Validation"])
self.Observable_output_position = flat_list_generator(self.output["Observable"])
def process(self, position, data):
# Just we can always iterate over it.
if self.n_measurements == 1:
data = [data]
# Cells for each measurement are already prepared.
readback_result = [measurement[0:self.n_readbacks]
for measurement in data]
validation_result = [measurement[self.n_readbacks:self.n_readbacks + self.n_validations]
for measurement in data]
interval_start = self.n_readbacks + self.n_validations
interval_end = self.n_readbacks + self.n_validations + self.n_observables
observable_result = [measurement[interval_start:interval_end]
for measurement in data]
if self.n_measurements == 1:
next(self.KnobReadback_output_position).extend(readback_result[0])
next(self.Validation_output_position).extend(validation_result[0])
next(self.Observable_output_position).extend(observable_result[0])
else:
next(self.KnobReadback_output_position).extend(readback_result)
next(self.Validation_output_position).extend(validation_result)
next(self.Observable_output_position).extend(observable_result)
def get_data(self):
return self.output

View File

View File

@@ -0,0 +1,184 @@
import math
from copy import copy
from pyscan.utils import convert_to_list
class AreaPositioner(object):
def _validate_parameters(self):
if not len(self.start) == len(self.end):
raise ValueError("Number of start %s and end %s positions do not match." %
(self.start, self.end))
if (self.n_steps and self.step_size) or (not self.n_steps and not self.step_size):
raise ValueError("N_steps (%s) or step_sizes (%s) must be set, but not none "
"or both of them at the same time." % (self.step_size, self.n_steps))
if self.n_steps and (not len(self.n_steps) == len(self.start)):
raise ValueError("The number of n_steps %s does not match the number of start positions %s." %
(self.n_steps, self.start))
if self.n_steps and not all(isinstance(x, int) for x in self.n_steps):
raise ValueError("The n_steps %s must have only integers." % self.n_steps)
if self.step_size and (not len(self.step_size) == len(self.start)):
raise ValueError("The number of step sizes %s does not match the number of start positions %s." %
(self.step_size, self.start))
if not isinstance(self.passes, int) or self.passes < 1:
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
if self.offsets and (not len(self.offsets) == len(self.start)):
raise ValueError("Number of offsets %s does not match the number of start positions %s." %
(self.offsets, self.start))
def __init__(self, start, end, n_steps=None, step_size=None, passes=1, offsets=None):
self.start = convert_to_list(start)
self.end = convert_to_list(end)
self.n_steps = convert_to_list(n_steps)
self.step_size = convert_to_list(step_size)
self.passes = passes
self.offsets = convert_to_list(offsets)
self._validate_parameters()
# Get the number of axis to scan.
self.n_axis = len(self.start)
# Fix the offsets if provided.
if self.offsets:
self.start = [offset + original_value for original_value, offset in zip(self.start, self.offsets)]
self.end = [offset + original_value for original_value, offset in zip(self.end, self.offsets)]
# Number of steps case.
if self.n_steps:
self.step_size = [(end - start) / steps for start, end, steps
in zip(self.start, self.end, self.n_steps)]
# Step size case.
elif self.step_size:
self.n_steps = [math.floor((end - start) / step_size) for start, end, step_size
in zip(self.start, self.end, self.step_size)]
def get_generator(self):
for _ in range(self.passes):
positions = copy(self.start)
# Return the initial state.
yield copy(positions)
# Recursive call to print all axis values.
def scan_axis(axis_number):
# We should not scan axis that do not exist.
if not axis_number < self.n_axis:
return
# Output all position on the next axis while this axis is still at the start position.
yield from scan_axis(axis_number + 1)
# Move axis step by step.
for _ in range(self.n_steps[axis_number]):
positions[axis_number] = positions[axis_number] + self.step_size[axis_number]
yield copy(positions)
# Output all positions from the next axis for each value of this axis.
yield from scan_axis(axis_number + 1)
# Clean up after the loop - return the axis value back to the start value.
positions[axis_number] = self.start[axis_number]
yield from scan_axis(0)
class ZigZagAreaPositioner(AreaPositioner):
def get_generator(self):
for pass_number in range(self.passes):
# Directions (positive ascending, negative descending) for each axis.
directions = [1] * self.n_axis
positions = copy(self.start)
# Return the initial state.
yield copy(positions)
# Recursive call to print all axis values.
def scan_axis(axis_number):
# We should not scan axis that do not exist.
if not axis_number < self.n_axis:
return
# Output all position on the next axis while this axis is still at the start position.
yield from scan_axis(axis_number + 1)
# Move axis step by step.
for _ in range(self.n_steps[axis_number]):
positions[axis_number] = positions[axis_number] + (self.step_size[axis_number]
* directions[axis_number])
yield copy(positions)
# Output all positions from the next axis for each value of this axis.
yield from scan_axis(axis_number + 1)
# Invert the direction for the next iteration on this axis.
directions[axis_number] *= -1
yield from scan_axis(0)
class MultiAreaPositioner(object):
def __init__(self, start, end, steps, passes=1, offsets=None):
self.offsets = offsets
self.passes = passes
self.end = end
self.start = start
# Get the number of axis to scan.
self.n_axis = len(self.start)
# Fix the offsets if provided.
if self.offsets:
self.start = [[original_value + offset for original_value, offset in zip(original_values, offsets)]
for original_values, offsets in zip(self.start, self.offsets)]
self.end = [[original_value + offset for original_value, offset in zip(original_values, offsets)]
for original_values, offsets in zip(self.end, self.offsets)]
# Number of steps case.
if isinstance(steps[0][0], int):
# TODO: Verify that each axis has positive steps and that all are ints (all steps or step_size)
self.n_steps = steps
self.step_size = [[(end - start) / steps for start, end, steps in zip(starts, ends, line_steps)]
for starts, ends, line_steps in zip(self.start, self.end, steps)]
# Step size case.
elif isinstance(steps[0][0], float):
# TODO: Verify that each axis has the same number of steps and that the step_size is correct (positive etc.)
self.n_steps = [[math.floor((end - start) / step) for start, end, step in zip(starts, ends, line_steps)]
for starts, ends, line_steps in zip(self.start, self.end, steps)]
self.step_size = steps
# Something went wrong
else:
# TODO: Raise an exception.
pass
def get_generator(self):
for _ in range(self.passes):
positions = copy(self.start)
# Return the initial state.
yield copy(positions)
# Recursive call to print all axis values.
def scan_axis(axis_number):
# We should not scan axis that do not exist.
if not axis_number < self.n_axis:
return
# Output all position on the next axis while this axis is still at the start position.
yield from scan_axis(axis_number + 1)
# Move axis step by step.
# TODO: Figure out what to do with this steps.
for _ in range(self.n_steps[axis_number][0]):
positions[axis_number] = [position + step_size for position, step_size
in zip(positions[axis_number], self.step_size[axis_number])]
yield copy(positions)
# Output all positions from the next axis for each value of this axis.
yield from scan_axis(axis_number + 1)
# Clean up after the loop - return the axis value back to the start value.
positions[axis_number] = self.start[axis_number]
yield from scan_axis(0)

View File

@@ -0,0 +1,21 @@
class BsreadPositioner(object):
def __init__(self, n_messages):
"""
Acquire N consecutive messages from the stream.
:param n_messages: Number of messages to acquire.
"""
self.n_messages = n_messages
self.bs_reader = None
def set_bs_reader(self, bs_reader):
self.bs_reader = bs_reader
def get_generator(self):
if self.bs_reader is None:
raise RuntimeError("Set bs_reader before using this generator.")
for index in range(self.n_messages):
self.bs_reader.read(index)
yield index

View File

@@ -0,0 +1,21 @@
from copy import copy
from pyscan.utils import convert_to_list
class CompoundPositioner(object):
"""
Given a list of positioners, it compounds them in given order, getting values from each of them at every step.
"""
def __init__(self, positioners):
self.positioners = positioners
self.n_positioners = len(positioners)
def get_generator(self):
def walk_positioner(index, output_positions):
if index == self.n_positioners:
yield copy(output_positions)
else:
for current_positions in self.positioners[index].get_generator():
yield from walk_positioner(index+1, output_positions + convert_to_list(current_positions))
yield from walk_positioner(0, [])

View File

@@ -0,0 +1,91 @@
import math
from copy import copy
from pyscan.utils import convert_to_list
class LinePositioner(object):
def _validate_parameters(self):
if not len(self.start) == len(self.end):
raise ValueError("Number of start %s and end %s positions do not match." %
(self.start, self.end))
# Only 1 among n_steps and step_sizes must be set.
if (self.n_steps is not None and self.step_size) or (self.n_steps is None and not self.step_size):
raise ValueError("N_steps (%s) or step_sizes (%s) must be set, but not none "
"or both of them at the same time." % (self.step_size, self.n_steps))
# If n_steps is set, than it must be an integer greater than 0.
if (self.n_steps is not None) and (not isinstance(self.n_steps, int) or self.n_steps < 1):
raise ValueError("Steps must be a positive integer value, but %s was given." % self.n_steps)
if self.step_size and (not len(self.step_size) == len(self.start)):
raise ValueError("The number of step sizes %s does not match the number of start positions %s." %
(self.step_size, self.start))
if not isinstance(self.passes, int) or self.passes < 1:
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
if self.offsets and (not len(self.offsets) == len(self.start)):
raise ValueError("Number of offsets %s does not match the number of start positions %s." %
(self.offsets, self.start))
def __init__(self, start, end, n_steps=None, step_size=None, passes=1, offsets=None):
self.start = convert_to_list(start)
self.end = convert_to_list(end)
self.n_steps = n_steps
self.step_size = convert_to_list(step_size)
self.passes = passes
self.offsets = convert_to_list(offsets)
self._validate_parameters()
# Fix the offsets if provided.
if self.offsets:
self.start = [offset + original_value for original_value, offset in zip(self.start, self.offsets)]
self.end = [offset + original_value for original_value, offset in zip(self.end, self.offsets)]
# Number of steps case.
if self.n_steps:
self.step_size = [(end - start) / self.n_steps for start, end in zip(self.start, self.end)]
# Step size case.
elif self.step_size:
n_steps_per_axis = [math.floor((end - start) / step_size) for start, end, step_size
in zip(self.start, self.end, self.step_size)]
# Verify that all axis do the same number of steps.
if not all(x == n_steps_per_axis[0] for x in n_steps_per_axis):
raise ValueError("The step sizes %s must give the same number of steps for each start %s "
"and end % pair." % (self.step_size, self.start, self.end))
# All the elements in n_steps_per_axis must be the same anyway.
self.n_steps = n_steps_per_axis[0]
def get_generator(self):
for _ in range(self.passes):
# The initial position is always the start position.
current_positions = copy(self.start)
yield current_positions
for __ in range(self.n_steps):
current_positions = [position + step_size for position, step_size
in zip(current_positions, self.step_size)]
yield current_positions
class ZigZagLinePositioner(LinePositioner):
def get_generator(self):
# The initial position is always the start position.
current_positions = copy(self.start)
yield current_positions
for pass_number in range(self.passes):
# Positive direction means we increase the position each step, negative we decrease.
direction = 1 if pass_number % 2 == 0 else -1
for __ in range(self.n_steps):
current_positions = [position + (step_size * direction) for position, step_size
in zip(current_positions, self.step_size)]
yield current_positions

View File

@@ -0,0 +1,40 @@
from copy import copy
from pyscan.utils import convert_to_list
class SerialPositioner(object):
"""
Scan over all provided points, one by one, returning the previous to the initial state.
Each axis is treated as a separate line.
"""
def __init__(self, positions, initial_positions, passes=1, offsets=None):
self.positions = positions
self.passes = passes
self.offsets = offsets
if passes < 1:
raise ValueError("Number of passes cannot be less than 1, but %d was provided." % passes)
self.initial_positions = initial_positions
self.n_axis = len(self.initial_positions)
# In case only 1 axis is provided, still wrap it in a list, because it makes the generator code easier.
if self.n_axis == 1:
self.positions = [positions]
# Fix the offset if provided.
if self.offsets:
for axis_positions, offset in zip(self.positions, self.offsets):
axis_positions[:] = [original_position + offset for original_position in axis_positions]
def get_generator(self):
for _ in range(self.passes):
# For each axis.
for axis_index in range(self.n_axis):
current_state = copy(self.initial_positions)
n_steps_in_axis = len(self.positions[axis_index])
for axis_position_index in range(n_steps_in_axis):
current_state[axis_index] = convert_to_list(self.positions[axis_index])[axis_position_index]
yield copy(current_state)

View File

@@ -0,0 +1,12 @@
class StaticPositioner(object):
def __init__(self, n_images):
"""
Acquire N consecutive images in a static position.
:param n_images: Number of images to acquire.
"""
self.n_images = n_images
def get_generator(self):
for index in range(self.n_images):
yield index

View File

@@ -0,0 +1,52 @@
from time import time, sleep
from pyscan.config import max_time_tolerance
smoothing_factor = 0.95
class TimePositioner(object):
def __init__(self, time_interval, n_intervals, tolerance=None):
"""
Time interval at which to read data.
:param time_interval: Time interval in seconds.
:param n_intervals: How many intervals to measure.
"""
self.time_interval = time_interval
# Tolerance cannot be less than the min set tolerance.
if tolerance is None or tolerance < max_time_tolerance:
tolerance = max_time_tolerance
self.tolerance = tolerance
# Minimum one measurement.
if n_intervals < 1:
n_intervals = 1
self.n_intervals = n_intervals
def get_generator(self):
measurement_time_start = time()
last_time_to_sleep = 0
for _ in range(self.n_intervals):
measurement_time_stop = time()
# How much time did the measurement take.
measurement_time = measurement_time_stop - measurement_time_start
time_to_sleep = self.time_interval - measurement_time
# Use the smoothing factor to attenuate variations in the measurement time.
time_to_sleep = (smoothing_factor * time_to_sleep) + ((1-smoothing_factor) * last_time_to_sleep)
# Time to sleep is negative (more time has elapsed, we cannot achieve the requested time interval.
if time_to_sleep < (-1 * max_time_tolerance):
raise ValueError("The requested time interval cannot be achieved. Last iteration took %.2f seconds, "
"but a %.2f seconds time interval was set." % (measurement_time, self.time_interval))
# Sleep only if time to sleep is positive.
if time_to_sleep > 0:
sleep(time_to_sleep)
last_time_to_sleep = time_to_sleep
measurement_time_start = time()
# Return the timestamp at which the measurement should begin.
yield measurement_time_start

View File

@@ -0,0 +1,52 @@
from itertools import cycle, chain
from pyscan.utils import convert_to_list
class VectorPositioner(object):
"""
Moves over the provided positions.
"""
def _validate_parameters(self):
if not all(len(convert_to_list(x)) == len(convert_to_list(self.positions[0])) for x in self.positions):
raise ValueError("All positions %s must have the same number of axis." % self.positions)
if not isinstance(self.passes, int) or self.passes < 1:
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
if self.offsets and (not len(self.offsets) == len(self.positions[0])):
raise ValueError("Number of offsets %s does not match the number of positions %s." %
(self.offsets, self.positions[0]))
def __init__(self, positions, passes=1, offsets=None):
self.positions = convert_to_list(positions)
self.passes = passes
self.offsets = convert_to_list(offsets)
self._validate_parameters()
# Number of positions to move to.
self.n_positions = len(self.positions)
# Fix the offset if provided.
if self.offsets:
for step_positions in self.positions:
step_positions[:] = [original_position + offset
for original_position, offset in zip(step_positions, self.offsets)]
def get_generator(self):
for _ in range(self.passes):
for position in self.positions:
yield position
class ZigZagVectorPositioner(VectorPositioner):
def get_generator(self):
# This creates a generator for [0, 1, 2, 3... n, n-1, n-2.. 2, 1, 0.....]
indexes = cycle(chain(range(0, self.n_positions, 1), range(self.n_positions - 2, 0, -1)))
# First pass has the full number of items, each subsequent has one less (extreme sequence item).
n_indexes = self.n_positions + ((self.passes - 1) * (self.n_positions - 1))
for x in range(n_indexes):
yield self.positions[next(indexes)]

260
packages/pyscan/scan.py Normal file
View File

@@ -0,0 +1,260 @@
import logging
from pyscan.dal import epics_dal, bsread_dal, function_dal
from pyscan.dal.function_dal import FunctionProxy
from pyscan.positioner.bsread import BsreadPositioner
from pyscan.scanner import Scanner
from pyscan.scan_parameters import EPICS_PV, EPICS_CONDITION, BS_PROPERTY, BS_CONDITION, scan_settings, convert_input, \
FUNCTION_VALUE, FUNCTION_CONDITION, convert_conditions, ConditionAction, ConditionComparison
from pyscan.utils import convert_to_list, SimpleDataProcessor, ActionExecutor, compare_channel_value
# Instances to use.
EPICS_WRITER = epics_dal.WriteGroupInterface
EPICS_READER = epics_dal.ReadGroupInterface
BS_READER = bsread_dal.ReadGroupInterface
FUNCTION_PROXY = function_dal.FunctionProxy
DATA_PROCESSOR = SimpleDataProcessor
ACTION_EXECUTOR = ActionExecutor
_logger = logging.getLogger(__name__)
def scan(positioner, readables, writables=None, conditions=None, before_read=None, after_read=None, initialization=None,
finalization=None, settings=None, data_processor=None, before_move=None, after_move=None):
# Initialize the scanner instance.
scanner_instance = scanner(positioner, readables, writables, conditions, before_read, after_read, initialization,
finalization, settings, data_processor, before_move, after_move)
return scanner_instance.discrete_scan()
def scanner(positioner, readables, writables=None, conditions=None, before_read=None, after_read=None,
initialization=None, finalization=None, settings=None, data_processor=None,
before_move=None, after_move=None):
# Allow a list or a single value to be passed. Initialize None values.
writables = convert_input(convert_to_list(writables) or [])
readables = convert_input(convert_to_list(readables) or [])
conditions = convert_conditions(convert_to_list(conditions) or [])
before_read = convert_to_list(before_read) or []
after_read = convert_to_list(after_read) or []
before_move = convert_to_list(before_move) or []
after_move = convert_to_list(after_move) or []
initialization = convert_to_list(initialization) or []
finalization = convert_to_list(finalization) or []
settings = settings or scan_settings()
# TODO: Ugly. The scanner should not depend on a particular positioner implementation.
if isinstance(positioner, BsreadPositioner) and settings.n_measurements > 1:
raise ValueError("When using BsreadPositioner the maximum number of n_measurements = 1.")
bs_reader = _initialize_bs_dal(readables, conditions, settings.bs_read_filter, positioner)
epics_writer, epics_pv_reader, epics_condition_reader = _initialize_epics_dal(writables,
readables,
conditions,
settings)
function_writer, function_reader, function_condition = _initialize_function_dal(writables,
readables,
conditions)
writables_order = [type(writable) for writable in writables]
# Write function needs to merge PV and function proxy data.
def write_data(positions):
positions = convert_to_list(positions)
pv_values = [x for x, source in zip(positions, writables_order) if source == EPICS_PV]
function_values = [x for x, source in zip(positions, writables_order) if source == FUNCTION_VALUE]
if epics_writer:
epics_writer.set_and_match(pv_values)
if function_writer:
function_writer.write(function_values)
# Order of value sources, needed to reconstruct the correct order of the result.
readables_order = [type(readable) for readable in readables]
# Read function needs to merge BS, PV, and function proxy data.
def read_data(current_position_index, retry=False):
_logger.debug("Reading data for position index %s." % current_position_index)
bs_values = iter(bs_reader.read(current_position_index, retry) if bs_reader else [])
epics_values = iter(epics_pv_reader.read(current_position_index) if epics_pv_reader else [])
function_values = iter(function_reader.read(current_position_index) if function_reader else [])
# Interleave the values correctly.
result = []
for source in readables_order:
if source == BS_PROPERTY:
next_result = next(bs_values)
elif source == EPICS_PV:
next_result = next(epics_values)
elif source == FUNCTION_VALUE:
next_result = next(function_values)
else:
raise ValueError("Unknown type of readable %s used." % source)
# We flatten the result, whenever possible.
if isinstance(next_result, list) and source != FUNCTION_VALUE:
result.extend(next_result)
else:
result.append(next_result)
return result
# Order of value sources, needed to reconstruct the correct order of the result.
conditions_order = [type(condition) for condition in conditions]
# Validate function needs to validate both BS, PV, and function proxy data.
def validate_data(current_position_index, data):
_logger.debug("Reading data for position index %s." % current_position_index)
bs_values = iter(bs_reader.read_cached_conditions() if bs_reader else [])
epics_values = iter(epics_condition_reader.read(current_position_index) if epics_condition_reader else [])
function_values = iter(function_condition.read(current_position_index) if function_condition else [])
for index, source in enumerate(conditions_order):
if source == BS_CONDITION:
value = next(bs_values)
elif source == EPICS_CONDITION:
value = next(epics_values)
elif source == FUNCTION_CONDITION:
value = next(function_values)
else:
raise ValueError("Unknown type of condition %s used." % source)
value_valid = False
# Function conditions are self contained.
if source == FUNCTION_CONDITION:
if value:
value_valid = True
else:
expected_value = conditions[index].value
tolerance = conditions[index].tolerance
operation = conditions[index].operation
if compare_channel_value(value, expected_value, tolerance, operation):
value_valid = True
if not value_valid:
if conditions[index].action == ConditionAction.Retry:
return False
if source == FUNCTION_CONDITION:
raise ValueError("Function condition %s returned False." % conditions[index].identifier)
else:
raise ValueError("Condition %s failed, expected value %s, actual value %s, "
"tolerance %s, operation %s." %
(conditions[index].identifier,
conditions[index].value,
value,
conditions[index].tolerance,
conditions[index].operation))
return True
if not data_processor:
data_processor = DATA_PROCESSOR()
# Before acquisition hook.
before_measurement_executor = None
if before_read:
before_measurement_executor = ACTION_EXECUTOR(before_read).execute
# After acquisition hook.
after_measurement_executor = None
if after_read:
after_measurement_executor = ACTION_EXECUTOR(after_read).execute
# Executor before each move.
before_move_executor = None
if before_move:
before_move_executor = ACTION_EXECUTOR(before_move).execute
# Executor after each move.
after_move_executor = None
if after_move:
after_move_executor = ACTION_EXECUTOR(after_move).execute
# Initialization (before move to first position) hook.
initialization_executor = None
if initialization:
initialization_executor = ACTION_EXECUTOR(initialization).execute
# Finalization (after last acquisition AND on error) hook.
finalization_executor = None
if finalization:
finalization_executor = ACTION_EXECUTOR(finalization).execute
scanner = Scanner(positioner=positioner, data_processor=data_processor, reader=read_data,
writer=write_data, before_measurement_executor=before_measurement_executor,
after_measurement_executor=after_measurement_executor,
initialization_executor=initialization_executor,
finalization_executor=finalization_executor, data_validator=validate_data, settings=settings,
before_move_executor=before_move_executor, after_move_executor=after_move_executor)
return scanner
def _initialize_epics_dal(writables, readables, conditions, settings):
epics_writer = None
if writables:
epics_writables = [x for x in writables if isinstance(x, EPICS_PV)]
if epics_writables:
# Instantiate the PVs to move the motors.
epics_writer = EPICS_WRITER(pv_names=[pv.pv_name for pv in epics_writables],
readback_pv_names=[pv.readback_pv_name for pv in epics_writables],
tolerances=[pv.tolerance for pv in epics_writables],
timeout=settings.write_timeout)
epics_readables_pv_names = [x.pv_name for x in filter(lambda x: isinstance(x, EPICS_PV), readables)]
epics_conditions_pv_names = [x.pv_name for x in filter(lambda x: isinstance(x, EPICS_CONDITION), conditions)]
# Reading epics PV values.
epics_pv_reader = None
if epics_readables_pv_names:
epics_pv_reader = EPICS_READER(pv_names=epics_readables_pv_names)
# Reading epics condition values.
epics_condition_reader = None
if epics_conditions_pv_names:
epics_condition_reader = EPICS_READER(pv_names=epics_conditions_pv_names)
return epics_writer, epics_pv_reader, epics_condition_reader
def _initialize_bs_dal(readables, conditions, filter_function, positioner):
bs_readables = [x for x in filter(lambda x: isinstance(x, BS_PROPERTY), readables)]
bs_conditions = [x for x in filter(lambda x: isinstance(x, BS_CONDITION), conditions)]
bs_reader = None
if bs_readables or bs_conditions:
# TODO: The scanner should not depend on a particular positioner. Refactor.
if isinstance(positioner, BsreadPositioner):
bs_reader = bsread_dal.ImmediateReadGroupInterface(properties=bs_readables,
conditions=bs_conditions,
filter_function=filter_function)
positioner.set_bs_reader(bs_reader)
return bs_reader
else:
bs_reader = BS_READER(properties=bs_readables, conditions=bs_conditions, filter_function=filter_function)
return bs_reader
def _initialize_function_dal(writables, readables, conditions):
function_writer = FunctionProxy([x for x in writables if isinstance(x, FUNCTION_VALUE)])
function_reader = FunctionProxy([x for x in readables if isinstance(x, FUNCTION_VALUE)])
function_condition = FunctionProxy([x for x in conditions if isinstance(x, FUNCTION_CONDITION)])
return function_writer, function_reader, function_condition

View File

@@ -0,0 +1,58 @@
from collections import namedtuple
from pyscan import config, convert_input
from pyscan.scan import EPICS_WRITER, EPICS_READER
from pyscan.scan_parameters import epics_pv
from pyscan.utils import convert_to_list
SET_EPICS_PV = namedtuple("SET_EPICS_PV", ["pv_name", "value", "readback_pv_name", "tolerance", "timeout"])
RESTORE_WRITABLE_PVS = namedtuple("RESTORE_WRITABLE_PVS", [])
def action_set_epics_pv(pv_name, value, readback_pv_name=None, tolerance=None, timeout=None):
"""
Construct a tuple for set PV representation.
:param pv_name: Name of the PV.
:param value: Value to set the PV to.
:param readback_pv_name: Name of the readback PV.
:param tolerance: Tolerance if the PV is writable.
:param timeout: Timeout for setting the pv value.
:return: Tuple of (pv_name, pv_readback, tolerance)
"""
_, pv_name, readback_pv_name, tolerance, readback_pv_value = epics_pv(pv_name, readback_pv_name, tolerance)
if value is None:
raise ValueError("pv value not specified.")
if not timeout or timeout < 0:
timeout = config.epics_default_set_and_match_timeout
def execute():
writer = EPICS_WRITER(pv_name, readback_pv_name, tolerance, timeout)
writer.set_and_match(value)
writer.close()
return execute
def action_restore(writables):
"""
Restore the initial state of the writable PVs.
:return: Empty tuple, to be replaced with the initial values.
"""
writables = convert_input(convert_to_list(writables))
pv_names = [pv.pv_name for pv in writables]
readback_pv_names = [pv.readback_pv_name for pv in writables]
tolerances = [pv.tolerance for pv in writables]
# Get the initial values.
reader = EPICS_READER(pv_names)
initial_values = reader.read()
reader.close()
def execute():
writer = EPICS_WRITER(pv_names, readback_pv_names, tolerances)
writer.set_and_match(initial_values)
writer.close()
return execute

View File

@@ -0,0 +1,280 @@
from collections import namedtuple
from enum import Enum
from pyscan import config
EPICS_PV = namedtuple("EPICS_PV", ["identifier", "pv_name", "readback_pv_name", "tolerance", "readback_pv_value"])
EPICS_CONDITION = namedtuple("EPICS_CONDITION", ["identifier", "pv_name", "value", "action", "tolerance", "operation"])
BS_PROPERTY = namedtuple("BS_PROPERTY", ["identifier", "property", "default_value"])
BS_CONDITION = namedtuple("BS_CONDITION", ["identifier", "property", "value", "action", "tolerance", "operation",
"default_value"])
SCAN_SETTINGS = namedtuple("SCAN_SETTINGS", ["measurement_interval", "n_measurements",
"write_timeout", "settling_time", "progress_callback", "bs_read_filter"])
FUNCTION_VALUE = namedtuple("FUNCTION_VALUE", ["identifier", "call_function"])
FUNCTION_CONDITION = namedtuple("FUNCTION_CONDITION", ["identifier", "call_function", "action"])
class ConditionComparison(Enum):
EQUAL = 0
NOT_EQUAL = 1
LOWER = 2
LOWER_OR_EQUAL = 3
HIGHER = 4
HIGHER_OR_EQUAL = 5
class ConditionAction(Enum):
Abort = 1
Retry = 2
# Used to determine if a parameter was passed or the default value is used.
_default_value_placeholder = object()
def function_value(call_function, name=None):
"""
Construct a tuple for function representation.
:param call_function: Function to invoke.
:param name: Name to assign to this function.
:return: Tuple of ("identifier", "call_function")
"""
# If the name is not specified, use a counter to set the function name.
if not name:
name = "function_%d" % function_value.function_count
function_value.function_count += 1
identifier = name
return FUNCTION_VALUE(identifier, call_function)
function_value.function_count = 0
def function_condition(call_function, name=None, action=None):
"""
Construct a tuple for condition checking function representation.
:param call_function: Function to invoke.
:param name: Name to assign to this function.
:param action: What to do then the return value is False.
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
:return: Tuple of ("identifier", "call_function", "action")
"""
# If the name is not specified, use a counter to set the function name.
if not name:
name = "function_condition_%d" % function_condition.function_count
function_condition.function_count += 1
identifier = name
# The default action is Abort - used for conditions.
if not action:
action = ConditionAction.Abort
return FUNCTION_CONDITION(identifier, call_function, action)
function_condition.function_count = 0
def epics_pv(pv_name, readback_pv_name=None, tolerance=None, readback_pv_value=None):
"""
Construct a tuple for PV representation
:param pv_name: Name of the PV.
:param readback_pv_name: Name of the readback PV.
:param tolerance: Tolerance if the PV is writable.
:param readback_pv_value: If the readback_pv_value is set, the readback is compared against this instead of
comparing it to the setpoint.
:return: Tuple of (identifier, pv_name, pv_readback, tolerance)
"""
identifier = pv_name
if not pv_name:
raise ValueError("pv_name not specified.")
if not readback_pv_name:
readback_pv_name = pv_name
if not tolerance or tolerance < config.max_float_tolerance:
tolerance = config.max_float_tolerance
return EPICS_PV(identifier, pv_name, readback_pv_name, tolerance, readback_pv_value)
def epics_condition(pv_name, value, action=None, tolerance=None, operation=ConditionComparison.EQUAL):
"""
Construct a tuple for an epics condition representation.
:param pv_name: Name of the PV to monitor.
:param value: Value we expect the PV to be in.
:param action: What to do when the condition fails.
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
:param tolerance: Tolerance within which the condition needs to be.
:param operation: How to compare the received value with the expected value.
Allowed values: ConditionComparison.[EQUAL,NOT_EQUAL, LOWER, LOWER_OR_EQUAL, HIGHER, HIGHER_OR_EQUAL]
:return: Tuple of ("pv_name", "value", "action", "tolerance", "timeout", "operation")
"""
identifier = pv_name
if not pv_name:
raise ValueError("pv_name not specified.")
if value is None:
raise ValueError("pv value not specified.")
# the default action is Abort.
if not action:
action = ConditionAction.Abort
if not tolerance or tolerance < config.max_float_tolerance:
tolerance = config.max_float_tolerance
return EPICS_CONDITION(identifier, pv_name, value, action, tolerance, operation)
def bs_property(name, default_value=_default_value_placeholder):
"""
Construct a tuple for bs read property representation.
:param name: Complete property name.
:param default_value: The default value that is assigned to the property if it is missing.
:return: Tuple of ("identifier", "property", "default_value")
"""
identifier = name
if not name:
raise ValueError("name not specified.")
# We need this to allow the user to change the config at runtime.
if default_value is _default_value_placeholder:
default_value = config.bs_default_missing_property_value
return BS_PROPERTY(identifier, name, default_value)
def bs_condition(name, value, action=None, tolerance=None, operation=ConditionComparison.EQUAL,
default_value=_default_value_placeholder):
"""
Construct a tuple for bs condition property representation.
:param name: Complete property name.
:param value: Expected value.
:param action: What to do when the condition fails.
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
:param tolerance: Tolerance within which the condition needs to be.
:param operation: How to compare the received value with the expected value.
Allowed values: ConditionComparison.[EQUAL,NOT_EQUAL, LOWER, LOWER_OR_EQUAL, HIGHER, HIGHER_OR_EQUAL]
:param default_value: Default value of a condition, if not present in the bs stream.
:return: Tuple of ("identifier", "property", "value", "action", "tolerance", "operation", "default_value")
"""
identifier = name
if not name:
raise ValueError("name not specified.")
if value is None:
raise ValueError("value not specified.")
if not tolerance or tolerance < config.max_float_tolerance:
tolerance = config.max_float_tolerance
if not action:
action = ConditionAction.Abort
# We need this to allow the user to change the config at runtime.
if default_value is _default_value_placeholder:
default_value = config.bs_default_missing_property_value
return BS_CONDITION(identifier, name, value, action, tolerance, operation, default_value)
def scan_settings(measurement_interval=None, n_measurements=None, write_timeout=None, settling_time=None,
progress_callback=None, bs_read_filter=None):
"""
Set the scan settings.
:param measurement_interval: Default 0. Interval between each measurement, in case n_measurements is more than 1.
:param n_measurements: Default 1. How many measurements to make at each position.
:param write_timeout: How much time to wait in seconds for set_and_match operations on epics PVs.
:param settling_time: How much time to wait in seconds after the motors have reached the desired destination.
:param progress_callback: Function to call after each scan step is completed.
Signature: def callback(current_position, total_positions)
:param bs_read_filter: Filter to apply to the bs read receive function, to filter incoming messages.
Signature: def callback(message)
:return: Scan settings named tuple.
"""
if not measurement_interval or measurement_interval < 0:
measurement_interval = config.scan_default_measurement_interval
if not n_measurements or n_measurements < 1:
n_measurements = config.scan_default_n_measurements
if not write_timeout or write_timeout < 0:
write_timeout = config.epics_default_set_and_match_timeout
if not settling_time or settling_time < 0:
settling_time = config.epics_default_settling_time
if not progress_callback:
def default_progress_callback(current_position, total_positions):
completed_percentage = 100.0 * (current_position / total_positions)
print("Scan: %.2f %% completed (%d/%d)" % (completed_percentage, current_position, total_positions))
progress_callback = default_progress_callback
return SCAN_SETTINGS(measurement_interval, n_measurements, write_timeout, settling_time, progress_callback,
bs_read_filter)
def convert_input(input_parameters):
"""
Convert any type of input parameter into appropriate named tuples.
:param input_parameters: Parameter input from the user.
:return: Inputs converted into named tuples.
"""
converted_inputs = []
for input in input_parameters:
# Input already of correct type.
if isinstance(input, (EPICS_PV, BS_PROPERTY, FUNCTION_VALUE)):
converted_inputs.append(input)
# We need to convert it.
elif isinstance(input, str):
# Check if the string is valid.
if not input:
raise ValueError("Input cannot be an empty string.")
if "://" in input:
# Epics PV!
if input.lower().startswith("ca://"):
converted_inputs.append(epics_pv(input[5:]))
# bs_read property.
elif input.lower().startswith("bs://"):
converted_inputs.append(bs_property(input[5:]))
# A new protocol we don't know about?
else:
raise ValueError("Readable %s uses an unexpected protocol. "
"'ca://' and 'bs://' are supported." % input)
# No protocol specified, default is epics.
else:
converted_inputs.append(epics_pv(input))
elif callable(input):
converted_inputs.append(function_value(input))
# Supported named tuples or string, we cannot interpret the rest.
else:
raise ValueError("Input of unexpected type %s. Value: '%s'." % (type(input), input))
return converted_inputs
def convert_conditions(input_conditions):
"""
Convert any type type of condition input parameter into appropriate named tuples.
:param input_conditions: Condition input from the used.
:return: Input conditions converted into named tuples.
"""
converted_inputs = []
for input in input_conditions:
# Input already of correct type.
if isinstance(input, (EPICS_CONDITION, BS_CONDITION, FUNCTION_CONDITION)):
converted_inputs.append(input)
# Function call.
elif callable(input):
converted_inputs.append(function_condition(input))
# Unknown.
else:
raise ValueError("Condition of unexpected type %s. Value: '%s'." % (type(input), input))
return converted_inputs

202
packages/pyscan/scanner.py Normal file
View File

@@ -0,0 +1,202 @@
from itertools import count
from time import sleep
from pyscan import config
from pyscan.scan_parameters import scan_settings
STATUS_INITIALIZED = "INITIALIZED"
STATUS_RUNNING = "RUNNING"
STATUS_FINISHED = "FINISHED"
STATUS_PAUSED = "PAUSED"
STATUS_ABORTED = "ABORTED"
class Scanner(object):
"""
Perform discrete and continues scans.
"""
def __init__(self, positioner, data_processor, reader, writer=None, before_measurement_executor=None,
after_measurement_executor=None, initialization_executor=None, finalization_executor=None,
data_validator=None, settings=None, before_move_executor=None, after_move_executor=None):
"""
Initialize scanner.
:param positioner: Positioner should provide a generator to get the positions to move to.
:param writer: Object that implements the write(position) method and sets the positions.
:param data_processor: How to store and handle the data.
:param reader: Object that implements the read() method to return data to the data_processor.
:param before_measurement_executor: Callbacks executor that executed before measurements.
:param after_measurement_executor: Callbacks executor that executed after measurements.
:param before_move_executor: Callbacks executor that executes before each move.
:param after_move_executor: Callbacks executor that executes after each move.
"""
self.positioner = positioner
self.writer = writer
self.data_processor = data_processor
self.reader = reader
self.before_measurement_executor = before_measurement_executor
self.after_measurement_executor = after_measurement_executor
self.initialization_executor = initialization_executor
self.finalization_executor = finalization_executor
self.settings = settings or scan_settings()
self.before_move_executor = before_move_executor
self.after_move_executor = after_move_executor
# If no data validator is provided, data is always valid.
self.data_validator = data_validator or (lambda position, data: True)
self._user_abort_scan_flag = False
self._user_pause_scan_flag = False
self._status = STATUS_INITIALIZED
def abort_scan(self):
"""
Abort the scan after the next measurement.
"""
self._user_abort_scan_flag = True
def pause_scan(self):
"""
Pause the scan after the next measurement.
"""
self._user_pause_scan_flag = True
def get_status(self):
return self._status
def resume_scan(self):
"""
Resume the scan.
"""
self._user_pause_scan_flag = False
def _verify_scan_status(self):
"""
Check if the conditions to pause or abort the scan are met.
:raise Exception in case the conditions are met.
"""
# Check if the abort flag is set.
if self._user_abort_scan_flag:
self._status = STATUS_ABORTED
raise Exception("User aborted scan.")
# If the scan is in pause, wait until it is resumed or the user aborts the scan.
if self._user_pause_scan_flag:
self._status = STATUS_PAUSED
while self._user_pause_scan_flag:
if self._user_abort_scan_flag:
self._status = STATUS_ABORTED
raise Exception("User aborted scan in pause.")
sleep(config.scan_pause_sleep_interval)
# Once the pause flag is cleared, the scanning continues.
self._status = STATUS_RUNNING
def _perform_single_read(self, current_position_index):
"""
Read a single result from the channel.
:param current_position_index: Current position, passed to the validator.
:return: Single result (all channels).
"""
n_current_acquisition = 0
# Collect data until acquired data is valid or retry limit reached.
while n_current_acquisition < config.scan_acquisition_retry_limit:
retry_acquisition = n_current_acquisition != 0
single_measurement = self.reader(current_position_index, retry=retry_acquisition)
# If the data is valid, break out of the loop.
if self.data_validator(current_position_index, single_measurement):
return single_measurement
n_current_acquisition += 1
sleep(config.scan_acquisition_retry_delay)
# Could not read the data within the retry limit.
else:
raise Exception("Number of maximum read attempts (%d) exceeded. Cannot read valid data at position %s."
% (config.scan_acquisition_retry_limit, current_position_index))
def _read_and_process_data(self, current_position):
"""
Read the data and pass it on only if valid.
:param current_position: Current position reached by the scan.
:return: Current position scan data.
"""
# We do a single acquisition per position.
if self.settings.n_measurements == 1:
result = self._perform_single_read(current_position)
# Multiple acquisitions.
else:
result = []
for n_measurement in range(self.settings.n_measurements):
result.append(self._perform_single_read(current_position))
sleep(self.settings.measurement_interval)
# Process only valid data.
self.data_processor.process(current_position, result)
return result
def discrete_scan(self):
"""
Perform a discrete scan - set a position, read, continue. Return value at the end.
"""
try:
self._status = STATUS_RUNNING
# Get how many positions we have in total.
n_of_positions = sum(1 for _ in self.positioner.get_generator())
# Report the 0% completed.
self.settings.progress_callback(0, n_of_positions)
# Set up the experiment.
if self.initialization_executor:
self.initialization_executor(self)
for position_index, next_positions in zip(count(1), self.positioner.get_generator()):
# Execute before moving to the next position.
if self.before_move_executor:
self.before_move_executor(next_positions)
# Position yourself before reading.
if self.writer:
self.writer(next_positions)
# Settling time, wait after positions has been reached.
sleep(self.settings.settling_time)
# Execute the after move executor.
if self.after_move_executor:
self.after_move_executor(next_positions)
# Pre reading callbacks.
if self.before_measurement_executor:
self.before_measurement_executor(next_positions)
# Read and process the data in the current position.
position_data = self._read_and_process_data(next_positions)
# Post reading callbacks.
if self.after_measurement_executor:
self.after_measurement_executor(next_positions, position_data)
# Report about the progress.
self.settings.progress_callback(position_index, n_of_positions)
# Verify is the scan should continue.
self._verify_scan_status()
finally:
# Clean up after yourself.
if self.finalization_executor:
self.finalization_executor(self)
# If the scan was aborted we do not change the status to finished.
if self._status != STATUS_ABORTED:
self._status = STATUS_FINISHED
return self.data_processor.get_data()
def continuous_scan(self):
# TODO: Needs implementation.
pass

216
packages/pyscan/utils.py Normal file
View File

@@ -0,0 +1,216 @@
import inspect
from collections import OrderedDict
from time import sleep
from epics.pv import PV
from pyscan import config
from pyscan.scan_parameters import convert_input, ConditionComparison
def compare_channel_value(current_value, expected_value, tolerance=0.0, operation=ConditionComparison.EQUAL):
"""
Check if the pv value is the same as the expected value, within tolerance for int and float.
:param current_value: Current value to compare it to.
:param expected_value: Expected value of the PV.
:param tolerance: Tolerance for number comparison. Cannot be less than the minimum tolerance.
:param operation: Operation to perform on the current and expected value - works for int and floats.
:return: True if the value matches.
"""
# Minimum tolerance allowed.
tolerance = max(tolerance, config.max_float_tolerance)
def compare_value(value):
# For numbers we compare them within tolerance.
if isinstance(current_value, (float, int)):
if operation == ConditionComparison.EQUAL:
return abs(current_value - expected_value) <= tolerance
elif operation == ConditionComparison.HIGHER:
return (current_value - expected_value) > tolerance
elif operation == ConditionComparison.HIGHER_OR_EQUAL:
return (current_value - expected_value) >= tolerance
elif operation == ConditionComparison.LOWER:
return (current_value - expected_value) < 0 or abs(current_value - expected_value) < tolerance
elif operation == ConditionComparison.LOWER_OR_EQUAL:
return (current_value - expected_value) <= 0 or abs(current_value - expected_value) <= tolerance
elif operation == ConditionComparison.NOT_EQUAL:
return abs(current_value - expected_value) > tolerance
# Otherwise use the object comparison.
else:
try:
if operation == ConditionComparison.EQUAL:
return current_value == expected_value
elif operation == ConditionComparison.HIGHER:
return current_value > expected_value
elif operation == ConditionComparison.HIGHER_OR_EQUAL:
return current_value >= expected_value
elif operation == ConditionComparison.LOWER:
return current_value < expected_value
elif operation == ConditionComparison.LOWER_OR_EQUAL:
return current_value <= expected_value
elif operation == ConditionComparison.NOT_EQUAL:
return current_value != expected_value
except:
raise ValueError("Do not know how to compare current_value %s with expected_value %s and action %s."
% (current_value, expected_value, operation))
return False
if isinstance(current_value, list):
# In case of a list, any of the provided values will do.
return any((compare_value(value) for value in expected_value))
else:
return compare_value(current_value)
def connect_to_pv(pv_name, n_connection_attempts=3):
"""
Start a connection to a PV.
:param pv_name: PV name to connect to.
:param n_connection_attempts: How many times you should try to connect before raising an exception.
:return: PV object.
:raises ValueError if cannot connect to PV.
"""
pv = PV(pv_name, auto_monitor=False)
for i in range(n_connection_attempts):
if pv.connect():
return pv
sleep(0.1)
raise ValueError("Cannot connect to PV '%s'." % pv_name)
def validate_lists_length(*args):
"""
Check if all the provided lists are of the same length.
:param args: Lists.
:raise ValueError if they are not of the same length.
"""
if not args:
raise ValueError("Cannot compare lengths of None.")
initial_length = len(args[0])
if not all([len(element) == initial_length for element in args]):
error = "The provided lists must be of same length.\n"
for element in args:
error += "%s\n" % element
raise ValueError(error)
def convert_to_list(value):
"""
If the input parameter is not a list, convert to one.
:return: The value in a list, or None.
"""
# If None or a list, just return the value as it is.
if (value is None) or isinstance(value, list):
return value
# Otherwise treat the value as the first element in a list.
return [value]
def convert_to_position_list(axis_list):
"""
# Change the PER KNOB to PER INDEX of positions.
:param axis_list: PER KNOB list of positions.
:return: PER INDEX list of positions.
"""
return [list(positions) for positions in zip(*axis_list)]
def flat_list_generator(list_to_flatten):
# Just return the most inner list.
if (len(list_to_flatten) == 0) or (not isinstance(list_to_flatten[0], list)):
yield list_to_flatten
# Otherwise we have to go deeper.
else:
for inner_list in list_to_flatten:
yield from flat_list_generator(inner_list)
class ActionExecutor(object):
"""
Execute all callbacks in the same thread.
Each callback method should accept 2 parameters: position, sampled values.
"""
def __init__(self, actions):
"""
Initialize the action executor.
:param actions: Actions to execute. Single action or list of.
"""
self.actions = convert_to_list(actions)
def execute(self, position, position_data=None):
for action in self.actions:
n_parameters = len(inspect.signature(action).parameters)
if n_parameters == 2:
action(position, position_data)
elif n_parameters == 1:
action(position)
else:
action()
class SimpleDataProcessor(object):
"""
Save the position and the received data at this position.
"""
def __init__(self, positions=None, data=None):
"""
Initialize the simple data processor.
:param positions: List to store the visited positions. Default: internal list.
:param data: List to store the data at each position. Default: internal list.
"""
self.positions = positions if positions is not None else []
self.data = data if data is not None else []
def process(self, position, data):
self.positions.append(position)
self.data.append(data)
def get_data(self):
return self.data
def get_positions(self):
return self.positions
class DictionaryDataProcessor(SimpleDataProcessor):
"""
Save the positions and the received data for each position in a dictionary.
"""
def __init__(self, readables, positions=None, data=None):
"""
Readables specified in the scan.
:param readables: Same readables that were passed to the scan function.
"""
super(DictionaryDataProcessor, self).__init__(positions=positions, data=data)
readables = convert_input(readables)
self.readable_ids = [x.identifier for x in readables]
def process(self, position, data):
self.positions.append(position)
# Create a dictionary with the results.
values = OrderedDict(zip(self.readable_ids, data))
self.data.append(values)

View File

@@ -0,0 +1,111 @@
Metadata-Version: 1.0
Name: elog
Version: 1.3.4
Summary: Python library to access Elog.
Home-page: https://github.com/paulscherrerinstitute/py_elog
Author: Paul Scherrer Institute (PSI)
Author-email: UNKNOWN
License: UNKNOWN
Description: [![Build Status](https://travis-ci.org/paulscherrerinstitute/py_elog.svg?branch=master)](https://travis-ci.org/paulscherrerinstitute/py_elog) [![Build status](https://ci.appveyor.com/api/projects/status/glo428gqw951y512?svg=true)](https://ci.appveyor.com/project/simongregorebner/py-elog)
# Overview
This Python module provides a native interface [electronic logbooks](https://midas.psi.ch/elog/). It is compatible with Python versions 3.5 and higher.
# Usage
For accessing a logbook at ```http[s]://<hostename>:<port>/[<subdir>/]<logbook>/[<msg_id>]``` a logbook handle must be retrieved.
```python
import elog
# Open GFA SwissFEL test logbook
logbook = elog.open('https://elog-gfa.psi.ch/SwissFEL+test/')
# Contstructor using detailed arguments
# Open demo logbook on local host: http://localhost:8080/demo/
logbook = elog.open('localhost', 'demo', port=8080, use_ssl=False)
```
Once you have hold of the logbook handle one of its public methods can be used to read, create, reply to, edit or delete the message.
## Get Existing Message Ids
Get all the existing message ids of a logbook
```python
message_ids = logbook.get_message_ids()
```
To get if of the last inserted message
```python
last_message_id = logbook.get_last_message_id()
```
## Read Message
```python
# Read message with with message ID = 23
message, attributes, attachments = logbook.read(23)
```
## Create Message
```python
# Create new message with some text, attributes (dict of attributes + kwargs) and attachments
new_msg_id = logbook.post('This is message text', attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
```
What attributes are required is determined by the configuration of the elog server (keywork `Required Attributes`).
If the configuration looks like this:
```
Required Attributes = Author, Type
```
You have to provide author and type when posting a message.
In case type need to be specified, the supported keywords can as well be found in the elog configuration with the key `Options Type`.
If the config looks like this:
```
Options Type = Routine, Software Installation, Problem Fixed, Configuration, Other
```
A working create call would look like this:
```python
new_msg_id = logbook.post('This is message text', author='me', type='Routine')
```
## Reply to Message
```python
# Reply to message with ID=23
new_msg_id = logbook.post('This is a reply', msg_id=23, reply=True, attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
```
## Edit Message
```python
# Edit message with ID=23. Changed message text, some attributes (dict of edited attributes + kwargs) and new attachments
edited_msg_id = logbook.post('This is new message text', msg_id=23, attributes=dict_of_changed_attributes, attachments=list_of_new_attachments, attribute_as_param='new value')
```
## Delete Message (and all its replies)
```python
# Delete message with ID=23. All its replies will also be deleted.
logbook.delete(23)
```
__Note:__ Due to the way elog implements delete this function is only supported on english logbooks.
# Installation
The Elog module and only depends on the `passlib` and `requests` library used for password encryption and http(s) communication. It is packed as [anaconda package](https://anaconda.org/paulscherrerinstitute/elog) and can be installed as follows:
```bash
conda install -c paulscherrerinstitute elog
```
Keywords: elog,electronic,logbook
Platform: UNKNOWN

View File

@@ -0,0 +1,8 @@
setup.py
elog/__init__.py
elog/logbook.py
elog/logbook_exceptions.py
elog.egg-info/PKG-INFO
elog.egg-info/SOURCES.txt
elog.egg-info/dependency_links.txt
elog.egg-info/top_level.txt

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@
elog

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,78 @@
__version__ = "0.16.2"
from bigtree.binarytree.construct import list_to_binarytree
from bigtree.dag.construct import dataframe_to_dag, dict_to_dag, list_to_dag
from bigtree.dag.export import dag_to_dataframe, dag_to_dict, dag_to_dot, dag_to_list
from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
from bigtree.node.dagnode import DAGNode
from bigtree.node.node import Node
from bigtree.tree.construct import (
add_dataframe_to_tree_by_name,
add_dataframe_to_tree_by_path,
add_dict_to_tree_by_name,
add_dict_to_tree_by_path,
add_path_to_tree,
dataframe_to_tree,
dataframe_to_tree_by_relation,
dict_to_tree,
list_to_tree,
list_to_tree_by_relation,
nested_dict_to_tree,
newick_to_tree,
str_to_tree,
)
from bigtree.tree.export import (
hprint_tree,
hyield_tree,
print_tree,
tree_to_dataframe,
tree_to_dict,
tree_to_dot,
tree_to_mermaid,
tree_to_nested_dict,
tree_to_newick,
tree_to_pillow,
yield_tree,
)
from bigtree.tree.helper import clone_tree, get_subtree, get_tree_diff, prune_tree
from bigtree.tree.modify import (
copy_and_replace_nodes_from_tree_to_tree,
copy_nodes,
copy_nodes_from_tree_to_tree,
copy_or_shift_logic,
replace_logic,
shift_and_replace_nodes,
shift_nodes,
)
from bigtree.tree.search import (
find,
find_attr,
find_attrs,
find_child,
find_child_by_name,
find_children,
find_full_path,
find_name,
find_names,
find_path,
find_paths,
find_relative_path,
findall,
)
from bigtree.utils.groot import speak_like_groot, whoami
from bigtree.utils.iterators import (
dag_iterator,
inorder_iter,
levelorder_iter,
levelordergroup_iter,
postorder_iter,
preorder_iter,
zigzag_iter,
zigzaggroup_iter,
)
from bigtree.utils.plot import reingold_tilford
from bigtree.workflows.app_calendar import Calendar
from bigtree.workflows.app_todo import AppToDo
sphinx_versions = ["latest", "0.16.2", "0.15.7", "0.14.8"]

View File

@@ -0,0 +1,53 @@
from typing import List, Type
from bigtree.node.binarynode import BinaryNode
__all__ = ["list_to_binarytree"]
def list_to_binarytree(
heapq_list: List[int], node_type: Type[BinaryNode] = BinaryNode
) -> BinaryNode:
"""Construct tree from a list of numbers (int or float) in heapq format.
Examples:
>>> from bigtree import list_to_binarytree, tree_to_dot
>>> nums_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> root = list_to_binarytree(nums_list)
>>> root.show()
1
├── 2
│ ├── 4
│ │ ├── 8
│ │ └── 9
│ └── 5
│ └── 10
└── 3
├── 6
└── 7
>>> graph = tree_to_dot(root, node_colour="gold")
>>> graph.write_png("assets/construct_binarytree.png")
![Sample Binary Tree](https://github.com/kayjan/bigtree/raw/master/assets/construct_binarytree.png)
Args:
heapq_list (List[int]): list containing integer node names, ordered in heapq fashion
node_type (Type[BinaryNode]): node type of tree to be created, defaults to ``BinaryNode``
Returns:
(BinaryNode)
"""
if not len(heapq_list):
raise ValueError("Input list does not contain any data, check `heapq_list`")
root_node = node_type(heapq_list[0])
node_list = [root_node]
for idx, num in enumerate(heapq_list):
if idx:
if idx % 2:
parent_idx = int((idx - 1) / 2)
else:
parent_idx = int((idx - 2) / 2)
node = node_type(num, parent=node_list[parent_idx]) # type: ignore
node_list.append(node)
return root_node

View File

@@ -0,0 +1,206 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple, Type
from bigtree.node.dagnode import DAGNode
from bigtree.utils.exceptions import optional_dependencies_pandas
try:
import pandas as pd
except ImportError: # pragma: no cover
pd = None
__all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"]
@optional_dependencies_pandas
def list_to_dag(
relations: List[Tuple[str, str]],
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from list of tuples containing parent-child names.
Note that node names must be unique.
Examples:
>>> from bigtree import list_to_dag, dag_iterator
>>> relations_list = [("a", "c"), ("a", "d"), ("b", "c"), ("c", "d"), ("d", "e")]
>>> dag = list_to_dag(relations_list)
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
relations (List[Tuple[str, str]]): list containing tuple of parent-child names
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
Returns:
(DAGNode)
"""
if not len(relations):
raise ValueError("Input list does not contain any data, check `relations`")
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
return dataframe_to_dag(
relation_data, child_col="child", parent_col="parent", node_type=node_type
)
def dict_to_dag(
relation_attrs: Dict[str, Any],
parent_key: str = "parents",
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from nested dictionary, ``key``: child name, ``value``: dictionary of parent names, attribute
name, and attribute value.
Note that node names must be unique.
Examples:
>>> from bigtree import dict_to_dag, dag_iterator
>>> relation_dict = {
... "a": {"step": 1},
... "b": {"step": 1},
... "c": {"parents": ["a", "b"], "step": 2},
... "d": {"parents": ["a", "c"], "step": 2},
... "e": {"parents": ["d"], "step": 3},
... }
>>> dag = dict_to_dag(relation_dict, parent_key="parents")
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
relation_attrs (Dict[str, Any]): dictionary containing node, node parents, and node attribute information,
key: child name, value: dictionary of parent names, node attribute, and attribute value
parent_key (str): key of dictionary to retrieve list of parents name, defaults to 'parent'
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
Returns:
(DAGNode)
"""
if not len(relation_attrs):
raise ValueError("Dictionary does not contain any data, check `relation_attrs`")
# Convert dictionary to dataframe
data = pd.DataFrame(relation_attrs).T.rename_axis("_tmp_child").reset_index()
if parent_key not in data:
raise ValueError(
f"Parent key {parent_key} not in dictionary, check `relation_attrs` and `parent_key`"
)
data = data.explode(parent_key)
return dataframe_to_dag(
data,
child_col="_tmp_child",
parent_col=parent_key,
node_type=node_type,
)
@optional_dependencies_pandas
def dataframe_to_dag(
data: pd.DataFrame,
child_col: str = "",
parent_col: str = "",
attribute_cols: List[str] = [],
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from pandas DataFrame.
Note that node names must be unique.
- `child_col` and `parent_col` specify columns for child name and parent name to construct DAG.
- `attribute_cols` specify columns for node attribute for child name.
- If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
columns are `attribute_cols`.
Examples:
>>> import pandas as pd
>>> from bigtree import dataframe_to_dag, dag_iterator
>>> relation_data = pd.DataFrame([
... ["a", None, 1],
... ["b", None, 1],
... ["c", "a", 2],
... ["c", "b", 2],
... ["d", "a", 2],
... ["d", "c", 2],
... ["e", "d", 3],
... ],
... columns=["child", "parent", "step"]
... )
>>> dag = dataframe_to_dag(relation_data)
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
data (pd.DataFrame): data containing path and node attribute information
child_col (str): column of data containing child name information, defaults to ''
if not set, it will take the first column of data
parent_col (str): column of data containing parent name information, defaults to ''
if not set, it will take the second column of data
attribute_cols (List[str]): columns of data containing child node attribute information,
if not set, it will take all columns of data except `child_col` and `parent_col`
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
Returns:
(DAGNode)
"""
data = data.copy()
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not child_col:
child_col = data.columns[0]
elif child_col not in data.columns:
raise ValueError(f"Child column not in data, check `child_col`: {child_col}")
if not parent_col:
parent_col = data.columns[1]
elif parent_col not in data.columns:
raise ValueError(f"Parent column not in data, check `parent_col`: {parent_col}")
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(child_col)
attribute_cols.remove(parent_col)
elif any([col not in data.columns for col in attribute_cols]):
raise ValueError(
f"One or more attribute column(s) not in data, check `attribute_cols`: {attribute_cols}"
)
data_check = data.copy()[[child_col, parent_col] + attribute_cols].drop_duplicates(
subset=[child_col] + attribute_cols
)
_duplicate_check = (
data_check[child_col]
.value_counts()
.to_frame("counts")
.rename_axis(child_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate child name with different attributes\n"
f"Check {_duplicate_check}"
)
if sum(data[child_col].isnull()):
raise ValueError(f"Child name cannot be empty, check column: {child_col}")
node_dict: Dict[str, DAGNode] = dict()
parent_node = DAGNode()
for row in data.reset_index(drop=True).to_dict(orient="index").values():
child_name = row[child_col]
parent_name = row[parent_col]
node_attrs = row.copy()
del node_attrs[child_col]
del node_attrs[parent_col]
node_attrs = {k: v for k, v in node_attrs.items() if not pd.isnull(v)}
child_node = node_dict.get(child_name, node_type(child_name))
child_node.set_attrs(node_attrs)
node_dict[child_name] = child_node
if not pd.isnull(parent_name):
parent_node = node_dict.get(parent_name, node_type(parent_name))
node_dict[parent_name] = parent_node
child_node.parents = [parent_node]
return parent_node

View File

@@ -0,0 +1,298 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple, TypeVar, Union
from bigtree.node.dagnode import DAGNode
from bigtree.utils.exceptions import (
optional_dependencies_image,
optional_dependencies_pandas,
)
from bigtree.utils.iterators import dag_iterator
try:
import pandas as pd
except ImportError: # pragma: no cover
pd = None
try:
import pydot
except ImportError: # pragma: no cover
pydot = None
__all__ = ["dag_to_list", "dag_to_dict", "dag_to_dataframe", "dag_to_dot"]
T = TypeVar("T", bound=DAGNode)
def dag_to_list(
dag: T,
) -> List[Tuple[str, str]]:
"""Export DAG to list of tuples containing parent-child names
Examples:
>>> from bigtree import DAGNode, dag_to_list
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_list(a)
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
Args:
dag (DAGNode): DAG to be exported
Returns:
(List[Tuple[str, str]])
"""
relations = []
for parent_node, child_node in dag_iterator(dag):
relations.append((parent_node.node_name, child_node.node_name))
return relations
def dag_to_dict(
dag: T,
parent_key: str = "parents",
attr_dict: Dict[str, str] = {},
all_attrs: bool = False,
) -> Dict[str, Any]:
"""Export DAG to dictionary.
Exported dictionary will have key as child name, and parent names and node attributes as a nested dictionary.
Examples:
>>> from bigtree import DAGNode, dag_to_dict
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_dict(a, parent_key="parent", attr_dict={"step": "step no."})
{'a': {'step no.': 1}, 'c': {'parent': ['a', 'b'], 'step no.': 2}, 'd': {'parent': ['a', 'c'], 'step no.': 2}, 'b': {'step no.': 1}, 'e': {'parent': ['d'], 'step no.': 3}}
Args:
dag (DAGNode): DAG to be exported
parent_key (str): dictionary key for `node.parent.node_name`, defaults to `parents`
attr_dict (Dict[str, str]): dictionary mapping node attributes to dictionary key,
key: node attributes, value: corresponding dictionary key, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes, defaults to False
Returns:
(Dict[str, Any])
"""
dag = dag.copy()
data_dict = {}
for parent_node, child_node in dag_iterator(dag):
if parent_node.is_root:
data_parent: Dict[str, Any] = {}
if all_attrs:
data_parent.update(
parent_node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
else:
for k, v in attr_dict.items():
data_parent[v] = parent_node.get_attr(k)
data_dict[parent_node.node_name] = data_parent
if data_dict.get(child_node.node_name):
data_dict[child_node.node_name][parent_key].append(parent_node.node_name)
else:
data_child = {parent_key: [parent_node.node_name]}
if all_attrs:
data_child.update(
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
)
else:
for k, v in attr_dict.items():
data_child[v] = child_node.get_attr(k)
data_dict[child_node.node_name] = data_child
return data_dict
@optional_dependencies_pandas
def dag_to_dataframe(
dag: T,
name_col: str = "name",
parent_col: str = "parent",
attr_dict: Dict[str, str] = {},
all_attrs: bool = False,
) -> pd.DataFrame:
"""Export DAG to pandas DataFrame.
Examples:
>>> from bigtree import DAGNode, dag_to_dataframe
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_dataframe(a, name_col="name", parent_col="parent", attr_dict={"step": "step no."})
name parent step no.
0 a None 1
1 c a 2
2 d a 2
3 b None 1
4 c b 2
5 d c 2
6 e d 3
Args:
dag (DAGNode): DAG to be exported
name_col (str): column name for `node.node_name`, defaults to 'name'
parent_col (str): column name for `node.parent.node_name`, defaults to 'parent'
attr_dict (Dict[str, str]): dictionary mapping node attributes to column name,
key: node attributes, value: corresponding column in dataframe, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes, defaults to False
Returns:
(pd.DataFrame)
"""
dag = dag.copy()
data_list: List[Dict[str, Any]] = []
for parent_node, child_node in dag_iterator(dag):
if parent_node.is_root:
data_parent = {name_col: parent_node.node_name, parent_col: None}
if all_attrs:
data_parent.update(
parent_node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
else:
for k, v in attr_dict.items():
data_parent[v] = parent_node.get_attr(k)
data_list.append(data_parent)
data_child = {name_col: child_node.node_name, parent_col: parent_node.node_name}
if all_attrs:
data_child.update(
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
)
else:
for k, v in attr_dict.items():
data_child[v] = child_node.get_attr(k)
data_list.append(data_child)
return pd.DataFrame(data_list).drop_duplicates().reset_index(drop=True)
@optional_dependencies_image("pydot")
def dag_to_dot(
dag: Union[T, List[T]],
rankdir: str = "TB",
bg_colour: str = "",
node_colour: str = "",
node_shape: str = "",
edge_colour: str = "",
node_attr: str = "",
edge_attr: str = "",
) -> pydot.Dot:
r"""Export DAG or list of DAGs to image.
Note that node names must be unique.
Possible node attributes include style, fillcolor, shape.
Examples:
>>> from bigtree import DAGNode, dag_to_dot
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_graph = dag_to_dot(a)
Display image directly without saving (requires IPython)
>>> from IPython.display import Image, display
>>> plt = Image(dag_graph.create_png())
>>> display(plt)
<IPython.core.display.Image object>
Export to image, dot file, etc.
>>> dag_graph.write_png("assets/docstr/tree_dag.png")
>>> dag_graph.write_dot("assets/docstr/tree_dag.dot")
Export to string
>>> dag_graph.to_string()
'strict digraph G {\nrankdir=TB;\nc [label=c];\na [label=a];\na -> c;\nd [label=d];\na [label=a];\na -> d;\nc [label=c];\nb [label=b];\nb -> c;\nd [label=d];\nc [label=c];\nc -> d;\ne [label=e];\nd [label=d];\nd -> e;\n}\n'
Args:
dag (Union[DAGNode, List[DAGNode]]): DAG or list of DAGs to be exported
rankdir (str): set direction of graph layout, defaults to 'TB', can be 'BT, 'LR', 'RL'
bg_colour (str): background color of image, defaults to ''
node_colour (str): fill colour of nodes, defaults to ''
node_shape (str): shape of nodes, defaults to None
Possible node_shape include "circle", "square", "diamond", "triangle"
edge_colour (str): colour of edges, defaults to ''
node_attr (str): node attribute for style, overrides node_colour, defaults to ''
Possible node attributes include {"style": "filled", "fillcolor": "gold"}
edge_attr (str): edge attribute for style, overrides edge_colour, defaults to ''
Possible edge attributes include {"style": "bold", "label": "edge label", "color": "black"}
Returns:
(pydot.Dot)
"""
# Get style
if bg_colour:
graph_style = dict(bgcolor=bg_colour)
else:
graph_style = dict()
if node_colour:
node_style = dict(style="filled", fillcolor=node_colour)
else:
node_style = dict()
if node_shape:
node_style["shape"] = node_shape
if edge_colour:
edge_style = dict(color=edge_colour)
else:
edge_style = dict()
_graph = pydot.Dot(
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
)
if not isinstance(dag, list):
dag = [dag]
for _dag in dag:
if not isinstance(_dag, DAGNode):
raise TypeError(
"Tree should be of type `DAGNode`, or inherit from `DAGNode`"
)
_dag = _dag.copy()
for parent_node, child_node in dag_iterator(_dag):
_node_style = node_style.copy()
_edge_style = edge_style.copy()
child_name = child_node.name
if node_attr and child_node.get_attr(node_attr):
_node_style.update(child_node.get_attr(node_attr))
if edge_attr and child_node.get_attr(edge_attr):
_edge_style.update(child_node.get_attr(edge_attr))
pydot_child = pydot.Node(name=child_name, label=child_name, **_node_style)
_graph.add_node(pydot_child)
parent_name = parent_node.name
parent_node_style = node_style.copy()
if node_attr and parent_node.get_attr(node_attr):
parent_node_style.update(parent_node.get_attr(node_attr))
pydot_parent = pydot.Node(
name=parent_name, label=parent_name, **parent_node_style
)
_graph.add_node(pydot_parent)
edge = pydot.Edge(parent_name, child_name, **_edge_style)
_graph.add_edge(edge)
return _graph

View File

@@ -0,0 +1,3 @@
import os
ASSERTIONS: bool = bool(os.environ.get("BIGTREE_CONF_ASSERTIONS", True))

View File

@@ -0,0 +1,780 @@
from __future__ import annotations
import copy
from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar
from bigtree.globals import ASSERTIONS
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
from bigtree.utils.iterators import preorder_iter
class BaseNode:
"""
BaseNode extends any Python class to a tree node.
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
Nodes can be linked to each other with `parent` and `children` setter methods,
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
Examples:
>>> from bigtree import Node, print_tree
>>> root = Node("a", age=90)
>>> b = Node("b", age=65)
>>> c = Node("c", age=60)
>>> d = Node("d", age=40)
>>> root.children = [b, c]
>>> d.parent = b
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ └── d [age=40]
└── c [age=60]
>>> from bigtree import Node
>>> root = Node("a", age=90)
>>> b = Node("b", age=65)
>>> c = Node("c", age=60)
>>> d = Node("d", age=40)
>>> root >> b
>>> root >> c
>>> d << b
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ └── d [age=40]
└── c [age=60]
Directly passing `parent` argument.
>>> from bigtree import Node
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=b)
Directly passing `children` argument.
>>> from bigtree import Node
>>> d = Node("d")
>>> c = Node("c")
>>> b = Node("b", children=[d])
>>> a = Node("a", children=[b, c])
**BaseNode Creation**
Node can be created by instantiating a `BaseNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import Node
>>> root = Node.from_dict({"name": "a", "age": 90})
**BaseNode Attributes**
These are node attributes that have getter and/or setter methods.
Get and set other `BaseNode`
1. ``parent``: Get/set parent node
2. ``children``: Get/set child nodes
Get other `BaseNode`
1. ``ancestors``: Get ancestors of node excluding self, iterator
2. ``descendants``: Get descendants of node excluding self, iterator
3. ``leaves``: Get all leaf node(s) from self, iterator
4. ``siblings``: Get siblings of self
5. ``left_sibling``: Get sibling left of self
6. ``right_sibling``: Get sibling right of self
Get `BaseNode` configuration
1. ``node_path``: Get tuple of nodes from root
2. ``is_root``: Get indicator if self is root node
3. ``is_leaf``: Get indicator if self is leaf node
4. ``root``: Get root node of tree
5. ``depth``: Get depth of self
6. ``max_depth``: Get maximum depth from root to leaf node
**BaseNode Methods**
These are methods available to be performed on `BaseNode`.
Constructor methods
1. ``from_dict()``: Create BaseNode from dictionary
`BaseNode` methods
1. ``describe()``: Get node information sorted by attributes, return list of tuples
2. ``get_attr(attr_name: str)``: Get value of node attribute
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
4. ``go_to(node: Self)``: Get a path from own node to another node from same tree
5. ``append(node: Self)``: Add child to node
6. ``extend(nodes: List[Self])``: Add multiple children to node
7. ``copy()``: Deep copy self
8. ``sort()``: Sort child nodes
----
"""
def __init__(
self,
parent: Optional[T] = None,
children: Optional[List[T]] = None,
**kwargs: Any,
):
self.__parent: Optional[T] = None
self.__children: List[T] = []
if children is None:
children = []
self.parent = parent
self.children = children # type: ignore
if "parents" in kwargs:
raise AttributeError(
"Attempting to set `parents` attribute, do you mean `parent`?"
)
self.__dict__.update(**kwargs)
@staticmethod
def __check_parent_type(new_parent: T) -> None:
"""Check parent type
Args:
new_parent (Self): parent node
"""
if not (isinstance(new_parent, BaseNode) or new_parent is None):
raise TypeError(
f"Expect parent to be BaseNode type or NoneType, received input type {type(new_parent)}"
)
def __check_parent_loop(self, new_parent: T) -> None:
"""Check parent type
Args:
new_parent (Self): parent node
"""
if new_parent is not None:
if new_parent is self:
raise LoopError("Error setting parent: Node cannot be parent of itself")
if any(
ancestor is self
for ancestor in new_parent.ancestors
if new_parent.ancestors
):
raise LoopError(
"Error setting parent: Node cannot be ancestor of itself"
)
@property
def parent(self: T) -> Optional[T]:
"""Get parent node
Returns:
(Optional[Self])
"""
return self.__parent
@parent.setter
def parent(self: T, new_parent: T) -> None:
"""Set parent node
Args:
new_parent (Self): parent node
"""
if ASSERTIONS:
self.__check_parent_type(new_parent)
self.__check_parent_loop(new_parent)
current_parent = self.parent
current_child_idx = None
# Assign new parent - rollback if error
self.__pre_assign_parent(new_parent)
try:
# Remove self from old parent
if current_parent is not None:
if not any(
child is self for child in current_parent.children
): # pragma: no cover
raise CorruptedTreeError(
"Error setting parent: Node does not exist as children of its parent"
)
current_child_idx = current_parent.__children.index(self)
current_parent.__children.remove(self)
# Assign self to new parent
self.__parent = new_parent
if new_parent is not None:
new_parent.__children.append(self)
self.__post_assign_parent(new_parent)
except Exception as exc_info:
# Remove self from new parent
if new_parent is not None:
new_parent.__children.remove(self)
# Reassign self to old parent
self.__parent = current_parent
if current_child_idx is not None:
current_parent.__children.insert(current_child_idx, self)
raise TreeError(exc_info)
def __pre_assign_parent(self, new_parent: T) -> None:
"""Custom method to check before attaching parent
Can be overridden with `_BaseNode__pre_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def __post_assign_parent(self, new_parent: T) -> None:
"""Custom method to check after attaching parent
Can be overridden with `_BaseNode__post_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
@property
def parents(self) -> None:
"""Do not allow `parents` attribute to be accessed
Raises:
AttributeError: No such attribute
"""
raise AttributeError(
"Attempting to access `parents` attribute, do you mean `parent`?"
)
@parents.setter
def parents(self, new_parent: T) -> None:
"""Do not allow `parents` attribute to be set
Args:
new_parent (Self): parent node
Raises:
AttributeError: No such attribute
"""
raise AttributeError(
"Attempting to set `parents` attribute, do you mean `parent`?"
)
def __check_children_type(
self: T, new_children: List[T] | Tuple[T] | Set[T]
) -> None:
"""Check child type
Args:
new_children (Iterable[Self]): child node
"""
if (
not isinstance(new_children, list)
and not isinstance(new_children, tuple)
and not isinstance(new_children, set)
):
raise TypeError(
f"Expect children to be List or Tuple or Set type, received input type {type(new_children)}"
)
def __check_children_loop(self: T, new_children: Iterable[T]) -> None:
"""Check child loop
Args:
new_children (Iterable[Self]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if not isinstance(new_child, BaseNode):
raise TypeError(
f"Expect children to be BaseNode type, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestor of itself"
)
# Check for duplicate children
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@property
def children(self: T) -> Tuple[T, ...]:
"""Get child nodes
Returns:
(Tuple[Self, ...])
"""
return tuple(self.__children)
@children.setter
def children(self: T, new_children: List[T] | Tuple[T] | Set[T]) -> None:
"""Set child nodes
Args:
new_children (List[Self]): child node
"""
if ASSERTIONS:
self.__check_children_type(new_children)
self.__check_children_loop(new_children)
new_children = list(new_children)
current_new_children = {
new_child: (new_child.parent.__children.index(new_child), new_child.parent)
for new_child in new_children
if new_child.parent is not None
}
current_new_orphan = [
new_child for new_child in new_children if new_child.parent is None
]
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Remove old children from self
del self.children
# Assign new children to self
self.__children = new_children
for new_child in new_children:
if new_child.parent:
new_child.parent.__children.remove(new_child)
new_child.__parent = self
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign new children to their original parent
for child, idx_parent in current_new_children.items():
child_idx, parent = idx_parent
child.__parent = parent
parent.__children.insert(child_idx, child)
for child in current_new_orphan:
child.__parent = None
# Reassign old children to self
self.__children = current_children
for child in current_children:
child.__parent = self
raise TreeError(exc_info)
@children.deleter
def children(self) -> None:
"""Delete child node(s)"""
for child in self.children:
child.parent.__children.remove(child) # type: ignore
child.__parent = None
def __pre_assign_children(self: T, new_children: Iterable[T]) -> None:
"""Custom method to check before attaching children
Can be overridden with `_BaseNode__pre_assign_children()`
Args:
new_children (Iterable[Self]): new children to be added
"""
pass
def __post_assign_children(self: T, new_children: Iterable[T]) -> None:
"""Custom method to check after attaching children
Can be overridden with `_BaseNode__post_assign_children()`
Args:
new_children (Iterable[Self]): new children to be added
"""
pass
@property
def ancestors(self: T) -> Iterable[T]:
"""Get iterator to yield all ancestors of self, does not include self
Returns:
(Iterable[Self])
"""
node = self.parent
while node is not None:
yield node
node = node.parent
@property
def descendants(self: T) -> Iterable[T]:
"""Get iterator to yield all descendants of self, does not include self
Returns:
(Iterable[Self])
"""
yield from preorder_iter(self, filter_condition=lambda _node: _node != self)
@property
def leaves(self: T) -> Iterable[T]:
"""Get iterator to yield all leaf nodes from self
Returns:
(Iterable[Self])
"""
yield from preorder_iter(self, filter_condition=lambda _node: _node.is_leaf)
@property
def siblings(self: T) -> Iterable[T]:
"""Get siblings of self
Returns:
(Iterable[Self])
"""
if self.parent is None:
return ()
return tuple(child for child in self.parent.children if child is not self)
@property
def left_sibling(self: T) -> T:
"""Get sibling left of self
Returns:
(Self)
"""
if self.parent:
children = self.parent.children
child_idx = children.index(self)
if child_idx:
return self.parent.children[child_idx - 1]
@property
def right_sibling(self: T) -> T:
"""Get sibling right of self
Returns:
(Self)
"""
if self.parent:
children = self.parent.children
child_idx = children.index(self)
if child_idx + 1 < len(children):
return self.parent.children[child_idx + 1]
@property
def node_path(self: T) -> Iterable[T]:
"""Get tuple of nodes starting from root
Returns:
(Iterable[Self])
"""
if self.parent is None:
return [self]
return tuple(list(self.parent.node_path) + [self])
@property
def is_root(self) -> bool:
"""Get indicator if self is root node
Returns:
(bool)
"""
return self.parent is None
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len(list(self.children))
@property
def root(self: T) -> T:
"""Get root node of tree
Returns:
(Self)
"""
if self.parent is None:
return self
return self.parent.root
@property
def depth(self) -> int:
"""Get depth of self, indexing starts from 1
Returns:
(int)
"""
if self.parent is None:
return 1
return self.parent.depth + 1
@property
def max_depth(self) -> int:
"""Get maximum depth from root to leaf node
Returns:
(int)
"""
return max(
[self.root.depth] + [node.depth for node in list(self.root.descendants)]
)
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]) -> BaseNode:
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
Input dictionary must have key `name` if not `Node` will not have any name
Examples:
>>> from bigtree import Node
>>> a = Node.from_dict({"name": "a", "age": 90})
Args:
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
Returns:
(BaseNode)
"""
return cls(**input_dict)
def describe(
self, exclude_attributes: List[str] = [], exclude_prefix: str = ""
) -> List[Tuple[str, Any]]:
"""Get node information sorted by attribute name, returns list of tuples
Examples:
>>> from bigtree.node.node import Node
>>> a = Node('a', age=90)
>>> a.describe()
[('_BaseNode__children', []), ('_BaseNode__parent', None), ('_sep', '/'), ('age', 90), ('name', 'a')]
>>> a.describe(exclude_prefix="_")
[('age', 90), ('name', 'a')]
>>> a.describe(exclude_prefix="_", exclude_attributes=["name"])
[('age', 90)]
Args:
exclude_attributes (List[str]): list of attributes to exclude
exclude_prefix (str): prefix of attributes to exclude
Returns:
(List[Tuple[str, Any]])
"""
return [
item
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
if (item[0] not in exclude_attributes)
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
]
def get_attr(self, attr_name: str, default_value: Any = None) -> Any:
"""Get value of node attribute
Returns default value if attribute name does not exist
Examples:
>>> from bigtree.node.node import Node
>>> a = Node('a', age=90)
>>> a.get_attr("age")
90
Args:
attr_name (str): attribute name
default_value (Any): default value if attribute does not exist, defaults to None
Returns:
(Any)
"""
try:
return getattr(self, attr_name)
except AttributeError:
return default_value
def set_attrs(self, attrs: Dict[str, Any]) -> None:
"""Set node attributes
Examples:
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a.set_attrs({"age": 90})
>>> a
Node(/a, age=90)
Args:
attrs (Dict[str, Any]): attribute dictionary,
key: attribute name, value: attribute value
"""
self.__dict__.update(attrs)
def go_to(self: T, node: T) -> Iterable[T]:
"""Get path from current node to specified node from same tree
Examples:
>>> from bigtree import Node, print_tree
>>> a = Node(name="a")
>>> b = Node(name="b", parent=a)
>>> c = Node(name="c", parent=a)
>>> d = Node(name="d", parent=b)
>>> e = Node(name="e", parent=b)
>>> f = Node(name="f", parent=c)
>>> g = Node(name="g", parent=e)
>>> h = Node(name="h", parent=e)
>>> print_tree(a)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> d.go_to(d)
[Node(/a/b/d, )]
>>> d.go_to(g)
[Node(/a/b/d, ), Node(/a/b, ), Node(/a/b/e, ), Node(/a/b/e/g, )]
>>> d.go_to(f)
[Node(/a/b/d, ), Node(/a/b, ), Node(/a, ), Node(/a/c, ), Node(/a/c/f, )]
Args:
node (Self): node to travel to from current node, inclusive of start and end node
Returns:
(Iterable[Self])
"""
if not isinstance(node, BaseNode):
raise TypeError(
f"Expect node to be BaseNode type, received input type {type(node)}"
)
if self.root != node.root:
raise TreeError(
f"Nodes are not from the same tree. Check {self} and {node}"
)
if self == node:
return [self]
self_path = [self] + list(self.ancestors)
node_path = ([node] + list(node.ancestors))[::-1]
common_nodes = set(self_path).intersection(set(node_path))
self_min_index, min_common_node = sorted(
[(self_path.index(_node), _node) for _node in common_nodes]
)[0]
node_min_index = node_path.index(min_common_node)
return self_path[:self_min_index] + node_path[node_min_index:]
def append(self: T, other: T) -> None:
"""Add other as child of self
Args:
other (Self): other node, child to be added
"""
other.parent = self
def extend(self: T, others: List[T]) -> None:
"""Add others as children of self
Args:
others (Self): other nodes, children to be added
"""
for child in others:
child.parent = self
def copy(self: T) -> T:
"""Deep copy self; clone self
Examples:
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a_copy = a.copy()
Returns:
(Self)
"""
return copy.deepcopy(self)
def sort(self: T, **kwargs: Any) -> None:
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
Examples:
>>> from bigtree import Node, print_tree
>>> a = Node('a')
>>> c = Node("c", parent=a)
>>> b = Node("b", parent=a)
>>> print_tree(a)
a
├── c
└── b
>>> a.sort(key=lambda node: node.name)
>>> print_tree(a)
a
├── b
└── c
"""
children = list(self.children)
children.sort(**kwargs)
self.__children = children
def __copy__(self: T) -> T:
"""Shallow copy self
Examples:
>>> import copy
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a_copy = copy.deepcopy(a)
Returns:
(Self)
"""
obj: T = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def __repr__(self) -> str:
"""Print format of BaseNode
Returns:
(str)
"""
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_")
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({node_description})"
def __rshift__(self: T, other: T) -> None:
"""Set children using >> bitshift operator for self >> children (other)
Args:
other (Self): other node, children
"""
other.parent = self
def __lshift__(self: T, other: T) -> None:
"""Set parent using << bitshift operator for self << parent (other)
Args:
other (Self): other node, parent
"""
self.parent = other
def __iter__(self) -> Generator[T, None, None]:
"""Iterate through child nodes
Returns:
(Self): child node
"""
yield from self.children # type: ignore
def __contains__(self, other_node: T) -> bool:
"""Check if child node exists
Args:
other_node (T): child node
Returns:
(bool)
"""
return other_node in self.children
T = TypeVar("T", bound=BaseNode)

View File

@@ -0,0 +1,418 @@
from __future__ import annotations
from typing import Any, List, Optional, Tuple, TypeVar, Union
from bigtree.globals import ASSERTIONS
from bigtree.node.node import Node
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
class BinaryNode(Node):
"""
BinaryNode is an extension of Node, and is able to extend to any Python class for Binary Tree implementation.
Nodes can have attributes if they are initialized from `BinaryNode`, *dictionary*, or *pandas DataFrame*.
BinaryNode can be linked to each other with `children`, `left`, or `right` setter methods.
If initialized with `children`, it must be length 2, denoting left and right child.
Examples:
>>> from bigtree import BinaryNode, print_tree
>>> a = BinaryNode(1)
>>> b = BinaryNode(2)
>>> c = BinaryNode(3)
>>> d = BinaryNode(4)
>>> a.children = [b, c]
>>> b.right = d
>>> print_tree(a)
1
├── 2
│ └── 4
└── 3
Directly passing `left`, `right`, or `children` argument.
>>> from bigtree import BinaryNode
>>> d = BinaryNode(4)
>>> c = BinaryNode(3)
>>> b = BinaryNode(2, right=d)
>>> a = BinaryNode(1, children=[b, c])
**BinaryNode Creation**
Node can be created by instantiating a `BinaryNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import BinaryNode
>>> a = BinaryNode.from_dict({"name": "1"})
>>> a
BinaryNode(name=1, val=1)
**BinaryNode Attributes**
These are node attributes that have getter and/or setter methods.
Get `BinaryNode` configuration
1. ``left``: Get left children
2. ``right``: Get right children
----
"""
def __init__(
self,
name: Union[str, int] = "",
left: Optional[T] = None,
right: Optional[T] = None,
parent: Optional[T] = None,
children: Optional[List[Optional[T]]] = None,
**kwargs: Any,
):
try:
self.val: Union[str, int] = int(name)
except ValueError:
self.val = str(name)
self.name = str(name)
self._sep = "/"
self.__parent: Optional[T] = None
self.__children: List[Optional[T]] = [None, None]
if not children:
children = []
if len(children):
if len(children) and len(children) != 2:
raise ValueError("Children input must have length 2")
if left and left != children[0]:
raise ValueError(
f"Error setting child: Attempting to set both left and children with mismatched values\n"
f"Check left {left} and children {children}"
)
if right and right != children[1]:
raise ValueError(
f"Error setting child: Attempting to set both right and children with mismatched values\n"
f"Check right {right} and children {children}"
)
else:
children = [left, right]
self.parent = parent
self.children = children # type: ignore
if "parents" in kwargs:
raise AttributeError(
"Attempting to set `parents` attribute, do you mean `parent`?"
)
self.__dict__.update(**kwargs)
@property
def left(self: T) -> T:
"""Get left children
Returns:
(Self)
"""
return self.__children[0]
@left.setter
def left(self: T, left_child: Optional[T]) -> None:
"""Set left children
Args:
left_child (Optional[Self]): left child
"""
self.children = [left_child, self.right] # type: ignore
@property
def right(self: T) -> T:
"""Get right children
Returns:
(Self)
"""
return self.__children[1]
@right.setter
def right(self: T, right_child: Optional[T]) -> None:
"""Set right children
Args:
right_child (Optional[Self]): right child
"""
self.children = [self.left, right_child] # type: ignore
@staticmethod
def __check_parent_type(new_parent: Optional[T]) -> None:
"""Check parent type
Args:
new_parent (Optional[Self]): parent node
"""
if not (isinstance(new_parent, BinaryNode) or new_parent is None):
raise TypeError(
f"Expect parent to be BinaryNode type or NoneType, received input type {type(new_parent)}"
)
@property
def parent(self: T) -> Optional[T]:
"""Get parent node
Returns:
(Optional[Self])
"""
return self.__parent
@parent.setter
def parent(self: T, new_parent: Optional[T]) -> None:
"""Set parent node
Args:
new_parent (Optional[Self]): parent node
"""
if ASSERTIONS:
self.__check_parent_type(new_parent)
self._BaseNode__check_parent_loop(new_parent) # type: ignore
current_parent = self.parent
current_child_idx = None
# Assign new parent - rollback if error
self.__pre_assign_parent(new_parent)
try:
# Remove self from old parent
if current_parent is not None:
if not any(
child is self for child in current_parent.children
): # pragma: no cover
raise CorruptedTreeError(
"Error setting parent: Node does not exist as children of its parent"
)
current_child_idx = current_parent.__children.index(self)
current_parent.__children[current_child_idx] = None
# Assign self to new parent
self.__parent = new_parent
if new_parent is not None:
inserted = False
for child_idx, child in enumerate(new_parent.__children):
if not child and not inserted:
new_parent.__children[child_idx] = self
inserted = True
if not inserted:
raise TreeError(f"Parent {new_parent} already has 2 children")
self.__post_assign_parent(new_parent)
except Exception as exc_info:
# Remove self from new parent
if new_parent is not None and self in new_parent.__children:
child_idx = new_parent.__children.index(self)
new_parent.__children[child_idx] = None
# Reassign self to old parent
self.__parent = current_parent
if current_child_idx is not None:
current_parent.__children[current_child_idx] = self
raise TreeError(exc_info)
def __pre_assign_parent(self: T, new_parent: Optional[T]) -> None:
"""Custom method to check before attaching parent
Can be overridden with `_BinaryNode__pre_assign_parent()`
Args:
new_parent (Optional[Self]): new parent to be added
"""
pass
def __post_assign_parent(self: T, new_parent: Optional[T]) -> None:
"""Custom method to check after attaching parent
Can be overridden with `_BinaryNode__post_assign_parent()`
Args:
new_parent (Optional[Self]): new parent to be added
"""
pass
def __check_children_type(
self: T, new_children: List[Optional[T]]
) -> List[Optional[T]]:
"""Check child type
Args:
new_children (List[Optional[Self]]): child node
Returns:
(List[Optional[Self]])
"""
if not len(new_children):
new_children = [None, None]
if len(new_children) != 2:
raise ValueError("Children input must have length 2")
return new_children
def __check_children_loop(self: T, new_children: List[Optional[T]]) -> None:
"""Check child loop
Args:
new_children (List[Optional[Self]]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if new_child is not None and not isinstance(new_child, BinaryNode):
raise TypeError(
f"Expect children to be BinaryNode type or NoneType, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestor of itself"
)
# Check for duplicate children
if new_child is not None:
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@property
def children(self: T) -> Tuple[T, ...]:
"""Get child nodes
Returns:
(Tuple[Optional[Self]])
"""
return tuple(self.__children)
@children.setter
def children(self: T, _new_children: List[Optional[T]]) -> None:
"""Set child nodes
Args:
_new_children (List[Optional[Self]]): child node
"""
self._BaseNode__check_children_type(_new_children) # type: ignore
new_children = self.__check_children_type(_new_children)
if ASSERTIONS:
self.__check_children_loop(new_children)
current_new_children = {
new_child: (
new_child.parent.__children.index(new_child),
new_child.parent,
)
for new_child in new_children
if new_child is not None and new_child.parent is not None
}
current_new_orphan = [
new_child
for new_child in new_children
if new_child is not None and new_child.parent is None
]
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Remove old children from self
del self.children
# Assign new children to self
self.__children = new_children
for new_child in new_children:
if new_child is not None:
if new_child.parent:
child_idx = new_child.parent.__children.index(new_child)
new_child.parent.__children[child_idx] = None
new_child.__parent = self
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign new children to their original parent
for child, idx_parent in current_new_children.items():
child_idx, parent = idx_parent
child.__parent = parent
parent.__children[child_idx] = child
for child in current_new_orphan:
child.__parent = None
# Reassign old children to self
self.__children = current_children
for child in current_children:
if child:
child.__parent = self
raise TreeError(exc_info)
@children.deleter
def children(self) -> None:
"""Delete child node(s)"""
for child in self.children:
if child is not None:
child.parent.__children.remove(child) # type: ignore
child.__parent = None
def __pre_assign_children(self: T, new_children: List[Optional[T]]) -> None:
"""Custom method to check before attaching children
Can be overridden with `_BinaryNode__pre_assign_children()`
Args:
new_children (List[Optional[Self]]): new children to be added
"""
pass
def __post_assign_children(self: T, new_children: List[Optional[T]]) -> None:
"""Custom method to check after attaching children
Can be overridden with `_BinaryNode__post_assign_children()`
Args:
new_children (List[Optional[Self]]): new children to be added
"""
pass
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len([child for child in self.children if child])
def sort(self, **kwargs: Any) -> None:
"""Sort children, possible keyword arguments include ``key=lambda node: node.val``, ``reverse=True``
Examples:
>>> from bigtree import BinaryNode, print_tree
>>> a = BinaryNode(1)
>>> c = BinaryNode(3, parent=a)
>>> b = BinaryNode(2, parent=a)
>>> print_tree(a)
1
├── 3
└── 2
>>> a.sort(key=lambda node: node.val)
>>> print_tree(a)
1
├── 2
└── 3
"""
children = [child for child in self.children if child]
if len(children) == 2:
children.sort(**kwargs)
self.__children = children # type: ignore
def __repr__(self) -> str:
"""Print format of BinaryNode
Returns:
(str)
"""
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_", exclude_attributes=[])
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({node_description})"
T = TypeVar("T", bound=BinaryNode)

View File

@@ -0,0 +1,672 @@
from __future__ import annotations
import copy
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar
from bigtree.globals import ASSERTIONS
from bigtree.utils.exceptions import LoopError, TreeError
from bigtree.utils.iterators import preorder_iter
class DAGNode:
"""
Base DAGNode extends any Python class to a DAG node, for DAG implementation.
In DAG implementation, a node can have multiple parents.
Parents and children cannot be reassigned once assigned, as Nodes are allowed to have multiple parents and children.
If each node only has one parent, use `Node` class.
DAGNodes can have attributes if they are initialized from `DAGNode` or dictionary.
DAGNode can be linked to each other with `parents` and `children` setter methods,
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
Examples:
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> c.parents = [a, b]
>>> c.children = [d]
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> a >> c
>>> b >> c
>>> d << c
Directly passing `parents` argument.
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c", parents=[a, b])
>>> d = DAGNode("d", parents=[c])
Directly passing `children` argument.
>>> from bigtree import DAGNode
>>> d = DAGNode("d")
>>> c = DAGNode("c", children=[d])
>>> b = DAGNode("b", children=[c])
>>> a = DAGNode("a", children=[c])
**DAGNode Creation**
Node can be created by instantiating a `DAGNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import DAGNode
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
**DAGNode Attributes**
These are node attributes that have getter and/or setter methods.
Get and set other `DAGNode`
1. ``parents``: Get/set parent nodes
2. ``children``: Get/set child nodes
Get other `DAGNode`
1. ``ancestors``: Get ancestors of node excluding self, iterator
2. ``descendants``: Get descendants of node excluding self, iterator
3. ``siblings``: Get siblings of self
Get `DAGNode` configuration
1. ``node_name``: Get node name, without accessing `name` directly
2. ``is_root``: Get indicator if self is root node
3. ``is_leaf``: Get indicator if self is leaf node
**DAGNode Methods**
These are methods available to be performed on `DAGNode`.
Constructor methods
1. ``from_dict()``: Create DAGNode from dictionary
`DAGNode` methods
1. ``describe()``: Get node information sorted by attributes, return list of tuples
2. ``get_attr(attr_name: str)``: Get value of node attribute
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
4. ``go_to(node: Self)``: Get a path from own node to another node from same DAG
5. ``copy()``: Deep copy self
----
"""
def __init__(
self,
name: str = "",
parents: Optional[List[T]] = None,
children: Optional[List[T]] = None,
**kwargs: Any,
):
self.name = name
self.__parents: List[T] = []
self.__children: List[T] = []
if parents is None:
parents = []
if children is None:
children = []
self.parents = parents
self.children = children
if "parent" in kwargs:
raise AttributeError(
"Attempting to set `parent` attribute, do you mean `parents`?"
)
self.__dict__.update(**kwargs)
@property
def parent(self) -> None:
"""Do not allow `parent` attribute to be accessed
Raises:
AttributeError: No such attribute
"""
raise AttributeError(
"Attempting to access `parent` attribute, do you mean `parents`?"
)
@parent.setter
def parent(self, new_parent: T) -> None:
"""Do not allow `parent` attribute to be set
Args:
new_parent (Self): parent node
Raises:
AttributeError
"""
raise AttributeError(
"Attempting to set `parent` attribute, do you mean `parents`?"
)
@staticmethod
def __check_parent_type(new_parents: List[T]) -> None:
"""Check parent type
Args:
new_parents (List[Self]): parent nodes
"""
if not isinstance(new_parents, list):
raise TypeError(
f"Parents input should be list type, received input type {type(new_parents)}"
)
def __check_parent_loop(self: T, new_parents: List[T]) -> None:
"""Check parent type
Args:
new_parents (List[Self]): parent nodes
"""
seen_parent = []
for new_parent in new_parents:
# Check type
if not isinstance(new_parent, DAGNode):
raise TypeError(
f"Expect parent to be DAGNode type, received input type {type(new_parent)}"
)
# Check for loop and tree structure
if new_parent is self:
raise LoopError("Error setting parent: Node cannot be parent of itself")
if new_parent.ancestors:
if any(ancestor is self for ancestor in new_parent.ancestors):
raise LoopError(
"Error setting parent: Node cannot be ancestor of itself"
)
# Check for duplicate children
if id(new_parent) in seen_parent:
raise TreeError(
"Error setting parent: Node cannot be added multiple times as a parent"
)
else:
seen_parent.append(id(new_parent))
@property
def parents(self: T) -> Iterable[T]:
"""Get parent nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__parents)
@parents.setter
def parents(self: T, new_parents: List[T]) -> None:
"""Set parent node
Args:
new_parents (List[Self]): parent nodes
"""
if ASSERTIONS:
self.__check_parent_type(new_parents)
self.__check_parent_loop(new_parents)
current_parents = self.__parents.copy()
# Assign new parents - rollback if error
self.__pre_assign_parents(new_parents)
try:
# Assign self to new parent
for new_parent in new_parents:
if new_parent not in self.__parents:
self.__parents.append(new_parent)
new_parent.__children.append(self)
self.__post_assign_parents(new_parents)
except Exception as exc_info:
# Remove self from new parent
for new_parent in new_parents:
if new_parent not in current_parents:
self.__parents.remove(new_parent)
new_parent.__children.remove(self)
raise TreeError(exc_info)
def __pre_assign_parents(self: T, new_parents: List[T]) -> None:
"""Custom method to check before attaching parent
Can be overridden with `_DAGNode__pre_assign_parent()`
Args:
new_parents (List[Self]): new parents to be added
"""
pass
def __post_assign_parents(self: T, new_parents: List[T]) -> None:
"""Custom method to check after attaching parent
Can be overridden with `_DAGNode__post_assign_parent()`
Args:
new_parents (List[Self]): new parents to be added
"""
pass
def __check_children_type(self: T, new_children: Iterable[T]) -> None:
"""Check child type
Args:
new_children (Iterable[Self]): child node
"""
if not isinstance(new_children, Iterable):
raise TypeError(
f"Expect children to be Iterable type, received input type {type(new_children)}"
)
def __check_children_loop(self: T, new_children: Iterable[T]) -> None:
"""Check child loop
Args:
new_children (Iterable[Self]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if not isinstance(new_child, DAGNode):
raise TypeError(
f"Expect children to be DAGNode type, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestor of itself"
)
# Check for duplicate children
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@property
def children(self: T) -> Iterable[T]:
"""Get child nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__children)
@children.setter
def children(self: T, new_children: Iterable[T]) -> None:
"""Set child nodes
Args:
new_children (Iterable[Self]): child node
"""
if ASSERTIONS:
self.__check_children_type(new_children)
self.__check_children_loop(new_children)
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Assign new children to self
for new_child in new_children:
if self not in new_child.__parents:
new_child.__parents.append(self)
self.__children.append(new_child)
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign old children to self
for new_child in new_children:
if new_child not in current_children:
new_child.__parents.remove(self)
self.__children.remove(new_child)
raise TreeError(exc_info)
@children.deleter
def children(self) -> None:
"""Delete child node(s)"""
for child in self.children:
self.__children.remove(child) # type: ignore
child.__parents.remove(self) # type: ignore
def __pre_assign_children(self: T, new_children: Iterable[T]) -> None:
"""Custom method to check before attaching children
Can be overridden with `_DAGNode__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self: T, new_children: Iterable[T]) -> None:
"""Custom method to check after attaching children
Can be overridden with `_DAGNode__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
@property
def ancestors(self: T) -> Iterable[T]:
"""Get iterator to yield all ancestors of self, does not include self
Returns:
(Iterable[Self])
"""
if not len(list(self.parents)):
return ()
def _recursive_parent(node: T) -> Iterable[T]:
"""Recursively yield parent of current node, returns earliest to latest ancestor
Args:
node (DAGNode): current node
Returns:
(Iterable[DAGNode])
"""
for _node in node.parents:
yield from _recursive_parent(_node)
yield _node
ancestors = list(_recursive_parent(self))
return list(dict.fromkeys(ancestors))
@property
def descendants(self: T) -> Iterable[T]:
"""Get iterator to yield all descendants of self, does not include self
Returns:
(Iterable[Self])
"""
descendants = preorder_iter(self, filter_condition=lambda _node: _node != self)
return list(dict.fromkeys(descendants))
@property
def siblings(self: T) -> Iterable[T]:
"""Get siblings of self
Returns:
(Iterable[Self])
"""
if self.is_root:
return ()
return tuple(
child
for parent in self.parents
for child in parent.children
if child is not self
)
@property
def node_name(self) -> str:
"""Get node name
Returns:
(str)
"""
return self.name
@property
def is_root(self) -> bool:
"""Get indicator if self is root node
Returns:
(bool)
"""
return not len(list(self.parents))
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len(list(self.children))
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]) -> DAGNode:
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
Input dictionary must have key `name` if not `Node` will not have any name
Examples:
>>> from bigtree import DAGNode
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
Args:
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
Returns:
(DAGNode)
"""
return cls(**input_dict)
def describe(
self, exclude_attributes: List[str] = [], exclude_prefix: str = ""
) -> List[Tuple[str, Any]]:
"""Get node information sorted by attribute name, returns list of tuples
Args:
exclude_attributes (List[str]): list of attributes to exclude
exclude_prefix (str): prefix of attributes to exclude
Returns:
(List[Tuple[str, Any]])
"""
return [
item
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
if (item[0] not in exclude_attributes)
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
]
def get_attr(self, attr_name: str, default_value: Any = None) -> Any:
"""Get value of node attribute
Returns default value if attribute name does not exist
Args:
attr_name (str): attribute name
default_value (Any): default value if attribute does not exist, defaults to None
Returns:
(Any)
"""
try:
return getattr(self, attr_name)
except AttributeError:
return default_value
def set_attrs(self, attrs: Dict[str, Any]) -> None:
"""Set node attributes
Examples:
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a.set_attrs({"age": 90})
>>> a
DAGNode(a, age=90)
Args:
attrs (Dict[str, Any]): attribute dictionary,
key: attribute name, value: attribute value
"""
self.__dict__.update(attrs)
def go_to(self: T, node: T) -> List[List[T]]:
"""Get list of possible paths from current node to specified node from same tree
Examples:
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> a >> c
>>> b >> c
>>> c >> d
>>> a >> d
>>> a.go_to(c)
[[DAGNode(a, ), DAGNode(c, )]]
>>> a.go_to(d)
[[DAGNode(a, ), DAGNode(c, ), DAGNode(d, )], [DAGNode(a, ), DAGNode(d, )]]
>>> a.go_to(b)
Traceback (most recent call last):
...
bigtree.utils.exceptions.TreeError: It is not possible to go to DAGNode(b, )
Args:
node (Self): node to travel to from current node, inclusive of start and end node
Returns:
(List[List[Self]])
"""
if not isinstance(node, DAGNode):
raise TypeError(
f"Expect node to be DAGNode type, received input type {type(node)}"
)
if self == node:
return [[self]]
if node not in self.descendants:
raise TreeError(f"It is not possible to go to {node}")
self.__path: List[List[T]] = []
def _recursive_path(_node: T, _path: List[T]) -> Optional[List[T]]:
"""Get path to specified node
Args:
_node (DAGNode): current node
_path (List[DAGNode]): current path, from start node to current node, excluding current node
Returns:
(List[DAGNode])
"""
if _node: # pragma: no cover
_path.append(_node)
if _node == node:
return _path
for _child in _node.children:
ans = _recursive_path(_child, _path.copy())
if ans:
self.__path.append(ans)
return None
_recursive_path(self, [])
return self.__path
def copy(self: T) -> T:
"""Deep copy self; clone DAGNode
Examples:
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a_copy = a.copy()
Returns:
(Self)
"""
return copy.deepcopy(self)
def __copy__(self: T) -> T:
"""Shallow copy self
Examples:
>>> import copy
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a_copy = copy.deepcopy(a)
Returns:
(Self)
"""
obj: T = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def __getitem__(self, child_name: str) -> T:
"""Get child by name identifier
Args:
child_name (str): name of child node
Returns:
(Self): child node
"""
from bigtree.tree.search import find_child_by_name
return find_child_by_name(self, child_name) # type: ignore
def __delitem__(self, child_name: str) -> None:
"""Delete child by name identifier, will not throw error if child does not exist
Args:
child_name (str): name of child node
"""
from bigtree.tree.search import find_child_by_name
child = find_child_by_name(self, child_name)
if child:
self.__children.remove(child) # type: ignore
child.__parents.remove(self) # type: ignore
def __repr__(self) -> str:
"""Print format of DAGNode
Returns:
(str)
"""
class_name = self.__class__.__name__
node_dict = self.describe(exclude_attributes=["name"])
node_description = ", ".join(
[f"{k}={v}" for k, v in node_dict if not k.startswith("_")]
)
return f"{class_name}({self.node_name}, {node_description})"
def __rshift__(self: T, other: T) -> None:
"""Set children using >> bitshift operator for self >> children (other)
Args:
other (Self): other node, children
"""
other.parents = [self]
def __lshift__(self: T, other: T) -> None:
"""Set parent using << bitshift operator for self << parent (other)
Args:
other (Self): other node, parent
"""
self.parents = [other]
def __iter__(self) -> Generator[T, None, None]:
"""Iterate through child nodes
Returns:
(Self): child node
"""
yield from self.children # type: ignore
def __contains__(self, other_node: T) -> bool:
"""Check if child node exists
Args:
other_node (T): child node
Returns:
(bool)
"""
return other_node in self.children
T = TypeVar("T", bound=DAGNode)

View File

@@ -0,0 +1,261 @@
from __future__ import annotations
from collections import Counter
from typing import Any, List, TypeVar
from bigtree.node.basenode import BaseNode
from bigtree.utils.exceptions import TreeError
class Node(BaseNode):
"""
Node is an extension of BaseNode, and is able to extend to any Python class.
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
Nodes can be linked to each other with `parent` and `children` setter methods.
Examples:
>>> from bigtree import Node
>>> a = Node("a")
>>> b = Node("b")
>>> c = Node("c")
>>> d = Node("d")
>>> b.parent = a
>>> b.children = [c, d]
Directly passing `parent` argument.
>>> from bigtree import Node
>>> a = Node("a")
>>> b = Node("b", parent=a)
>>> c = Node("c", parent=b)
>>> d = Node("d", parent=b)
Directly passing `children` argument.
>>> from bigtree import Node
>>> d = Node("d")
>>> c = Node("c")
>>> b = Node("b", children=[c, d])
>>> a = Node("a", children=[b])
**Node Creation**
Node can be created by instantiating a `Node` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import Node
>>> a = Node.from_dict({"name": "a", "age": 90})
**Node Attributes**
These are node attributes that have getter and/or setter methods.
Get and set `Node` configuration
1. ``sep``: Get/set separator for path name
Get `Node` configuration
1. ``node_name``: Get node name, without accessing `name` directly
2. ``path_name``: Get path name from root, separated by `sep`
**Node Methods**
These are methods available to be performed on `Node`.
`Node` methods
1. ``show()``: Print tree to console
2. ``hshow()``: Print tree in horizontal orientation to console
----
"""
def __init__(self, name: str = "", sep: str = "/", **kwargs: Any):
self.name = name
self._sep = sep
super().__init__(**kwargs)
if not self.node_name:
raise TreeError("Node must have a `name` attribute")
@property
def sep(self) -> str:
"""Get separator, gets from root node
Returns:
(str)
"""
if self.parent is None:
return self._sep
return self.parent.sep
@sep.setter
def sep(self, value: str) -> None:
"""Set separator, affects root node
Args:
value (str): separator to replace default separator
"""
self.root._sep = value
@property
def node_name(self) -> str:
"""Get node name
Returns:
(str)
"""
return self.name
@property
def path_name(self) -> str:
"""Get path name, separated by self.sep
Returns:
(str)
"""
ancestors = [self] + list(self.ancestors)
sep = ancestors[-1].sep
return sep + sep.join([str(node.name) for node in reversed(ancestors)])
def __pre_assign_children(self: T, new_children: List[T]) -> None:
"""Custom method to check before attaching children
Can be overridden with `_Node__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self: T, new_children: List[T]) -> None:
"""Custom method to check after attaching children
Can be overridden with `_Node__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __pre_assign_parent(self: T, new_parent: T) -> None:
"""Custom method to check before attaching parent
Can be overridden with `_Node__pre_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def __post_assign_parent(self: T, new_parent: T) -> None:
"""Custom method to check after attaching parent
Can be overridden with `_Node__post_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def _BaseNode__pre_assign_parent(self: T, new_parent: T) -> None:
"""Do not allow duplicate nodes of same path
Args:
new_parent (Self): new parent to be added
"""
self.__pre_assign_parent(new_parent)
if new_parent is not None:
if any(
child.node_name == self.node_name and child is not self
for child in new_parent.children
):
raise TreeError(
f"Duplicate node with same path\n"
f"There exist a node with same path {new_parent.path_name}{new_parent.sep}{self.node_name}"
)
def _BaseNode__post_assign_parent(self: T, new_parent: T) -> None:
"""No rules
Args:
new_parent (Self): new parent to be added
"""
self.__post_assign_parent(new_parent)
def _BaseNode__pre_assign_children(self: T, new_children: List[T]) -> None:
"""Do not allow duplicate nodes of same path
Args:
new_children (List[Self]): new children to be added
"""
self.__pre_assign_children(new_children)
children_names = [node.node_name for node in new_children]
duplicate_names = [
item[0] for item in Counter(children_names).items() if item[1] > 1
]
if len(duplicate_names):
duplicate_names_str = " and ".join(
[f"{self.path_name}{self.sep}{name}" for name in duplicate_names]
)
raise TreeError(
f"Duplicate node with same path\n"
f"Attempting to add nodes with same path {duplicate_names_str}"
)
def _BaseNode__post_assign_children(self: T, new_children: List[T]) -> None:
"""No rules
Args:
new_children (List[Self]): new children to be added
"""
self.__post_assign_children(new_children)
def show(self, **kwargs: Any) -> None:
"""Print tree to console, takes in same keyword arguments as `print_tree` function"""
from bigtree.tree.export import print_tree
print_tree(self, **kwargs)
def hshow(self, **kwargs: Any) -> None:
"""Print tree in horizontal orientation to console, takes in same keyword arguments as `hprint_tree` function"""
from bigtree.tree.export import hprint_tree
hprint_tree(self, **kwargs)
def __getitem__(self, child_name: str) -> T:
"""Get child by name identifier
Args:
child_name (str): name of child node
Returns:
(Self): child node
"""
from bigtree.tree.search import find_child_by_name
return find_child_by_name(self, child_name) # type: ignore
def __delitem__(self, child_name: str) -> None:
"""Delete child by name identifier, will not throw error if child does not exist
Args:
child_name (str): name of child node
"""
from bigtree.tree.search import find_child_by_name
child = find_child_by_name(self, child_name)
if child:
child.parent = None
def __repr__(self) -> str:
"""Print format of Node
Returns:
(str)
"""
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_", exclude_attributes=["name"])
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({self.path_name}, {node_description})"
T = TypeVar("T", bound=Node)

View File

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,415 @@
from collections import deque
from typing import Any, Deque, Dict, List, Set, Type, TypeVar, Union
from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
from bigtree.node.node import Node
from bigtree.tree.construct import add_dict_to_tree_by_path, dataframe_to_tree
from bigtree.tree.export import tree_to_dataframe
from bigtree.tree.search import find_path
from bigtree.utils.exceptions import NotFoundError
from bigtree.utils.iterators import levelordergroup_iter
__all__ = ["clone_tree", "get_subtree", "prune_tree", "get_tree_diff"]
BaseNodeT = TypeVar("BaseNodeT", bound=BaseNode)
BinaryNodeT = TypeVar("BinaryNodeT", bound=BinaryNode)
NodeT = TypeVar("NodeT", bound=Node)
def clone_tree(tree: BaseNode, node_type: Type[BaseNodeT]) -> BaseNodeT:
"""Clone tree to another ``Node`` type.
If the same type is needed, simply do a tree.copy().
Examples:
>>> from bigtree import BaseNode, Node, clone_tree
>>> root = BaseNode(name="a")
>>> b = BaseNode(name="b", parent=root)
>>> clone_tree(root, Node)
Node(/a, )
Args:
tree (BaseNode): tree to be cloned, must inherit from BaseNode
node_type (Type[BaseNode]): type of cloned tree
Returns:
(BaseNode)
"""
if not isinstance(tree, BaseNode):
raise TypeError("Tree should be of type `BaseNode`, or inherit from `BaseNode`")
# Start from root
root_info = dict(tree.root.describe(exclude_prefix="_"))
root_node = node_type(**root_info)
def _recursive_add_child(
_new_parent_node: BaseNodeT, _parent_node: BaseNode
) -> None:
"""Recursively clone current node
Args:
_new_parent_node (BaseNode): cloned parent node
_parent_node (BaseNode): parent node to be cloned
"""
for _child in _parent_node.children:
if _child:
child_info = dict(_child.describe(exclude_prefix="_"))
child_node = node_type(**child_info)
child_node.parent = _new_parent_node
_recursive_add_child(child_node, _child)
_recursive_add_child(root_node, tree.root)
return root_node
def get_subtree(
tree: NodeT,
node_name_or_path: str = "",
max_depth: int = 0,
) -> NodeT:
"""Get subtree based on node name or node path, and/or maximum depth of tree.
Examples:
>>> from bigtree import Node, get_subtree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=b)
>>> d = Node("d", parent=b)
>>> e = Node("e", parent=root)
>>> root.show()
a
├── b
│ ├── c
│ └── d
└── e
Get subtree
>>> root_subtree = get_subtree(root, "b")
>>> root_subtree.show()
b
├── c
└── d
Args:
tree (Node): existing tree
node_name_or_path (str): node name or path to get subtree, defaults to None
max_depth (int): maximum depth of subtree, based on `depth` attribute, defaults to None
Returns:
(Node)
"""
tree = tree.copy()
if node_name_or_path:
tree = find_path(tree, node_name_or_path)
if not tree:
raise ValueError(f"Node name or path {node_name_or_path} not found")
if not tree.is_root:
tree.parent = None
if max_depth:
tree = prune_tree(tree, max_depth=max_depth)
return tree
def prune_tree(
tree: Union[BinaryNodeT, NodeT],
prune_path: Union[List[str], str] = "",
exact: bool = False,
sep: str = "/",
max_depth: int = 0,
) -> Union[BinaryNodeT, NodeT]:
"""Prune tree by path or depth, returns the root of a *copy* of the original tree.
For pruning by `prune_path`,
- All siblings along the prune path will be removed.
- If ``exact=True``, all descendants of prune path will be removed.
- Prune path can be string (only one path) or a list of strings (multiple paths).
- Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
For pruning by `max_depth`,
- All nodes that are beyond `max_depth` will be removed.
Path should contain ``Node`` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
Examples:
>>> from bigtree import Node, prune_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=b)
>>> d = Node("d", parent=b)
>>> e = Node("e", parent=root)
>>> root.show()
a
├── b
│ ├── c
│ └── d
└── e
Prune (default is keep descendants)
>>> root_pruned = prune_tree(root, "a/b")
>>> root_pruned.show()
a
└── b
├── c
└── d
Prune exact path
>>> root_pruned = prune_tree(root, "a/b", exact=True)
>>> root_pruned.show()
a
└── b
Prune multiple paths
>>> root_pruned = prune_tree(root, ["a/b/d", "a/e"])
>>> root_pruned.show()
a
├── b
│ └── d
└── e
Prune by depth
>>> root_pruned = prune_tree(root, max_depth=2)
>>> root_pruned.show()
a
├── b
└── e
Args:
tree (Union[BinaryNode, Node]): existing tree
prune_path (List[str] | str): prune path(s), all siblings along the prune path(s) will be removed
exact (bool): prune path(s) to be exactly the path, defaults to False (descendants of the path are retained)
sep (str): path separator of `prune_path`
max_depth (int): maximum depth of pruned tree, based on `depth` attribute, defaults to None
Returns:
(Union[BinaryNode, Node])
"""
if isinstance(prune_path, str):
prune_path = [prune_path] if prune_path else []
if not len(prune_path) and not max_depth:
raise ValueError("Please specify either `prune_path` or `max_depth` or both.")
tree_copy = tree.copy()
# Prune by path (prune bottom-up)
if len(prune_path):
ancestors_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
nodes_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
for path in prune_path:
path = path.replace(sep, tree.sep)
child = find_path(tree_copy, path)
if not child:
raise NotFoundError(
f"Cannot find any node matching path_name ending with {path}"
)
nodes_to_prune.add(child)
ancestors_to_prune.update(list(child.ancestors))
if exact:
ancestors_to_prune.update(nodes_to_prune)
for node in ancestors_to_prune:
for child in node.children:
if (
child
and child not in ancestors_to_prune
and child not in nodes_to_prune
):
child.parent = None
# Prune by depth (prune top-down)
if max_depth:
for depth, level_nodes in enumerate(levelordergroup_iter(tree_copy), 1):
if depth == max_depth:
for level_node in level_nodes:
del level_node.children
return tree_copy
def get_tree_diff(
tree: Node, other_tree: Node, only_diff: bool = True, attr_list: List[str] = []
) -> Node:
"""Get difference of `tree` to `other_tree`, changes are relative to `tree`.
Compares the difference in tree structure (default), but can also compare tree attributes using `attr_list`.
Function can return only the differences (default), or all original tree nodes and differences.
Comparing tree structure:
- (+) and (-) will be added to node name relative to `tree`.
- For example: (+) refers to nodes that are in `other_tree` but not `tree`.
- For example: (-) refers to nodes that are in `tree` but not `other_tree`.
Examples:
>>> # Create original tree
>>> from bigtree import Node, get_tree_diff, list_to_tree
>>> root = list_to_tree(["Downloads/Pictures/photo1.jpg", "Downloads/file1.doc", "Downloads/photo2.jpg"])
>>> root.show()
Downloads
├── Pictures
│ └── photo1.jpg
├── file1.doc
└── photo2.jpg
>>> # Create other tree
>>> root_other = list_to_tree(["Downloads/Pictures/photo1.jpg", "Downloads/Pictures/photo2.jpg", "Downloads/file1.doc"])
>>> root_other.show()
Downloads
├── Pictures
│ ├── photo1.jpg
│ └── photo2.jpg
└── file1.doc
>>> # Get tree differences
>>> tree_diff = get_tree_diff(root, root_other)
>>> tree_diff.show()
Downloads
├── photo2.jpg (-)
└── Pictures
└── photo2.jpg (+)
>>> tree_diff = get_tree_diff(root, root_other, only_diff=False)
>>> tree_diff.show()
Downloads
├── Pictures
│ ├── photo1.jpg
│ └── photo2.jpg (+)
├── file1.doc
└── photo2.jpg (-)
Comparing tree attributes
- (~) will be added to node name if there are differences in tree attributes defined in `attr_list`.
- The node's attributes will be a list of [value in `tree`, value in `other_tree`]
>>> # Create original tree
>>> root = Node("Downloads")
>>> picture_folder = Node("Pictures", parent=root)
>>> photo2 = Node("photo1.jpg", tags="photo1", parent=picture_folder)
>>> file1 = Node("file1.doc", tags="file1", parent=root)
>>> root.show(attr_list=["tags"])
Downloads
├── Pictures
│ └── photo1.jpg [tags=photo1]
└── file1.doc [tags=file1]
>>> # Create other tree
>>> root_other = Node("Downloads")
>>> picture_folder = Node("Pictures", parent=root_other)
>>> photo1 = Node("photo1.jpg", tags="photo1-edited", parent=picture_folder)
>>> photo2 = Node("photo2.jpg", tags="photo2-new", parent=picture_folder)
>>> file1 = Node("file1.doc", tags="file1", parent=root_other)
>>> root_other.show(attr_list=["tags"])
Downloads
├── Pictures
│ ├── photo1.jpg [tags=photo1-edited]
│ └── photo2.jpg [tags=photo2-new]
└── file1.doc [tags=file1]
>>> # Get tree differences
>>> tree_diff = get_tree_diff(root, root_other, attr_list=["tags"])
>>> tree_diff.show(attr_list=["tags"])
Downloads
└── Pictures
├── photo1.jpg (~) [tags=('photo1', 'photo1-edited')]
└── photo2.jpg (+)
Args:
tree (Node): tree to be compared against
other_tree (Node): tree to be compared with
only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True
attr_list (List[str]): tree attributes to check for difference, defaults to empty list
Returns:
(Node)
"""
other_tree.sep = tree.sep
name_col = "name"
path_col = "PATH"
indicator_col = "Exists"
data, data_other = (
tree_to_dataframe(
_tree,
name_col=name_col,
path_col=path_col,
attr_dict={k: k for k in attr_list},
)
for _tree in (tree, other_tree)
)
# Check tree structure difference
data_both = data[[path_col, name_col] + attr_list].merge(
data_other[[path_col, name_col] + attr_list],
how="outer",
on=[path_col, name_col],
indicator=indicator_col,
)
# Handle tree structure difference
nodes_removed = list(data_both[data_both[indicator_col] == "left_only"][path_col])[
::-1
]
nodes_added = list(data_both[data_both[indicator_col] == "right_only"][path_col])[
::-1
]
for node_removed in nodes_removed:
data_both[path_col] = data_both[path_col].str.replace(
node_removed, f"{node_removed} (-)", regex=True
)
for node_added in nodes_added:
data_both[path_col] = data_both[path_col].str.replace(
node_added, f"{node_added} (+)", regex=True
)
# Check tree attribute difference
path_changes_list_of_dict: List[Dict[str, Dict[str, Any]]] = []
path_changes_deque: Deque[str] = deque([])
for attr_change in attr_list:
condition_diff = (
(
~data_both[f"{attr_change}_x"].isnull()
| ~data_both[f"{attr_change}_y"].isnull()
)
& (data_both[f"{attr_change}_x"] != data_both[f"{attr_change}_y"])
& (data_both[indicator_col] == "both")
)
data_diff = data_both[condition_diff]
if len(data_diff):
tuple_diff = zip(
data_diff[f"{attr_change}_x"], data_diff[f"{attr_change}_y"]
)
dict_attr_diff = [{attr_change: v} for v in tuple_diff]
dict_path_diff = dict(list(zip(data_diff[path_col], dict_attr_diff)))
path_changes_list_of_dict.append(dict_path_diff)
path_changes_deque.extend(list(data_diff[path_col]))
if only_diff:
data_both = data_both[
(data_both[indicator_col] != "both")
| (data_both[path_col].isin(path_changes_deque))
]
data_both = data_both[[path_col]]
if len(data_both):
tree_diff = dataframe_to_tree(data_both, node_type=tree.__class__)
# Handle tree attribute difference
if len(path_changes_deque):
path_changes_list = sorted(path_changes_deque, reverse=True)
name_changes_list = [
{k: {"name": f"{k.split(tree.sep)[-1]} (~)"} for k in path_changes_list}
]
path_changes_list_of_dict.extend(name_changes_list)
for attr_change_dict in path_changes_list_of_dict:
tree_diff = add_dict_to_tree_by_path(tree_diff, attr_change_dict)
return tree_diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,479 @@
from typing import Any, Callable, Iterable, List, Tuple, TypeVar, Union
from bigtree.node.basenode import BaseNode
from bigtree.node.dagnode import DAGNode
from bigtree.node.node import Node
from bigtree.utils.exceptions import SearchError
from bigtree.utils.iterators import preorder_iter
__all__ = [
"findall",
"find",
"find_name",
"find_names",
"find_relative_path",
"find_full_path",
"find_path",
"find_paths",
"find_attr",
"find_attrs",
"find_children",
"find_child",
"find_child_by_name",
]
T = TypeVar("T", bound=BaseNode)
NodeT = TypeVar("NodeT", bound=Node)
DAGNodeT = TypeVar("DAGNodeT", bound=DAGNode)
def findall(
tree: T,
condition: Callable[[T], bool],
max_depth: int = 0,
min_count: int = 0,
max_count: int = 0,
) -> Tuple[T, ...]:
"""
Search tree for nodes matching condition (callable function).
Examples:
>>> from bigtree import Node, findall
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> findall(root, lambda node: node.age > 62)
(Node(/a, age=90), Node(/a/b, age=65))
Args:
tree (BaseNode): tree to search
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
min_count (int): checks for minimum number of occurrences,
raise SearchError if the number of results do not meet min_count, defaults to None
max_count (int): checks for maximum number of occurrences,
raise SearchError if the number of results do not meet min_count, defaults to None
Returns:
(Tuple[BaseNode, ...])
"""
result = tuple(preorder_iter(tree, filter_condition=condition, max_depth=max_depth))
if min_count and len(result) < min_count:
raise SearchError(
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
)
if max_count and len(result) > max_count:
raise SearchError(
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
)
return result
def find(tree: T, condition: Callable[[T], bool], max_depth: int = 0) -> T:
"""
Search tree for *single node* matching condition (callable function).
Examples:
>>> from bigtree import Node, find
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find(root, lambda node: node.age == 65)
Node(/a/b, age=65)
>>> find(root, lambda node: node.age > 5)
Traceback (most recent call last):
...
bigtree.utils.exceptions.SearchError: Expected less than 1 element(s), found 4 elements
(Node(/a, age=90), Node(/a/b, age=65), Node(/a/c, age=60), Node(/a/c/d, age=40))
Args:
tree (BaseNode): tree to search
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
Returns:
(BaseNode)
"""
result = findall(tree, condition, max_depth, max_count=1)
if result:
return result[0]
def find_name(tree: NodeT, name: str, max_depth: int = 0) -> NodeT:
"""
Search tree for single node matching name attribute.
Examples:
>>> from bigtree import Node, find_name
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_name(root, "c")
Node(/a/c, age=60)
Args:
tree (Node): tree to search
name (str): value to match for name attribute
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
Returns:
(Node)
"""
return find(tree, lambda node: node.node_name == name, max_depth)
def find_names(tree: NodeT, name: str, max_depth: int = 0) -> Iterable[NodeT]:
"""
Search tree for multiple node(s) matching name attribute.
Examples:
>>> from bigtree import Node, find_names
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("b", age=40, parent=c)
>>> find_names(root, "c")
(Node(/a/c, age=60),)
>>> find_names(root, "b")
(Node(/a/b, age=65), Node(/a/c/b, age=40))
Args:
tree (Node): tree to search
name (str): value to match for name attribute
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
Returns:
(Iterable[Node])
"""
return findall(tree, lambda node: node.node_name == name, max_depth)
def find_relative_path(tree: NodeT, path_name: str) -> Iterable[NodeT]:
r"""
Search tree for single node matching relative path attribute.
- Supports unix folder expression for relative path, i.e., '../../node_name'
- Supports wildcards, i.e., '\*/node_name'
- If path name starts with leading separator symbol, it will start at root node.
Examples:
>>> from bigtree import Node, find_relative_path
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_relative_path(d, "..")
(Node(/a/c, age=60),)
>>> find_relative_path(d, "../../b")
(Node(/a/b, age=65),)
>>> find_relative_path(d, "../../*")
(Node(/a/b, age=65), Node(/a/c, age=60))
Args:
tree (Node): tree to search
path_name (str): value to match (relative path) of path_name attribute
Returns:
(Iterable[Node])
"""
sep = tree.sep
if path_name.startswith(sep):
resolved_node = find_full_path(tree, path_name)
return (resolved_node,)
path_name = path_name.rstrip(sep).lstrip(sep)
path_list = path_name.split(sep)
wildcard_indicator = "*" in path_name
resolved_nodes: List[NodeT] = []
def resolve(node: NodeT, path_idx: int) -> None:
"""Resolve node based on path name
Args:
node (Node): current node
path_idx (int): current index in path_list
"""
if path_idx == len(path_list):
resolved_nodes.append(node)
else:
path_component = path_list[path_idx]
if path_component == ".":
resolve(node, path_idx + 1)
elif path_component == "..":
if node.is_root:
raise SearchError("Invalid path name. Path goes beyond root node.")
resolve(node.parent, path_idx + 1)
elif path_component == "*":
for child in node.children:
resolve(child, path_idx + 1)
else:
node = find_child_by_name(node, path_component)
if not node:
if not wildcard_indicator:
raise SearchError(
f"Invalid path name. Node {path_component} cannot be found."
)
else:
resolve(node, path_idx + 1)
resolve(tree, 0)
return tuple(resolved_nodes)
def find_full_path(tree: NodeT, path_name: str) -> NodeT:
"""
Search tree for single node matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name must be full path, works similar to `find_path` but faster.
Examples:
>>> from bigtree import Node, find_full_path
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_full_path(root, "/a/c/d")
Node(/a/c/d, age=40)
Args:
tree (Node): tree to search
path_name (str): value to match (full path) of path_name attribute
Returns:
(Node)
"""
sep = tree.sep
path_list = path_name.rstrip(sep).lstrip(sep).split(sep)
if path_list[0] != tree.root.node_name:
raise ValueError(
f"Path {path_name} does not match the root node name {tree.root.node_name}"
)
parent_node = tree.root
child_node = parent_node
for child_name in path_list[1:]:
child_node = find_child_by_name(parent_node, child_name)
if not child_node:
break
parent_node = child_node
return child_node
def find_path(tree: NodeT, path_name: str) -> NodeT:
"""
Search tree for single node matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name can be full path or partial path (trailing part of path) or node name.
Examples:
>>> from bigtree import Node, find_path
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_path(root, "c")
Node(/a/c, age=60)
>>> find_path(root, "/c")
Node(/a/c, age=60)
Args:
tree (Node): tree to search
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
Returns:
(Node)
"""
path_name = path_name.rstrip(tree.sep)
return find(tree, lambda node: node.path_name.endswith(path_name))
def find_paths(tree: NodeT, path_name: str) -> Tuple[NodeT, ...]:
"""
Search tree for multiple nodes matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
Examples:
>>> from bigtree import Node, find_paths
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("c", age=40, parent=c)
>>> find_paths(root, "/a/c")
(Node(/a/c, age=60),)
>>> find_paths(root, "/c")
(Node(/a/c, age=60), Node(/a/c/c, age=40))
Args:
tree (Node): tree to search
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
Returns:
(Tuple[Node, ...])
"""
path_name = path_name.rstrip(tree.sep)
return findall(tree, lambda node: node.path_name.endswith(path_name))
def find_attr(
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = 0
) -> BaseNode:
"""
Search tree for single node matching custom attribute.
Examples:
>>> from bigtree import Node, find_attr
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_attr(root, "age", 65)
Node(/a/b, age=65)
Args:
tree (BaseNode): tree to search
attr_name (str): attribute name to perform matching
attr_value (Any): value to match for attr_name attribute
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
Returns:
(BaseNode)
"""
return find(
tree,
lambda node: bool(node.get_attr(attr_name) == attr_value),
max_depth,
)
def find_attrs(
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = 0
) -> Tuple[BaseNode, ...]:
"""
Search tree for node(s) matching custom attribute.
Examples:
>>> from bigtree import Node, find_attrs
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=65, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_attrs(root, "age", 65)
(Node(/a/b, age=65), Node(/a/c, age=65))
Args:
tree (BaseNode): tree to search
attr_name (str): attribute name to perform matching
attr_value (Any): value to match for attr_name attribute
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
Returns:
(Tuple[BaseNode, ...])
"""
return findall(
tree,
lambda node: bool(node.get_attr(attr_name) == attr_value),
max_depth,
)
def find_children(
tree: Union[T, DAGNodeT],
condition: Callable[[Union[T, DAGNodeT]], bool],
min_count: int = 0,
max_count: int = 0,
) -> Tuple[Union[T, DAGNodeT], ...]:
"""
Search children for nodes matching condition (callable function).
Examples:
>>> from bigtree import Node, find_children
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_children(root, lambda node: node.age > 30)
(Node(/a/b, age=65), Node(/a/c, age=60))
Args:
tree (BaseNode/DAGNode): tree to search for its children
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
min_count (int): checks for minimum number of occurrences,
raise SearchError if the number of results do not meet min_count, defaults to None
max_count (int): checks for maximum number of occurrences,
raise SearchError if the number of results do not meet min_count, defaults to None
Returns:
(BaseNode/DAGNode)
"""
result = tuple([node for node in tree.children if node and condition(node)])
if min_count and len(result) < min_count:
raise SearchError(
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
)
if max_count and len(result) > max_count:
raise SearchError(
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
)
return result
def find_child(
tree: Union[T, DAGNodeT],
condition: Callable[[Union[T, DAGNodeT]], bool],
) -> Union[T, DAGNodeT]:
"""
Search children for *single node* matching condition (callable function).
Examples:
>>> from bigtree import Node, find_child
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_child(root, lambda node: node.age > 62)
Node(/a/b, age=65)
Args:
tree (BaseNode/DAGNode): tree to search for its child
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
Returns:
(BaseNode/DAGNode)
"""
result = find_children(tree, condition, max_count=1)
if result:
return result[0]
def find_child_by_name(
tree: Union[NodeT, DAGNodeT], name: str
) -> Union[NodeT, DAGNodeT]:
"""
Search tree for single node matching name attribute.
Examples:
>>> from bigtree import Node, find_child_by_name
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_child_by_name(root, "c")
Node(/a/c, age=60)
>>> find_child_by_name(c, "d")
Node(/a/c/d, age=40)
Args:
tree (Node/DAGNode): tree to search, parent node
name (str): value to match for name attribute, child node
Returns:
(Node/DAGNode)
"""
return find_child(tree, lambda node: node.node_name == name)

View File

@@ -0,0 +1,53 @@
from typing import Any, Dict, List
def assert_style_in_dict(
parameter: Any,
accepted_parameters: Dict[str, Any],
) -> None:
"""Raise ValueError is parameter is not in list of accepted parameters
Args:
parameter (Any): argument input for parameter
accepted_parameters (List[Any]): list of accepted parameters
"""
if parameter not in accepted_parameters and parameter != "custom":
raise ValueError(
f"Choose one of {accepted_parameters.keys()} style, use `custom` to define own style"
)
def assert_str_in_list(
parameter_name: str,
parameter: Any,
accepted_parameters: List[Any],
) -> None:
"""Raise ValueError is parameter is not in list of accepted parameters
Args:
parameter_name (str): parameter name for error message
parameter (Any): argument input for parameter
accepted_parameters (List[Any]): list of accepted parameters
"""
if parameter not in accepted_parameters:
raise ValueError(
f"Invalid input, check `{parameter_name}` should be one of {accepted_parameters}"
)
def assert_key_in_dict(
parameter_name: str,
parameter: Any,
accepted_parameters: Dict[Any, Any],
) -> None:
"""Raise ValueError is parameter is not in key of dictionary
Args:
parameter_name (str): parameter name for error message
parameter (Any): argument input for parameter
accepted_parameters (Dict[Any]): dictionary of accepted parameters
"""
if parameter not in accepted_parameters:
raise ValueError(
f"Invalid input, check `{parameter_name}` should be one of {accepted_parameters.keys()}"
)

View File

@@ -0,0 +1,165 @@
from enum import Enum, auto
from typing import Dict, List, Tuple
class ExportConstants:
DOWN_RIGHT = "\u250c"
VERTICAL_RIGHT = "\u251c"
VERTICAL_LEFT = "\u2524"
VERTICAL_HORIZONTAL = "\u253c"
UP_RIGHT = "\u2514"
VERTICAL = "\u2502"
HORIZONTAL = "\u2500"
DOWN_RIGHT_ROUNDED = "\u256D"
UP_RIGHT_ROUNDED = "\u2570"
DOWN_RIGHT_BOLD = "\u250F"
VERTICAL_RIGHT_BOLD = "\u2523"
VERTICAL_LEFT_BOLD = "\u252B"
VERTICAL_HORIZONTAL_BOLD = "\u254B"
UP_RIGHT_BOLD = "\u2517"
VERTICAL_BOLD = "\u2503"
HORIZONTAL_BOLD = "\u2501"
DOWN_RIGHT_DOUBLE = "\u2554"
VERTICAL_RIGHT_DOUBLE = "\u2560"
VERTICAL_LEFT_DOUBLE = "\u2563"
VERTICAL_HORIZONTAL_DOUBLE = "\u256C"
UP_RIGHT_DOUBLE = "\u255a"
VERTICAL_DOUBLE = "\u2551"
HORIZONTAL_DOUBLE = "\u2550"
PRINT_STYLES: Dict[str, Tuple[str, str, str]] = {
"ansi": ("| ", "|-- ", "`-- "),
"ascii": ("| ", "|-- ", "+-- "),
"const": (
f"{VERTICAL} ",
f"{VERTICAL_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
f"{UP_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
),
"const_bold": (
f"{VERTICAL_BOLD} ",
f"{VERTICAL_RIGHT_BOLD}{HORIZONTAL_BOLD}{HORIZONTAL_BOLD} ",
f"{UP_RIGHT_BOLD}{HORIZONTAL_BOLD}{HORIZONTAL_BOLD} ",
),
"rounded": (
f"{VERTICAL} ",
f"{VERTICAL_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
f"{UP_RIGHT_ROUNDED}{HORIZONTAL}{HORIZONTAL} ",
),
"double": (
f"{VERTICAL_DOUBLE} ",
f"{VERTICAL_RIGHT_DOUBLE}{HORIZONTAL_DOUBLE}{HORIZONTAL_DOUBLE} ",
f"{UP_RIGHT_DOUBLE}{HORIZONTAL_DOUBLE}{HORIZONTAL_DOUBLE} ",
),
}
HPRINT_STYLES: Dict[str, Tuple[str, str, str, str, str, str, str]] = {
"ansi": ("/", "+", "+", "+", "\\", "|", "-"),
"ascii": ("+", "+", "+", "+", "+", "|", "-"),
"const": (
DOWN_RIGHT,
VERTICAL_RIGHT,
VERTICAL_LEFT,
VERTICAL_HORIZONTAL,
UP_RIGHT,
VERTICAL,
HORIZONTAL,
),
"const_bold": (
DOWN_RIGHT_BOLD,
VERTICAL_RIGHT_BOLD,
VERTICAL_LEFT_BOLD,
VERTICAL_HORIZONTAL_BOLD,
UP_RIGHT_BOLD,
VERTICAL_BOLD,
HORIZONTAL_BOLD,
),
"rounded": (
DOWN_RIGHT_ROUNDED,
VERTICAL_RIGHT,
VERTICAL_LEFT,
VERTICAL_HORIZONTAL,
UP_RIGHT_ROUNDED,
VERTICAL,
HORIZONTAL,
),
"double": (
DOWN_RIGHT_DOUBLE,
VERTICAL_RIGHT_DOUBLE,
VERTICAL_LEFT_DOUBLE,
VERTICAL_HORIZONTAL_DOUBLE,
UP_RIGHT_DOUBLE,
VERTICAL_DOUBLE,
HORIZONTAL_DOUBLE,
),
}
class MermaidConstants:
RANK_DIR: List[str] = ["TB", "BT", "LR", "RL"]
LINE_SHAPES: List[str] = [
"basis",
"bumpX",
"bumpY",
"cardinal",
"catmullRom",
"linear",
"monotoneX",
"monotoneY",
"natural",
"step",
"stepAfter",
"stepBefore",
]
NODE_SHAPES: Dict[str, str] = {
"rounded_edge": """("{label}")""",
"stadium": """(["{label}"])""",
"subroutine": """[["{label}"]]""",
"cylindrical": """[("{label}")]""",
"circle": """(("{label}"))""",
"asymmetric": """>"{label}"]""",
"rhombus": """{{"{label}"}}""",
"hexagon": """{{{{"{label}"}}}}""",
"parallelogram": """[/"{label}"/]""",
"parallelogram_alt": """[\\"{label}"\\]""",
"trapezoid": """[/"{label}"\\]""",
"trapezoid_alt": """[\\"{label}"/]""",
"double_circle": """((("{label}")))""",
}
EDGE_ARROWS: Dict[str, str] = {
"normal": "-->",
"bold": "==>",
"dotted": "-.->",
"open": "---",
"bold_open": "===",
"dotted_open": "-.-",
"invisible": "~~~",
"circle": "--o",
"cross": "--x",
"double_normal": "<-->",
"double_circle": "o--o",
"double_cross": "x--x",
}
class NewickState(Enum):
PARSE_STRING = auto()
PARSE_ATTRIBUTE_NAME = auto()
PARSE_ATTRIBUTE_VALUE = auto()
class NewickCharacter(str, Enum):
OPEN_BRACKET = "("
CLOSE_BRACKET = ")"
ATTR_START = "["
ATTR_END = "]"
ATTR_KEY_VALUE = "="
ATTR_QUOTE = "'"
SEP = ":"
NODE_SEP = ","
@classmethod
def values(cls) -> List[str]:
return [c.value for c in cls]

View File

@@ -0,0 +1,126 @@
from functools import wraps
from typing import Any, Callable, TypeVar
from warnings import simplefilter, warn
T = TypeVar("T")
class TreeError(Exception):
"""Generic tree exception"""
pass
class LoopError(TreeError):
"""Error during node creation"""
pass
class CorruptedTreeError(TreeError):
"""Error during node creation"""
pass
class DuplicatedNodeError(TreeError):
"""Error during tree creation"""
pass
class NotFoundError(TreeError):
"""Error during tree pruning or modification"""
pass
class SearchError(TreeError):
"""Error during tree search"""
pass
def deprecated(
alias: str,
) -> Callable[[Callable[..., T]], Callable[..., T]]: # pragma: no cover
def decorator(func: Callable[..., T]) -> Callable[..., T]:
"""
This is a decorator which can be used to mark functions as deprecated.
It will raise a DeprecationWarning when the function is used.
Source: https://stackoverflow.com/a/30253848
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
simplefilter("always", DeprecationWarning)
warn(
"{old_func} is going to be deprecated, use {new_func} instead".format(
old_func=func.__name__,
new_func=alias,
),
category=DeprecationWarning,
stacklevel=2,
)
simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs)
return wrapper
return decorator
def optional_dependencies_pandas(
func: Callable[..., T]
) -> Callable[..., T]: # pragma: no cover
"""
This is a decorator which can be used to import optional pandas dependency.
It will raise a ImportError if the module is not found.
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
try:
import pandas as pd # noqa: F401
except ImportError:
raise ImportError(
"pandas not available. Please perform a\n\n"
"pip install 'bigtree[pandas]'\n\nto install required dependencies"
) from None
return func(*args, **kwargs)
return wrapper
def optional_dependencies_image(
package_name: str = "",
) -> Callable[[Callable[..., T]], Callable[..., T]]:
def decorator(func: Callable[..., T]) -> Callable[..., T]:
"""
This is a decorator which can be used to import optional image dependency.
It will raise a ImportError if the module is not found.
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
if not package_name or package_name == "pydot":
try:
import pydot # noqa: F401
except ImportError: # pragma: no cover
raise ImportError(
"pydot not available. Please perform a\n\n"
"pip install 'bigtree[image]'\n\nto install required dependencies"
) from None
if not package_name or package_name == "Pillow":
try:
from PIL import Image, ImageDraw, ImageFont # noqa: F401
except ImportError: # pragma: no cover
raise ImportError(
"Pillow not available. Please perform a\n\n"
"pip install 'bigtree[image]'\n\nto install required dependencies"
) from None
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,19 @@
def whoami() -> str:
"""Groot utils
Returns:
(str)
"""
return "I am Groot!"
def speak_like_groot(sentence: str) -> str:
"""Convert sentence into Groot langauge
Args:
sentence (str): Sentence to convert to groot language
Returns:
(str)
"""
return " ".join([whoami() for _ in range(len(sentence.split()))])

View File

@@ -0,0 +1,587 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Callable,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
)
if TYPE_CHECKING:
from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
from bigtree.node.dagnode import DAGNode
BaseNodeT = TypeVar("BaseNodeT", bound=BaseNode)
BinaryNodeT = TypeVar("BinaryNodeT", bound=BinaryNode)
DAGNodeT = TypeVar("DAGNodeT", bound=DAGNode)
T = TypeVar("T", bound=Union[BaseNode, DAGNode])
__all__ = [
"inorder_iter",
"preorder_iter",
"postorder_iter",
"levelorder_iter",
"levelordergroup_iter",
"zigzag_iter",
"zigzaggroup_iter",
"dag_iterator",
]
def inorder_iter(
tree: BinaryNodeT,
filter_condition: Optional[Callable[[BinaryNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[BinaryNodeT]:
"""Iterate through all children of a tree.
In-Order Iteration Algorithm, LNR
1. Recursively traverse the current node's left subtree.
2. Visit the current node.
3. Recursively traverse the current node's right subtree.
Examples:
>>> from bigtree import BinaryNode, list_to_binarytree, inorder_iter
>>> num_list = [1, 2, 3, 4, 5, 6, 7, 8]
>>> root = list_to_binarytree(num_list)
>>> root.show()
1
├── 2
│ ├── 4
│ │ └── 8
│ └── 5
└── 3
├── 6
└── 7
>>> [node.node_name for node in inorder_iter(root)]
['8', '4', '2', '5', '1', '6', '3', '7']
>>> [node.node_name for node in inorder_iter(root, filter_condition=lambda x: x.node_name in ["1", "4", "3", "6", "7"])]
['4', '1', '6', '3', '7']
>>> [node.node_name for node in inorder_iter(root, max_depth=3)]
['4', '2', '5', '1', '6', '3', '7']
Args:
tree (BinaryNode): input tree
filter_condition (Optional[Callable[[BinaryNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Iterable[BinaryNode])
"""
if tree and (not max_depth or not tree.depth > max_depth):
yield from inorder_iter(tree.left, filter_condition, max_depth)
if not filter_condition or filter_condition(tree):
yield tree
yield from inorder_iter(tree.right, filter_condition, max_depth)
def preorder_iter(
tree: T,
filter_condition: Optional[Callable[[T], bool]] = None,
stop_condition: Optional[Callable[[T], bool]] = None,
max_depth: int = 0,
) -> Iterable[T]:
"""Iterate through all children of a tree.
Pre-Order Iteration Algorithm, NLR
1. Visit the current node.
2. Recursively traverse the current node's left subtree.
3. Recursively traverse the current node's right subtree.
It is topologically sorted because a parent node is processed before its child nodes.
Examples:
>>> from bigtree import Node, list_to_tree, preorder_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in preorder_iter(root)]
['a', 'b', 'd', 'e', 'g', 'h', 'c', 'f']
>>> [node.node_name for node in preorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['a', 'd', 'e', 'g', 'f']
>>> [node.node_name for node in preorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
['a', 'b', 'd', 'c', 'f']
>>> [node.node_name for node in preorder_iter(root, max_depth=3)]
['a', 'b', 'd', 'e', 'c', 'f']
Args:
tree (Union[BaseNode, DAGNode]): input tree
filter_condition (Optional[Callable[[T], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[T], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Union[Iterable[BaseNode], Iterable[DAGNode]])
"""
if (
tree
and (not max_depth or not tree.get_attr("depth") > max_depth)
and (not stop_condition or not stop_condition(tree))
):
if not filter_condition or filter_condition(tree):
yield tree
for child in tree.children:
yield from preorder_iter(child, filter_condition, stop_condition, max_depth) # type: ignore
def postorder_iter(
tree: BaseNodeT,
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[BaseNodeT]:
"""Iterate through all children of a tree.
Post-Order Iteration Algorithm, LRN
1. Recursively traverse the current node's left subtree.
2. Recursively traverse the current node's right subtree.
3. Visit the current node.
Examples:
>>> from bigtree import Node, list_to_tree, postorder_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in postorder_iter(root)]
['d', 'g', 'h', 'e', 'b', 'f', 'c', 'a']
>>> [node.node_name for node in postorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['d', 'g', 'e', 'f', 'a']
>>> [node.node_name for node in postorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
['d', 'b', 'f', 'c', 'a']
>>> [node.node_name for node in postorder_iter(root, max_depth=3)]
['d', 'e', 'b', 'f', 'c', 'a']
Args:
tree (BaseNode): input tree
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Iterable[BaseNode])
"""
if (
tree
and (not max_depth or not tree.depth > max_depth)
and (not stop_condition or not stop_condition(tree))
):
for child in tree.children:
yield from postorder_iter(
child, filter_condition, stop_condition, max_depth
)
if not filter_condition or filter_condition(tree):
yield tree
def levelorder_iter(
tree: BaseNodeT,
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[BaseNodeT]:
"""Iterate through all children of a tree.
Level-Order Iteration Algorithm
1. Recursively traverse the nodes on same level.
Examples:
>>> from bigtree import Node, list_to_tree, levelorder_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in levelorder_iter(root)]
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
>>> [node.node_name for node in levelorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['a', 'd', 'e', 'f', 'g']
>>> [node.node_name for node in levelorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
['a', 'b', 'c', 'd', 'f']
>>> [node.node_name for node in levelorder_iter(root, max_depth=3)]
['a', 'b', 'c', 'd', 'e', 'f']
Args:
tree (BaseNode): input tree
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[BaseNode])
"""
def _levelorder_iter(trees: List[BaseNodeT]) -> Iterable[BaseNodeT]:
"""Iterate through all children of a tree.
Args:
trees (List[BaseNode]): trees to get children for next level
Returns:
(Iterable[BaseNode])
"""
next_level = []
for _tree in trees:
if _tree:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
yield _tree
next_level.extend(list(_tree.children))
if len(next_level):
yield from _levelorder_iter(next_level)
yield from _levelorder_iter([tree])
def levelordergroup_iter(
tree: BaseNodeT,
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[Iterable[BaseNodeT]]:
"""Iterate through all children of a tree.
Level-Order Group Iteration Algorithm
1. Recursively traverse the nodes on same level, returns nodes level by level in a nested list.
Examples:
>>> from bigtree import Node, list_to_tree, levelordergroup_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root)]
[['a'], ['b', 'c'], ['d', 'e', 'f'], ['g', 'h']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
[['a'], [], ['d', 'e', 'f'], ['g']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, stop_condition=lambda x: x.node_name == "e")]
[['a'], ['b', 'c'], ['d', 'f']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, max_depth=3)]
[['a'], ['b', 'c'], ['d', 'e', 'f']]
Args:
tree (BaseNode): input tree
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[Iterable[BaseNode]])
"""
def _levelordergroup_iter(trees: List[BaseNodeT]) -> Iterable[Iterable[BaseNodeT]]:
"""Iterate through all children of a tree.
Args:
trees (List[BaseNode]): trees to get children for next level
Returns:
(Iterable[Iterable[BaseNode]])
"""
current_tree = []
next_level = []
for _tree in trees:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
current_tree.append(_tree)
next_level.extend([_child for _child in _tree.children if _child])
yield tuple(current_tree)
if len(next_level) and (not max_depth or not next_level[0].depth > max_depth):
yield from _levelordergroup_iter(next_level)
yield from _levelordergroup_iter([tree])
def zigzag_iter(
tree: BaseNodeT,
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[BaseNodeT]:
"""Iterate through all children of a tree.
ZigZag Iteration Algorithm
1. Recursively traverse the nodes on same level, in a zigzag manner across different levels.
Examples:
>>> from bigtree import Node, list_to_tree, zigzag_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in zigzag_iter(root)]
['a', 'c', 'b', 'd', 'e', 'f', 'h', 'g']
>>> [node.node_name for node in zigzag_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['a', 'd', 'e', 'f', 'g']
>>> [node.node_name for node in zigzag_iter(root, stop_condition=lambda x: x.node_name == "e")]
['a', 'c', 'b', 'd', 'f']
>>> [node.node_name for node in zigzag_iter(root, max_depth=3)]
['a', 'c', 'b', 'd', 'e', 'f']
Args:
tree (BaseNode): input tree
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[BaseNode])
"""
def _zigzag_iter(
trees: List[BaseNodeT], reverse_indicator: bool = False
) -> Iterable[BaseNodeT]:
"""Iterate through all children of a tree.
Args:
trees (List[BaseNode]): trees to get children for next level
reverse_indicator (bool): indicator whether it is in reverse order
Returns:
(Iterable[BaseNode])
"""
next_level = []
for _tree in trees:
if _tree:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
yield _tree
next_level_nodes = list(_tree.children)
if reverse_indicator:
next_level_nodes = next_level_nodes[::-1]
next_level.extend(next_level_nodes)
if len(next_level):
yield from _zigzag_iter(
next_level[::-1], reverse_indicator=not reverse_indicator
)
yield from _zigzag_iter([tree])
def zigzaggroup_iter(
tree: BaseNodeT,
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
max_depth: int = 0,
) -> Iterable[Iterable[BaseNodeT]]:
"""Iterate through all children of a tree.
ZigZag Group Iteration Algorithm
1. Recursively traverse the nodes on same level, in a zigzag manner across different levels,
returns nodes level by level in a nested list.
Examples:
>>> from bigtree import Node, list_to_tree, zigzaggroup_iter
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root)]
[['a'], ['c', 'b'], ['d', 'e', 'f'], ['h', 'g']]
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
[['a'], [], ['d', 'e', 'f'], ['g']]
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, stop_condition=lambda x: x.node_name == "e")]
[['a'], ['c', 'b'], ['d', 'f']]
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, max_depth=3)]
[['a'], ['c', 'b'], ['d', 'e', 'f']]
Args:
tree (BaseNode): input tree
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Return node if condition evaluates to `True`
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[Iterable[BaseNode]])
"""
def _zigzaggroup_iter(
trees: List[BaseNodeT], reverse_indicator: bool = False
) -> Iterable[Iterable[BaseNodeT]]:
"""Iterate through all children of a tree.
Args:
trees (List[BaseNode]): trees to get children for next level
reverse_indicator (bool): indicator whether it is in reverse order
Returns:
(Iterable[Iterable[BaseNode]])
"""
current_tree = []
next_level = []
for _tree in trees:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
current_tree.append(_tree)
next_level_nodes = [_child for _child in _tree.children if _child]
if reverse_indicator:
next_level_nodes = next_level_nodes[::-1]
next_level.extend(next_level_nodes)
yield tuple(current_tree)
if len(next_level) and (not max_depth or not next_level[0].depth > max_depth):
yield from _zigzaggroup_iter(
next_level[::-1], reverse_indicator=not reverse_indicator
)
yield from _zigzaggroup_iter([tree])
def dag_iterator(dag: DAGNodeT) -> Iterable[Tuple[DAGNodeT, DAGNodeT]]:
"""Iterate through all nodes of a Directed Acyclic Graph (DAG).
Note that node names must be unique.
Note that DAG must at least have two nodes to be shown on graph.
1. Visit the current node.
2. Recursively traverse the current node's parents.
3. Recursively traverse the current node's children.
Examples:
>>> from bigtree import DAGNode, dag_iterator
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(a)]
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
Args:
dag (DAGNode): input dag
Returns:
(Iterable[Tuple[DAGNode, DAGNode]])
"""
visited_nodes = set()
def _dag_iterator(node: DAGNodeT) -> Iterable[Tuple[DAGNodeT, DAGNodeT]]:
"""Iterate through all children of a DAG.
Args:
node (DAGNode): current node
Returns:
(Iterable[Tuple[DAGNode, DAGNode]])
"""
node_name = node.node_name
visited_nodes.add(node_name)
# Parse upwards
for parent in node.parents:
parent_name = parent.node_name
if parent_name not in visited_nodes:
yield parent, node
# Parse downwards
for child in node.children:
child_name = child.node_name
if child_name not in visited_nodes:
yield node, child
# Parse upwards
for parent in node.parents:
parent_name = parent.node_name
if parent_name not in visited_nodes:
yield from _dag_iterator(parent)
# Parse downwards
for child in node.children:
child_name = child.node_name
if child_name not in visited_nodes:
yield from _dag_iterator(child)
yield from _dag_iterator(dag)

View File

@@ -0,0 +1,354 @@
from typing import Optional, TypeVar
from bigtree.node.basenode import BaseNode
T = TypeVar("T", bound=BaseNode)
__all__ = [
"reingold_tilford",
]
def reingold_tilford(
tree_node: T,
sibling_separation: float = 1.0,
subtree_separation: float = 1.0,
level_separation: float = 1.0,
x_offset: float = 0.0,
y_offset: float = 0.0,
) -> None:
"""
Algorithm for drawing tree structure, retrieves `(x, y)` coordinates for a tree structure.
Adds `x` and `y` attributes to every node in the tree. Modifies tree in-place.
This algorithm[1] is an improvement over Reingold Tilford algorithm[2].
According to Reingold Tilford's paper, a tree diagram should satisfy the following aesthetic rules,
1. Nodes at the same depth should lie along a straight line, and the straight lines defining the depths should be parallel.
2. A left child should be positioned to the left of its parent node and a right child to the right.
3. A parent should be centered over their children.
4. A tree and its mirror image should produce drawings that are reflections of one another; a subtree should be drawn the same way regardless of where it occurs in the tree.
Examples:
>>> from bigtree import reingold_tilford, list_to_tree
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> root.show()
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> reingold_tilford(root)
>>> root.show(attr_list=["x", "y"])
a [x=1.25, y=3.0]
├── b [x=0.5, y=2.0]
│ ├── d [x=0.0, y=1.0]
│ └── e [x=1.0, y=1.0]
│ ├── g [x=0.5, y=0.0]
│ └── h [x=1.5, y=0.0]
└── c [x=2.0, y=2.0]
└── f [x=2.0, y=1.0]
References
- [1] Walker, J. (1991). Positioning Nodes for General Trees. https://www.drdobbs.com/positioning-nodes-for-general-trees/184402320?pgno=4
- [2] Reingold, E., Tilford, J. (1981). Tidier Drawings of Trees. IEEE Transactions on Software Engineering. https://reingold.co/tidier-drawings.pdf
Args:
tree_node (BaseNode): tree to compute (x, y) coordinate
sibling_separation (float): minimum distance between adjacent siblings of the tree
subtree_separation (float): minimum distance between adjacent subtrees of the tree
level_separation (float): fixed distance between adjacent levels of the tree
x_offset (float): graph offset of x-coordinates
y_offset (float): graph offset of y-coordinates
"""
_first_pass(tree_node, sibling_separation, subtree_separation)
x_adjustment = _second_pass(tree_node, level_separation, x_offset, y_offset)
_third_pass(tree_node, x_adjustment)
def _first_pass(
tree_node: T, sibling_separation: float, subtree_separation: float
) -> None:
"""
Performs post-order traversal of tree and assigns `x`, `mod` and `shift` values to each node.
Modifies tree in-place.
Notation:
- `lsibling`: left-sibling of node
- `lchild`: last child of node
- `fchild`: first child of node
- `midpoint`: midpoint of node wrt children, :math:`midpoint = (lchild.x + fchild.x) / 2`
- `sibling distance`: sibling separation
- `subtree distance`: subtree separation
There are two parts in the first pass,
1. In the first part, we assign `x` and `mod` values to each node
`x` value is the initial x-position of each node purely based on the node's position
- :math:`x = 0` for leftmost node and :math:`x = lsibling.x + sibling distance` for other nodes
- Special case when leftmost node has children, then it will try to center itself, :math:`x = midpoint`
`mod` value is the amount to shift the subtree (all descendant nodes excluding itself) to make the children centered with itself
- :math:`mod = 0` for node does not have children (no need to shift subtree) or it is a leftmost node (parent is already centered, from above point)
- Special case when non-leftmost nodes have children, :math:`mod = x - midpoint`
2. In the second part, we assign `shift` value of nodes due to overlapping subtrees.
For each node on the same level, ensure that the leftmost descendant does not intersect with the rightmost
descendant of any left sibling at every subsequent level. Intersection happens when the subtrees are not
at least `subtree distance` apart.
If there are any intersections, shift the whole subtree by a new `shift` value, shift any left sibling by a
fraction of `shift` value, and shift any right sibling by `shift` + a multiple of the fraction of
`shift` value to keep nodes centralized at the level.
Args:
tree_node (BaseNode): tree to compute (x, y) coordinate
sibling_separation (float): minimum distance between adjacent siblings of the tree
subtree_separation (float): minimum distance between adjacent subtrees of the tree
"""
# Post-order iteration (LRN)
for child in tree_node.children:
_first_pass(child, sibling_separation, subtree_separation)
_x = 0.0
_mod = 0.0
_shift = 0.0
_midpoint = 0.0
if tree_node.is_root:
tree_node.set_attrs({"x": _get_midpoint_of_children(tree_node)})
tree_node.set_attrs({"mod": _mod})
tree_node.set_attrs({"shift": _shift})
else:
# First part - assign x and mod values
if tree_node.children:
_midpoint = _get_midpoint_of_children(tree_node)
# Non-leftmost node
if tree_node.left_sibling:
_x = tree_node.left_sibling.get_attr("x") + sibling_separation
if tree_node.children:
_mod = _x - _midpoint
# Leftmost node
else:
if tree_node.children:
_x = _midpoint
tree_node.set_attrs({"x": _x})
tree_node.set_attrs({"mod": _mod})
tree_node.set_attrs({"shift": tree_node.get_attr("shift", _shift)})
# Second part - assign shift values due to overlapping subtrees
parent_node = tree_node.parent
tree_node_idx = parent_node.children.index(tree_node)
if tree_node_idx:
for idx_node in range(tree_node_idx):
left_subtree = parent_node.children[idx_node]
_shift = max(
_shift,
_get_subtree_shift(
left_subtree=left_subtree,
right_subtree=tree_node,
left_idx=idx_node,
right_idx=tree_node_idx,
subtree_separation=subtree_separation,
),
)
# Shift siblings (left siblings, itself, right siblings) accordingly
for multiple, sibling in enumerate(parent_node.children):
sibling.set_attrs(
{
"shift": sibling.get_attr("shift", 0)
+ (_shift * multiple / tree_node_idx)
}
)
def _get_midpoint_of_children(tree_node: BaseNode) -> float:
"""Get midpoint of children of a node
Args:
tree_node (BaseNode): tree node to obtain midpoint of their child/children
Returns:
(float)
"""
if tree_node.children:
first_child_x: float = tree_node.children[0].get_attr("x") + tree_node.children[
0
].get_attr("shift")
last_child_x: float = tree_node.children[-1].get_attr("x") + tree_node.children[
-1
].get_attr("shift")
return (last_child_x + first_child_x) / 2
return 0.0
def _get_subtree_shift(
left_subtree: T,
right_subtree: T,
left_idx: int,
right_idx: int,
subtree_separation: float,
left_cum_shift: float = 0,
right_cum_shift: float = 0,
cum_shift: float = 0,
initial_run: bool = True,
) -> float:
"""Get shift amount to shift the right subtree towards the right such that it does not overlap with the left subtree
Args:
left_subtree (BaseNode): left subtree, with right contour to be traversed
right_subtree (BaseNode): right subtree, with left contour to be traversed
left_idx (int): index of left subtree, to compute overlap for relative shift (constant across iteration)
right_idx (int): index of right subtree, to compute overlap for relative shift (constant across iteration)
subtree_separation (float): minimum distance between adjacent subtrees of the tree (constant across iteration)
left_cum_shift (float): cumulative `mod + shift` for left subtree from the ancestors, defaults to 0
right_cum_shift (float): cumulative `mod + shift` for right subtree from the ancestors, defaults to 0
cum_shift (float): cumulative shift amount for right subtree, defaults to 0
initial_run (bool): indicates whether left_subtree and right_subtree are the main subtrees, defaults to True
Returns:
(float)
"""
new_shift = 0.0
if not initial_run:
x_left = (
left_subtree.get_attr("x") + left_subtree.get_attr("shift") + left_cum_shift
)
x_right = (
right_subtree.get_attr("x")
+ right_subtree.get_attr("shift")
+ right_cum_shift
+ cum_shift
)
new_shift = max(
(x_left + subtree_separation - x_right) / (1 - left_idx / right_idx), 0
)
# Search for a left sibling of left_subtree that has children
while left_subtree and not left_subtree.children and left_subtree.left_sibling:
left_subtree = left_subtree.left_sibling
# Search for a right sibling of right_subtree that has children
while (
right_subtree and not right_subtree.children and right_subtree.right_sibling
):
right_subtree = right_subtree.right_sibling
if left_subtree.children and right_subtree.children:
# Iterate down the level, for the rightmost child of left_subtree and the leftmost child of right_subtree
return _get_subtree_shift(
left_subtree=left_subtree.children[-1],
right_subtree=right_subtree.children[0],
left_idx=left_idx,
right_idx=right_idx,
subtree_separation=subtree_separation,
left_cum_shift=(
left_cum_shift
+ left_subtree.get_attr("mod")
+ left_subtree.get_attr("shift")
),
right_cum_shift=(
right_cum_shift
+ right_subtree.get_attr("mod")
+ right_subtree.get_attr("shift")
),
cum_shift=cum_shift + new_shift,
initial_run=False,
)
return cum_shift + new_shift
def _second_pass(
tree_node: T,
level_separation: float,
x_offset: float,
y_offset: float,
cum_mod: Optional[float] = 0.0,
max_depth: Optional[int] = None,
x_adjustment: Optional[float] = 0.0,
) -> float:
"""
Performs pre-order traversal of tree and determines the final `x` and `y` values for each node.
Modifies tree in-place.
Notation:
- `depth`: maximum depth of tree
- `distance`: level separation
- `x'`: x offset
- `y'`: y offset
Final position of each node
- :math:`x = node.x + node.shift + sum(ancestor.mod) + x'`
- :math:`y = (depth - node.depth) * distance + y'`
Args:
tree_node (BaseNode): tree to compute (x, y) coordinate
level_separation (float): fixed distance between adjacent levels of the tree (constant across iteration)
x_offset (float): graph offset of x-coordinates (constant across iteration)
y_offset (float): graph offset of y-coordinates (constant across iteration)
cum_mod (Optional[float]): cumulative `mod + shift` for tree/subtree from the ancestors
max_depth (Optional[int]): maximum depth of tree (constant across iteration)
x_adjustment (Optional[float]): amount of x-adjustment for third pass, in case any x-coordinates goes below 0
Returns
(float)
"""
if not max_depth:
max_depth = tree_node.max_depth
final_x: float = (
tree_node.get_attr("x") + tree_node.get_attr("shift") + cum_mod + x_offset
)
final_y: float = (max_depth - tree_node.depth) * level_separation + y_offset
tree_node.set_attrs({"x": final_x, "y": final_y})
# Pre-order iteration (NLR)
if tree_node.children:
return max(
[
_second_pass(
child,
level_separation,
x_offset,
y_offset,
cum_mod + tree_node.get_attr("mod") + tree_node.get_attr("shift"),
max_depth,
x_adjustment,
)
for child in tree_node.children
]
)
return max(x_adjustment, -final_x)
def _third_pass(tree_node: BaseNode, x_adjustment: float) -> None:
"""Adjust all x-coordinates by an adjustment value so that every x-coordinate is greater than or equal to 0.
Modifies tree in-place.
Args:
tree_node (BaseNode): tree to compute (x, y) coordinate
x_adjustment (float): amount of adjustment for x-coordinates (constant across iteration)
"""
if x_adjustment:
tree_node.set_attrs({"x": tree_node.get_attr("x") + x_adjustment})
# Pre-order iteration (NLR)
for child in tree_node.children:
_third_pass(child, x_adjustment)

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import datetime as dt
from typing import Any, Optional, Union
from bigtree.node.node import Node
from bigtree.tree.construct import add_path_to_tree
from bigtree.tree.export import tree_to_dataframe
from bigtree.tree.search import find_full_path, findall
try:
import pandas as pd
except ImportError: # pragma: no cover
pd = None
class Calendar:
"""
Calendar Implementation with Big Tree.
- Calendar has four levels - year, month, day, and event name (with event attributes)
Examples:
*Initializing and Adding Events*
>>> from bigtree import Calendar
>>> calendar = Calendar("My Calendar")
>>> calendar.add_event("Gym", "2023-01-01 18:00")
>>> calendar.add_event("Dinner", "2023-01-01", date_format="%Y-%m-%d", budget=20)
>>> calendar.add_event("Gym", "2023-01-02 18:00")
>>> calendar.show()
My Calendar
2023-01-01 00:00:00 - Dinner (budget: 20)
2023-01-01 18:00:00 - Gym
2023-01-02 18:00:00 - Gym
*Search for Events*
>>> calendar.find_event("Gym")
2023-01-01 18:00:00 - Gym
2023-01-02 18:00:00 - Gym
*Removing Events*
>>> import datetime as dt
>>> calendar.delete_event("Gym", dt.date(2023, 1, 1))
>>> calendar.show()
My Calendar
2023-01-01 00:00:00 - Dinner (budget: 20)
2023-01-02 18:00:00 - Gym
*Export Calendar*
>>> calendar.to_dataframe()
path name date time budget
0 /My Calendar/2023/01/01/Dinner Dinner 2023-01-01 00:00:00 20.0
1 /My Calendar/2023/01/02/Gym Gym 2023-01-02 18:00:00 NaN
"""
def __init__(self, name: str):
self.calendar = Node(name)
self.__sorted = True
def add_event(
self,
event_name: str,
event_datetime: Union[str, dt.datetime],
date_format: str = "%Y-%m-%d %H:%M",
**kwargs: Any,
) -> None:
"""Add event to calendar
Args:
event_name (str): event name to be added
event_datetime (Union[str, dt.datetime]): event date and time
date_format (str): specify datetime format if event_datetime is str
"""
if isinstance(event_datetime, str):
event_datetime = dt.datetime.strptime(event_datetime, date_format)
year, month, day, date, time = (
event_datetime.year,
str(event_datetime.month).zfill(2),
str(event_datetime.day).zfill(2),
event_datetime.date(),
event_datetime.time(),
)
event_path = f"{self.calendar.node_name}/{year}/{month}/{day}/{event_name}"
event_attr = {"date": date, "time": time, **kwargs}
if find_full_path(self.calendar, event_path):
print(
f"Event {event_name} exists on {date}, overwriting information for {event_name}"
)
add_path_to_tree(
tree=self.calendar,
path=event_path,
node_attrs=event_attr,
)
self.__sorted = False
def delete_event(
self, event_name: str, event_date: Optional[dt.date] = None
) -> None:
"""Delete event from calendar
Args:
event_name (str): event name to be deleted
event_date (dt.date): event date to be deleted
"""
if event_date:
year, month, day = (
event_date.year,
str(event_date.month).zfill(2),
str(event_date.day).zfill(2),
)
event_path = f"{self.calendar.node_name}/{year}/{month}/{day}/{event_name}"
event = find_full_path(self.calendar, event_path)
if event:
self._delete_event(event)
else:
print(f"Event {event_name} does not exist on {event_date}")
else:
for event in findall(
self.calendar, lambda node: node.node_name == event_name
):
self._delete_event(event)
def find_event(self, event_name: str) -> None:
"""Find event by name, prints result to console
Args:
event_name (str): event name
"""
if not self.__sorted:
self._sort()
for event in findall(self.calendar, lambda node: node.node_name == event_name):
self._show(event)
def show(self) -> None:
"""Show calendar, prints result to console"""
if not len(self.calendar.children):
raise Exception("Calendar is empty!")
if not self.__sorted:
self._sort()
print(self.calendar.node_name)
for event in self.calendar.leaves:
self._show(event)
def to_dataframe(self) -> pd.DataFrame:
"""
Export calendar to DataFrame
Returns:
(pd.DataFrame)
"""
if not len(self.calendar.children):
raise Exception("Calendar is empty!")
data = tree_to_dataframe(self.calendar, all_attrs=True, leaf_only=True)
compulsory_cols = ["path", "name", "date", "time"]
other_cols = list(set(data.columns) - set(compulsory_cols))
return data[compulsory_cols + other_cols]
def _delete_event(self, event: Node) -> None:
"""Private method to delete event, delete parent node as well
Args:
event (Node): event to be deleted
"""
if len(event.parent.children) == 1:
if event.parent.parent:
self._delete_event(event.parent)
event.parent.parent = None
else:
event.parent = None
def _sort(self) -> None:
"""Private method to sort calendar by event date, followed by event time"""
for day_event in findall(self.calendar, lambda node: node.depth <= 4):
if day_event.depth < 4:
day_event.sort(key=lambda attr: attr.node_name)
else:
day_event.sort(key=lambda attr: attr.time)
self.__sorted = True
@staticmethod
def _show(event: Node) -> None:
"""Private method to show event, handles the formatting of event
Prints result to console
Args:
event (Node): event
"""
event_datetime = dt.datetime.combine(
event.get_attr("date"), event.get_attr("time")
)
event_attrs = event.describe(
exclude_attributes=["date", "time", "name"], exclude_prefix="_"
)
event_attrs_str = ", ".join([f"{attr[0]}: {attr[1]}" for attr in event_attrs])
if event_attrs_str:
event_attrs_str = f" ({event_attrs_str})"
print(f"{event_datetime} - {event.node_name}{event_attrs_str}")

View File

@@ -0,0 +1,261 @@
from __future__ import annotations
import json
import logging
from typing import Any, List, Union
from bigtree.node.node import Node
from bigtree.tree.construct import nested_dict_to_tree
from bigtree.tree.export import print_tree, tree_to_nested_dict
from bigtree.tree.search import find_child_by_name, find_name
logging.getLogger(__name__).addHandler(logging.NullHandler())
class AppToDo:
"""
To-Do List Implementation with Big Tree.
- To-Do List has three levels - app name, list name, and item name.
- If list name is not given, item will be assigned to a `General` list.
Examples:
*Initializing and Adding Items*
>>> from bigtree import AppToDo
>>> app = AppToDo("To Do App")
>>> app.add_item(item_name="Homework 1", list_name="School")
>>> app.add_item(item_name=["Milk", "Bread"], list_name="Groceries", description="Urgent")
>>> app.add_item(item_name="Cook")
>>> app.show()
To Do App
├── School
│ └── Homework 1
├── Groceries
│ ├── Milk [description=Urgent]
│ └── Bread [description=Urgent]
└── General
└── Cook
*Reorder List and Item*
>>> app.prioritize_list(list_name="General")
>>> app.show()
To Do App
├── General
│ └── Cook
├── School
│ └── Homework 1
└── Groceries
├── Milk [description=Urgent]
└── Bread [description=Urgent]
>>> app.prioritize_item(item_name="Bread")
>>> app.show()
To Do App
├── General
│ └── Cook
├── School
│ └── Homework 1
└── Groceries
├── Bread [description=Urgent]
└── Milk [description=Urgent]
*Removing Items*
>>> app.remove_item("Homework 1")
>>> app.show()
To Do App
├── General
│ └── Cook
└── Groceries
├── Bread [description=Urgent]
└── Milk [description=Urgent]
*Exporting and Importing List*
>>> app.save("assets/docstr/list.json")
>>> app2 = AppToDo.load("assets/docstr/list.json")
>>> app2.show()
To Do App
├── General
│ └── Cook
└── Groceries
├── Bread [description=Urgent]
└── Milk [description=Urgent]
"""
def __init__(
self,
app_name: str = "",
):
"""Initialize To-Do app
Args:
app_name (str): name of to-do app, optional
"""
self._root = Node(app_name)
def add_list(self, list_name: str, **kwargs: Any) -> Node:
"""Add list to app
If list is present, return list node, else a new list will be created
Args:
list_name (str): name of list
Returns:
(Node)
"""
list_node = find_child_by_name(self._root, list_name)
if not list_node:
list_node = Node(list_name, parent=self._root, **kwargs)
logging.info(f"Created list {list_name}")
return list_node
def prioritize_list(self, list_name: str) -> None:
"""Prioritize list in app, shift it to be the first list
Args:
list_name (str): name of list
"""
list_node = find_child_by_name(self._root, list_name)
if not list_node:
raise ValueError(f"List {list_name} not found")
current_children = list(self._root.children)
current_children.remove(list_node)
current_children.insert(0, list_node)
self._root.children = current_children # type: ignore
def add_item(
self, item_name: Union[str, List[str]], list_name: str = "", **kwargs: Any
) -> None:
"""Add items to list
Args:
item_name (str/List[str]): items to be added
list_name (str): list to add items to, optional
"""
if not isinstance(item_name, str) and not isinstance(item_name, list):
raise TypeError("Invalid data type for item")
if isinstance(item_name, str):
item_name = [item_name]
# Get list to add to
if list_name:
list_node = self.add_list(list_name)
else:
list_node = self.add_list("General")
# Add items to list
for _item in item_name:
_ = Node(_item, parent=list_node, **kwargs)
logging.info(f"Created item(s) {', '.join(item_name)}")
def remove_item(
self, item_name: Union[str, List[str]], list_name: str = ""
) -> None:
"""Remove items from list
Args:
item_name (str/List[str]): items to be added
list_name (str): list to add items to, optional
"""
if not isinstance(item_name, str) and not isinstance(item_name, list):
raise TypeError("Invalid data type for item")
if isinstance(item_name, str):
item_name = [item_name]
# Check if items can be found
items_to_remove = []
parent_to_check: set[Node] = set()
if list_name:
list_node = find_child_by_name(self._root, list_name)
if not list_node:
raise ValueError(f"List {list_name} does not exist!")
for _item in item_name:
item_node = find_child_by_name(list_node, _item)
if not item_node:
raise ValueError(f"Item {_item} does not exist!")
assert isinstance(item_node.parent, Node) # for mypy type checking
items_to_remove.append(item_node)
parent_to_check.add(item_node.parent)
else:
for _item in item_name:
item_node = find_name(self._root, _item)
if not item_node:
raise ValueError(f"Item {_item} does not exist!")
assert isinstance(item_node.parent, Node) # for mypy type checking
items_to_remove.append(item_node)
parent_to_check.add(item_node.parent)
# Remove items
for item_to_remove in items_to_remove:
if item_to_remove.depth != 3:
raise ValueError(
f"Check item to remove {item_to_remove} is an item at item-level"
)
item_to_remove.parent = None
logging.info(
f"Removed item(s) {', '.join(item.node_name for item in items_to_remove)}"
)
# Remove list if empty
for list_node in parent_to_check:
if not len(list(list_node.children)):
list_node.parent = None
logging.info(f"Removed list {list_node.node_name}")
def prioritize_item(self, item_name: str) -> None:
"""Prioritize item in list, shift it to be the first item in list
Args:
item_name (str): name of item
"""
item_node = find_name(self._root, item_name)
if not item_node:
raise ValueError(f"Item {item_name} not found")
if item_node.depth != 3:
raise ValueError(f"{item_name} is not an item")
assert isinstance(item_node.parent, Node) # for mypy type checking
current_parent = item_node.parent
current_children = list(current_parent.children)
current_children.remove(item_node)
current_children.insert(0, item_node)
current_parent.children = current_children # type: ignore
def show(self, **kwargs: Any) -> None:
"""Print tree to console"""
print_tree(self._root, all_attrs=True, **kwargs)
@staticmethod
def load(json_path: str) -> AppToDo:
"""Load To-Do app from json
Args:
json_path (str): json load path
Returns:
(Self)
"""
if not json_path.endswith(".json"):
raise ValueError("Path should end with .json")
with open(json_path, "r") as fp:
app_dict = json.load(fp)
_app = AppToDo("dummy")
AppToDo.__setattr__(_app, "_root", nested_dict_to_tree(app_dict["root"]))
return _app
def save(self, json_path: str) -> None:
"""Save To-Do app as json
Args:
json_path (str): json save path
"""
if not json_path.endswith(".json"):
raise ValueError("Path should end with .json")
node_dict = tree_to_nested_dict(self._root, all_attrs=True)
app_dict = {"root": node_dict}
with open(json_path, "w") as fp:
json.dump(app_dict, fp)

View File

@@ -0,0 +1 @@
./elog-1.3.4-py3.7.egg

View File

@@ -0,0 +1,13 @@
from elog.logbook import Logbook
from elog.logbook import LogbookError, LogbookAuthenticationError, LogbookServerProblem, LogbookMessageRejected, \
LogbookInvalidMessageID, LogbookInvalidAttachmentType
def open(*args, **kwargs):
"""
Will return a Logbook object. All arguments are passed to the logbook constructor.
:param args:
:param kwargs:
:return: Logbook() instance
"""
return Logbook(*args, **kwargs)

View File

@@ -0,0 +1,571 @@
import requests
import urllib.parse
import os
import builtins
import re
from elog.logbook_exceptions import *
from datetime import datetime
class Logbook(object):
"""
Logbook provides methods to interface with logbook on location: "server:port/subdir/logbook". User can create,
edit, delete logbook messages.
"""
def __init__(self, hostname, logbook='', port=None, user=None, password=None, subdir='', use_ssl=True,
encrypt_pwd=True):
"""
:param hostname: elog server hostname. If whole url is specified here, it will be parsed and arguments:
"logbook, port, subdir, use_ssl" will be overwritten by parsed values.
:param logbook: name of the logbook on the elog server
:param port: elog server port (if not specified will default to '80' if use_ssl=False or '443' if use_ssl=True
:param user: username (if authentication needed)
:param password: password (if authentication needed) Password will be encrypted with sha256 unless
encrypt_pwd=False (default: True)
:param subdir: subdirectory of logbooks locations
:param use_ssl: connect using ssl (ignored if url starts with 'http://'' or 'https://'?
:param encrypt_pwd: To avoid exposing password in the code, this flag can be set to False and password
will then be handled as it is (user needs to provide sha256 encrypted password with
salt= '' and rounds=5000)
:return:
"""
hostname = hostname.strip()
# parse url to see if some parameters are defined with url
parsed_url = urllib.parse.urlsplit(hostname)
# ---- handle SSL -----
# hostname must be modified according to use_ssl flag. If hostname starts with https:// or http://
# the use_ssl flag is ignored
url_scheme = parsed_url.scheme
if url_scheme == 'http':
use_ssl = False
elif url_scheme == 'https':
use_ssl = True
elif not url_scheme:
# add http or https
if use_ssl:
url_scheme = 'https'
else:
url_scheme = 'http'
# ---- handle port -----
# 1) by default use port defined in the url
# 2) remove any 'default' ports such as 80 for http and 443 for https
# 3) if port not defined in url and not 'default' add it to netloc
netloc = parsed_url.netloc
if netloc == "" and "localhost" in hostname:
netloc = 'localhost'
netloc_split = netloc.split(':')
if len(netloc_split) > 1:
# port defined in url --> remove if needed
port = netloc_split[1]
if (port == 80 and not use_ssl) or (port == 443 and use_ssl):
netloc = netloc_split[0]
else:
# add port info if needed
if port is not None and not (port == 80 and not use_ssl) and not (port == 443 and use_ssl):
netloc += ':{}'.format(port)
# ---- handle subdir and logbook -----
# parsed_url.path = /<subdir>/<logbook>/
# Remove last '/' for easier parsing
url_path = parsed_url.path
if url_path.endswith('/'):
url_path = url_path[:-1]
splitted_path = url_path.split('/')
if url_path and len(splitted_path) > 1:
# If here ... then at least some part of path is defined.
# If logbook defined --> treat path current path as subdir and add logbook at the end
# to define the full path. Else treat existing path as <subdir>/<logbook>.
# Put first and last '/' back on its place
if logbook:
url_path += '/{}'.format(logbook)
else:
logbook = splitted_path[-1]
else:
# There is nothing. Use arguments.
url_path = subdir + '/' + logbook
# urllib.parse.quote replaces special characters with %xx escapes
# self._logbook_path = urllib.parse.quote('/' + url_path + '/').replace('//', '/')
self._logbook_path = ('/' + url_path + '/').replace('//', '/')
self._url = url_scheme + '://' + netloc + self._logbook_path
self.logbook = logbook
self._user = user
self._password = _handle_pswd(password, encrypt_pwd)
def post(self, message, msg_id=None, reply=False, attributes=None, attachments=None, encoding=None,
**kwargs):
"""
Posts message to the logbook. If msg_id is not specified new message will be created, otherwise existing
message will be edited, or a reply (if reply=True) to it will be created. This method returns the msg_id
of the newly created message.
:param message: string with message text
:param msg_id: ID number of message to edit or reply. If not specified new message is created.
:param reply: If 'True' reply to existing message is created instead of editing it
:param attributes: Dictionary of attributes. Following attributes are used internally by the elog and will be
ignored: Text, Date, Encoding, Reply to, In reply to, Locked by, Attachment
:param attachments: list of:
- file like objects which read() will return bytes (if file_like_object.name is not
defined, default name "attachment<i>" will be used.
- paths to the files
All items will be appended as attachment to the elog entry. In case of unknown
attachment an exception LogbookInvalidAttachment will be raised.
:param encoding: Defines encoding of the message. Can be: 'plain' -> plain text, 'html'->html-text,
'ELCode' --> elog formatting syntax
:param kwargs: Anything in the kwargs will be interpreted as attribute. e.g.: logbook.post('Test text',
Author='Rok Vintar), "Author" will be sent as an attribute. If named same as one of the
attributes defined in "attributes", kwargs will have priority.
:return: msg_id
"""
attributes = attributes or {}
attributes = {**attributes, **kwargs} # kwargs as attributes with higher priority
attachments = attachments or []
if encoding is not None:
if encoding not in ['plain', 'HTML', 'ELCode']:
raise LogbookMessageRejected('Invalid message encoding. Valid options: plain, HTML, ELCode.')
attributes['Encoding'] = encoding
attributes_to_edit = dict()
if msg_id:
# Message exists, we can continue
if reply:
# Verify that there is a message on the server, otherwise do not reply to it!
self._check_if_message_on_server(msg_id) # raises exception in case of none existing message
attributes['reply_to'] = str(msg_id)
else: # Edit existing
attributes['edit_id'] = str(msg_id)
attributes['skiplock'] = '1'
# Handle existing attachments
msg_to_edit, attributes_to_edit, attach_to_edit = self.read(msg_id)
i = 0
for attachment in attach_to_edit:
if attachment:
# Existing attachments must be passed as regular arguments attachment<i> with value= file name
# Read message returnes full urls to existing attachments:
# <hostname>:[<port>][/<subdir]/<logbook>/<msg_id>/<file_name>
attributes['attachment' + str(i)] = os.path.basename(attachment)
i += 1
for attribute, data in attributes.items():
new_data = attributes.get(attribute)
if new_data is not None:
attributes_to_edit[attribute] = new_data
else:
# As we create a new message, specify creation time if not already specified in attributes
if 'When' not in attributes:
attributes['When'] = int(datetime.now().timestamp())
if not attributes_to_edit:
attributes_to_edit = attributes
# Remove any attributes that should not be sent
_remove_reserved_attributes(attributes_to_edit)
if attachments:
files_to_attach, objects_to_close = self._prepare_attachments(attachments)
else:
objects_to_close = list()
files_to_attach = list()
# Make requests module think that Text is a "file". This is the only way to force requests to send data as
# multipart/form-data even if there are no attachments. Elog understands only multipart/form-data
files_to_attach.append(('Text', ('', message)))
# Base attributes are common to all messages
self._add_base_msg_attributes(attributes_to_edit)
# Keys in attributes cannot have certain characters like whitespaces or dashes for the http request
attributes_to_edit = _replace_special_characters_in_attribute_keys(attributes_to_edit)
try:
response = requests.post(self._url, data=attributes_to_edit, files=files_to_attach, allow_redirects=False,
verify=False)
# Validate response. Any problems will raise an Exception.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
# Close file like objects that were opened by the elog (if path
for file_like_object in objects_to_close:
if hasattr(file_like_object, 'close'):
file_like_object.close()
except requests.RequestException as e:
# Check if message on server.
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to post a message, ' + 'because of:\n' +
'{0}'.format(e))
# Any error before here should raise an exception, but check again for nay case.
if not resp_msg_id or resp_msg_id < 1:
raise LogbookInvalidMessageID('Invalid message ID: ' + str(resp_msg_id) + ' returned')
return resp_msg_id
def read(self, msg_id):
"""
Reads message from the logbook server and returns tuple of (message, attributes, attachments) where:
message: string with message body
attributes: dictionary of all attributes returned by the logbook
attachments: list of urls to attachments on the logbook server
:param msg_id: ID of the message to be read
:return: message, attributes, attachments
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
response = requests.get(self._url + str(msg_id) + '?cmd=download', headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read the message with ID: ' + str(msg_id) +
'because of:\n' + '{0}'.format(e))
# Parse message to separate message body, attributes and attachments
attributes = dict()
attachments = list()
returned_msg = resp_message.decode('utf-8', 'ignore').splitlines()
delimiter_idx = returned_msg.index('========================================')
message = '\n'.join(returned_msg[delimiter_idx + 1:])
for line in returned_msg[0:delimiter_idx]:
line = line.split(': ')
data = ''.join(line[1:])
if line[0] == 'Attachment':
attachments = data.split(',')
# Here are only attachment names, make a full url out of it, so they could be
# recognisable by others, and downloaded if needed
attachments = [self._url + '{0}'.format(i) for i in attachments]
else:
attributes[line[0]] = data
return message, attributes, attachments
def delete(self, msg_id):
"""
Deletes message thread (!!!message + all replies!!!) from logbook.
It also deletes all of attachments of corresponding messages from the server.
:param msg_id: message to be deleted
:return:
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
self._check_if_message_on_server(msg_id) # check if something to delete
response = requests.get(self._url + str(msg_id) + '?cmd=Delete&confirm=Yes', headers=request_headers,
allow_redirects=False, verify=False)
_validate_response(response) # raises exception if any other error identified
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to delete the message with ID: ' + str(msg_id) +
'because of:\n' + '{0}'.format(e))
# Additional validation: If successfully deleted then status_code = 302. In case command was not executed at
# all (not English language --> no download command supported) status_code = 200 and the content is just a
# html page of this whole message.
if response.status_code == 200:
raise LogbookServerProblem('Cannot process delete command (only logbooks in English supported).')
def search(self, search_term, n_results = 20, scope="subtext"):
"""
Searches the logbook and returns the message ids.
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
# Putting n_results = 0 crashes the elog. also in the web-gui.
n_results = 1 if n_results < 1 else n_results
params = {
"mode": "full",
"reverse": "1",
"npp": n_results,
scope: search_term
}
try:
response = requests.get(self._url, params=params, headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
_validate_response(response)
resp_message = response
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read message ids '
'because of:\n' + '{0}'.format(e))
from lxml import html
tree = html.fromstring(resp_message.content)
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
message_ids = [int(m.split("/")[-1]) for m in message_ids]
return message_ids
def get_last_message_id(self):
ids = self.get_message_ids()
if len(ids) > 0:
return ids[0]
else:
return None
def get_message_ids(self):
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
response = requests.get(self._url + 'page', headers=request_headers,
allow_redirects=False, verify=False)
# Validate response. If problems Exception will be thrown.
_validate_response(response)
resp_message = response
except requests.RequestException as e:
# If here: message is on server but cannot be downloaded (should never happen)
raise LogbookServerProblem('Cannot access logbook server to read message ids '
'because of:\n' + '{0}'.format(e))
from lxml import html
tree = html.fromstring(resp_message.content)
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
message_ids = [int(m.split("/")[-1]) for m in message_ids]
return message_ids
def _check_if_message_on_server(self, msg_id):
"""Try to load page for specific message. If there is a htm tag like <td class="errormsg"> then there is no
such message.
:param msg_id: ID of message to be checked
:return:
"""
request_headers = dict()
if self._user or self._password:
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
try:
response = requests.get(self._url + str(msg_id), headers=request_headers, allow_redirects=False,
verify=False)
# If there is no message code 200 will be returned (OK) and _validate_response will not recognise it
# but there will be some error in the html code.
resp_message, resp_headers, resp_msg_id = _validate_response(response)
# If there is no message, code 200 will be returned (OK) but there will be some error indication in
# the html code.
if re.findall('<td.*?class="errormsg".*?>.*?</td>',
resp_message.decode('utf-8', 'ignore'),
flags=re.DOTALL):
raise LogbookInvalidMessageID('Message with ID: ' + str(msg_id) + ' does not exist on logbook.')
except requests.RequestException as e:
raise LogbookServerProblem('No response from the logbook server.\nDetails: ' + '{0}'.format(e))
def _add_base_msg_attributes(self, data):
"""
Adds base message attributes which are used by all messages.
:param data: dict of current attributes
:return: content string
"""
data['cmd'] = 'Submit'
data['exp'] = self.logbook
if self._user:
data['unm'] = self._user
if self._password:
data['upwd'] = self._password
def _prepare_attachments(self, files):
"""
Parses attachments to content objects. Attachments can be:
- file like objects: must have method read() which returns bytes. If it has attribute .name it will be used
for attachment name, otherwise generic attribute<i> name will be used.
- path to the file on disk
Note that if attachment is is an url pointing to the existing Logbook server it will be ignored and no
exceptions will be raised. This can happen if attachments returned with read_method are resend.
:param files: list of file like objects or paths
:return: content string
"""
prepared = list()
i = 0
objects_to_close = list() # objects that are created (opened) by elog must be later closed
for file_obj in files:
if hasattr(file_obj, 'read'):
i += 1
attribute_name = 'attfile' + str(i)
filename = attribute_name # If file like object has no name specified use this one
candidate_filename = os.path.basename(file_obj.name)
if filename: # use only if not empty string
filename = candidate_filename
elif isinstance(file_obj, str):
# Check if it is:
# - a path to the file --> open file and append
# - an url pointing to the existing Logbook server --> ignore
filename = ""
attribute_name = ""
if os.path.isfile(file_obj):
i += 1
attribute_name = 'attfile' + str(i)
file_obj = builtins.open(file_obj, 'rb')
filename = os.path.basename(file_obj.name)
objects_to_close.append(file_obj)
elif not file_obj.startswith(self._url):
raise LogbookInvalidAttachmentType('Invalid type of attachment: \"' + file_obj + '\".')
else:
raise LogbookInvalidAttachmentType('Invalid type of attachment[' + str(i) + '].')
prepared.append((attribute_name, (filename, file_obj)))
return prepared, objects_to_close
def _make_user_and_pswd_cookie(self):
"""
prepares user name and password cookie. It is sent in header when posting a message.
:return: user name and password value for the Cookie header
"""
cookie = ''
if self._user:
cookie += 'unm=' + self._user + ';'
if self._password:
cookie += 'upwd=' + self._password + ';'
return cookie
def _remove_reserved_attributes(attributes):
"""
Removes elog reserved attributes (from the attributes dict) that can not be sent.
:param attributes: dictionary of attributes to be cleaned.
:return:
"""
if attributes:
attributes.get('$@MID@$', None)
attributes.pop('Date', None)
attributes.pop('Attachment', None)
attributes.pop('Text', None) # Remove this one because it will be send attachment like
def _replace_special_characters_in_attribute_keys(attributes):
"""
Replaces special characters in elog attribute keys by underscore, otherwise attribute values will be erased in
the http request. This is using the same replacement elog itself is using to handle these cases
:param attributes: dictionary of attributes to be cleaned.
:return: attributes with replaced keys
"""
return {re.sub('[^0-9a-zA-Z]', '_', key): value for key, value in attributes.items()}
def _validate_response(response):
""" Validate response of the request."""
msg_id = None
if response.status_code not in [200, 302]:
# 200 --> OK; 302 --> Found
# Html page is returned with error description (handling errors same way as on original client. Looks
# like there is no other way.
err = re.findall('<td.*?class="errormsg".*?>.*?</td>',
response.content.decode('utf-8', 'ignore'),
flags=re.DOTALL)
if len(err) > 0:
# Remove html tags
# If part of the message has: Please go back... remove this part since it is an instruction for
# the user when using browser.
err = re.sub('(?:<.*?>)', '', err[0])
if err:
raise LogbookMessageRejected('Rejected because of: ' + err)
else:
raise LogbookMessageRejected('Rejected because of unknown error.')
# Other unknown errors
raise LogbookMessageRejected('Rejected because of unknown error.')
else:
location = response.headers.get('Location')
if location is not None:
if 'has moved' in location:
raise LogbookServerProblem('Logbook server has moved to another location.')
elif 'fail' in location:
raise LogbookAuthenticationError('Invalid username or password.')
else:
# returned locations is something like: '<host>/<sub_dir>/<logbook>/<msg_id><query>
# with urllib.parse.urlparse returns attribute path=<sub_dir>/<logbook>/<msg_id>
msg_id = int(urllib.parse.urlsplit(location).path.split('/')[-1])
if b'form name=form1' in response.content or b'type=password' in response.content:
# Not to smart to check this way, but no other indication of this kind of error.
# C client does it the same way
raise LogbookAuthenticationError('Invalid username or password.')
return response.content, response.headers, msg_id
def _handle_pswd(password, encrypt=True):
"""
Takes password string and returns password as needed by elog. If encrypt=True then password will be
sha256 encrypted (salt='', rounds=5000). Before returning password, any trailing $5$$ will be removed
independent off encrypt flag.
:param password: password string
:param encrypt: encrypt password?
:return: elog prepared password
"""
if encrypt and password:
from passlib.hash import sha256_crypt
return sha256_crypt.encrypt(password, salt='', rounds=5000)[4:]
elif password and password.startswith('$5$$'):
return password[4:]
else:
return password

View File

@@ -0,0 +1,28 @@
class LogbookError(Exception):
""" Parent logbook exception."""
pass
class LogbookAuthenticationError(LogbookError):
""" Raise when problem with username and password."""
pass
class LogbookServerProblem(LogbookError):
""" Raise when problem accessing logbook server."""
pass
class LogbookMessageRejected(LogbookError):
""" Raised when manipulating/creating message was rejected by the server or there was problem composing message."""
pass
class LogbookInvalidMessageID(LogbookMessageRejected):
""" Raised when there is no message with specified ID on the server."""
pass
class LogbookInvalidAttachmentType(LogbookMessageRejected):
""" Raised when passed attachment has invalid type."""
pass

View File

@@ -0,0 +1,68 @@
__version__ = "0.7.2"
from bigtree.binarytree.construct import list_to_binarytree
from bigtree.dag.construct import dataframe_to_dag, dict_to_dag, list_to_dag
from bigtree.dag.export import dag_to_dataframe, dag_to_dict, dag_to_dot, dag_to_list
from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
from bigtree.node.dagnode import DAGNode
from bigtree.node.node import Node
from bigtree.tree.construct import (
add_dataframe_to_tree_by_name,
add_dataframe_to_tree_by_path,
add_dict_to_tree_by_name,
add_dict_to_tree_by_path,
add_path_to_tree,
dataframe_to_tree,
dataframe_to_tree_by_relation,
dict_to_tree,
list_to_tree,
list_to_tree_by_relation,
nested_dict_to_tree,
str_to_tree,
)
from bigtree.tree.export import (
print_tree,
tree_to_dataframe,
tree_to_dict,
tree_to_dot,
tree_to_nested_dict,
tree_to_pillow,
yield_tree,
)
from bigtree.tree.helper import clone_tree, get_tree_diff, prune_tree
from bigtree.tree.modify import (
copy_nodes,
copy_nodes_from_tree_to_tree,
copy_or_shift_logic,
shift_nodes,
)
from bigtree.tree.search import (
find,
find_attr,
find_attrs,
find_children,
find_full_path,
find_name,
find_names,
find_path,
find_paths,
findall,
)
from bigtree.utils.exceptions import (
CorruptedTreeError,
DuplicatedNodeError,
LoopError,
NotFoundError,
SearchError,
TreeError,
)
from bigtree.utils.iterators import (
dag_iterator,
inorder_iter,
levelorder_iter,
levelordergroup_iter,
postorder_iter,
preorder_iter,
)
from bigtree.workflows.app_todo import AppToDo

View File

@@ -0,0 +1,50 @@
from typing import List, Type, Union
from bigtree.node.binarynode import BinaryNode
def list_to_binarytree(
heapq_list: List[Union[int, float]], node_type: Type[BinaryNode] = BinaryNode
) -> BinaryNode:
"""Construct tree from list of numbers (int or float) in heapq format.
>>> from bigtree import list_to_binarytree, print_tree, tree_to_dot
>>> nums_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> root = list_to_binarytree(nums_list)
>>> print_tree(root)
1
├── 2
│ ├── 4
│ │ ├── 8
│ │ └── 9
│ └── 5
│ └── 10
└── 3
├── 6
└── 7
>>> graph = tree_to_dot(root, node_colour="gold")
>>> graph.write_png("assets/binarytree.png")
.. image:: https://github.com/kayjan/bigtree/raw/master/assets/binarytree.png
Args:
heapq_list (List[Union[int, float]]): list containing integer node names, ordered in heapq fashion
node_type (Type[BinaryNode]): node type of tree to be created, defaults to BinaryNode
Returns:
(BinaryNode)
"""
if not len(heapq_list):
raise ValueError("Input list does not contain any data, check `heapq_list`")
root = node_type(heapq_list[0])
node_list = [root]
for idx, num in enumerate(heapq_list):
if idx:
if idx % 2:
parent_idx = int((idx - 1) / 2)
else:
parent_idx = int((idx - 2) / 2)
node = node_type(num, parent=node_list[parent_idx])
node_list.append(node)
return root

View File

@@ -0,0 +1,186 @@
from typing import List, Tuple, Type
import numpy as np
import pandas as pd
from bigtree.node.dagnode import DAGNode
__all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"]
def list_to_dag(
relations: List[Tuple[str, str]],
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from list of tuple containing parent-child names.
Note that node names must be unique.
>>> from bigtree import list_to_dag, dag_iterator
>>> relations_list = [("a", "c"), ("a", "d"), ("b", "c"), ("c", "d"), ("d", "e")]
>>> dag = list_to_dag(relations_list)
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
relations (list): list containing tuple of parent-child names
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
Returns:
(DAGNode)
"""
if not len(relations):
raise ValueError("Input list does not contain any data, check `relations`")
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
return dataframe_to_dag(
relation_data, child_col="child", parent_col="parent", node_type=node_type
)
def dict_to_dag(
relation_attrs: dict,
parent_key: str = "parents",
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from nested dictionary, ``key``: child name, ``value``: dict of parent names, attribute name and
attribute value.
Note that node names must be unique.
>>> from bigtree import dict_to_dag, dag_iterator
>>> relation_dict = {
... "a": {"step": 1},
... "b": {"step": 1},
... "c": {"parents": ["a", "b"], "step": 2},
... "d": {"parents": ["a", "c"], "step": 2},
... "e": {"parents": ["d"], "step": 3},
... }
>>> dag = dict_to_dag(relation_dict, parent_key="parents")
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
relation_attrs (dict): dictionary containing node, node parents, and node attribute information,
key: child name, value: dict of parent names, node attribute and attribute value
parent_key (str): key of dictionary to retrieve list of parents name, defaults to "parent"
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
Returns:
(DAGNode)
"""
if not len(relation_attrs):
raise ValueError("Dictionary does not contain any data, check `relation_attrs`")
# Convert dictionary to dataframe
data = pd.DataFrame(relation_attrs).T.rename_axis("_tmp_child").reset_index()
assert (
parent_key in data
), f"Parent key {parent_key} not in dictionary, check `relation_attrs` and `parent_key`"
data = data.explode(parent_key)
return dataframe_to_dag(
data,
child_col="_tmp_child",
parent_col=parent_key,
node_type=node_type,
)
def dataframe_to_dag(
data: pd.DataFrame,
child_col: str = None,
parent_col: str = None,
attribute_cols: list = [],
node_type: Type[DAGNode] = DAGNode,
) -> DAGNode:
"""Construct DAG from pandas DataFrame.
Note that node names must be unique.
`child_col` and `parent_col` specify columns for child name and parent name to construct DAG.
`attribute_cols` specify columns for node attribute for child name
If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
columns are `attribute_cols`.
>>> import pandas as pd
>>> from bigtree import dataframe_to_dag, dag_iterator
>>> relation_data = pd.DataFrame([
... ["a", None, 1],
... ["b", None, 1],
... ["c", "a", 2],
... ["c", "b", 2],
... ["d", "a", 2],
... ["d", "c", 2],
... ["e", "d", 3],
... ],
... columns=["child", "parent", "step"]
... )
>>> dag = dataframe_to_dag(relation_data)
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
Args:
data (pandas.DataFrame): data containing path and node attribute information
child_col (str): column of data containing child name information, defaults to None
if not set, it will take the first column of data
parent_col (str): column of data containing parent name information, defaults to None
if not set, it will take the second column of data
attribute_cols (list): columns of data containing child node attribute information,
if not set, it will take all columns of data except `child_col` and `parent_col`
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
Returns:
(DAGNode)
"""
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not child_col:
child_col = data.columns[0]
if not parent_col:
parent_col = data.columns[1]
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(child_col)
attribute_cols.remove(parent_col)
data_check = data.copy()[[child_col] + attribute_cols].drop_duplicates()
_duplicate_check = (
data_check[child_col]
.value_counts()
.to_frame("counts")
.rename_axis(child_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate child name with different attributes\nCheck {_duplicate_check}"
)
if np.any(data[child_col].isnull()):
raise ValueError(f"Child name cannot be empty, check {child_col}")
node_dict = dict()
parent_node = None
for row in data.reset_index(drop=True).to_dict(orient="index").values():
child_name = row[child_col]
parent_name = row[parent_col]
node_attrs = row.copy()
del node_attrs[child_col]
del node_attrs[parent_col]
node_attrs = {k: v for k, v in node_attrs.items() if not pd.isnull(v)}
child_node = node_dict.get(child_name)
if not child_node:
child_node = node_type(child_name)
node_dict[child_name] = child_node
child_node.set_attrs(node_attrs)
if not pd.isnull(parent_name):
parent_node = node_dict.get(parent_name)
if not parent_node:
parent_node = node_type(parent_name)
node_dict[parent_name] = parent_node
child_node.parents = [parent_node]
return parent_node

View File

@@ -0,0 +1,269 @@
from typing import Any, Dict, List, Tuple, Union
import pandas as pd
from bigtree.node.dagnode import DAGNode
from bigtree.utils.iterators import dag_iterator
__all__ = ["dag_to_list", "dag_to_dict", "dag_to_dataframe", "dag_to_dot"]
def dag_to_list(
dag: DAGNode,
) -> List[Tuple[str, str]]:
"""Export DAG to list of tuple containing parent-child names
>>> from bigtree import DAGNode, dag_to_list
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_list(a)
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
Args:
dag (DAGNode): DAG to be exported
Returns:
(List[Tuple[str, str]])
"""
relations = []
for parent_node, child_node in dag_iterator(dag):
relations.append((parent_node.node_name, child_node.node_name))
return relations
def dag_to_dict(
dag: DAGNode,
parent_key: str = "parents",
attr_dict: dict = {},
all_attrs: bool = False,
) -> Dict[str, Any]:
"""Export tree to dictionary.
Exported dictionary will have key as child name, and parent names and node attributes as a nested dictionary.
>>> from bigtree import DAGNode, dag_to_dict
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_dict(a, parent_key="parent", attr_dict={"step": "step no."})
{'a': {'step no.': 1}, 'c': {'parent': ['a', 'b'], 'step no.': 2}, 'd': {'parent': ['a', 'c'], 'step no.': 2}, 'b': {'step no.': 1}, 'e': {'parent': ['d'], 'step no.': 3}}
Args:
dag (DAGNode): DAG to be exported
parent_key (str): dictionary key for `node.parent.node_name`, defaults to `parents`
attr_dict (dict): dictionary mapping node attributes to dictionary key,
key: node attributes, value: corresponding dictionary key, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes
Returns:
(dict)
"""
dag = dag.copy()
data_dict = {}
for parent_node, child_node in dag_iterator(dag):
if parent_node.is_root:
data_parent = {}
if all_attrs:
data_parent.update(
parent_node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
else:
for k, v in attr_dict.items():
data_parent[v] = parent_node.get_attr(k)
data_dict[parent_node.node_name] = data_parent
if data_dict.get(child_node.node_name):
data_dict[child_node.node_name][parent_key].append(parent_node.node_name)
else:
data_child = {parent_key: [parent_node.node_name]}
if all_attrs:
data_child.update(
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
)
else:
for k, v in attr_dict.items():
data_child[v] = child_node.get_attr(k)
data_dict[child_node.node_name] = data_child
return data_dict
def dag_to_dataframe(
dag: DAGNode,
name_col: str = "name",
parent_col: str = "parent",
attr_dict: dict = {},
all_attrs: bool = False,
) -> pd.DataFrame:
"""Export DAG to pandas DataFrame.
>>> from bigtree import DAGNode, dag_to_dataframe
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_to_dataframe(a, name_col="name", parent_col="parent", attr_dict={"step": "step no."})
name parent step no.
0 a None 1
1 c a 2
2 d a 2
3 b None 1
4 c b 2
5 d c 2
6 e d 3
Args:
dag (DAGNode): DAG to be exported
name_col (str): column name for `node.node_name`, defaults to 'name'
parent_col (str): column name for `node.parent.node_name`, defaults to 'parent'
attr_dict (dict): dictionary mapping node attributes to column name,
key: node attributes, value: corresponding column in dataframe, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes
Returns:
(pd.DataFrame)
"""
dag = dag.copy()
data_list = []
for parent_node, child_node in dag_iterator(dag):
if parent_node.is_root:
data_parent = {name_col: parent_node.node_name, parent_col: None}
if all_attrs:
data_parent.update(
parent_node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
else:
for k, v in attr_dict.items():
data_parent[v] = parent_node.get_attr(k)
data_list.append(data_parent)
data_child = {name_col: child_node.node_name, parent_col: parent_node.node_name}
if all_attrs:
data_child.update(
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
)
else:
for k, v in attr_dict.items():
data_child[v] = child_node.get_attr(k)
data_list.append(data_child)
return pd.DataFrame(data_list).drop_duplicates().reset_index(drop=True)
def dag_to_dot(
dag: Union[DAGNode, List[DAGNode]],
rankdir: str = "TB",
bg_colour: str = None,
node_colour: str = None,
edge_colour: str = None,
node_attr: str = None,
edge_attr: str = None,
):
r"""Export DAG tree or list of DAG trees to image.
Note that node names must be unique.
Posible node attributes include style, fillcolor, shape.
>>> from bigtree import DAGNode, dag_to_dot
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> dag_graph = dag_to_dot(a)
Export to image, dot file, etc.
>>> dag_graph.write_png("tree_dag.png")
>>> dag_graph.write_dot("tree_dag.dot")
Export to string
>>> dag_graph.to_string()
'strict digraph G {\nrankdir=TB;\nc [label=c];\na [label=a];\na -> c;\nd [label=d];\na [label=a];\na -> d;\nc [label=c];\nb [label=b];\nb -> c;\nd [label=d];\nc [label=c];\nc -> d;\ne [label=e];\nd [label=d];\nd -> e;\n}\n'
Args:
dag (Union[DAGNode, List[DAGNode]]): DAG or list of DAGs to be exported
rankdir (str): set direction of graph layout, defaults to 'TB', can be 'BT, 'LR', 'RL'
bg_colour (str): background color of image, defaults to None
node_colour (str): fill colour of nodes, defaults to None
edge_colour (str): colour of edges, defaults to None
node_attr (str): node attribute for style, overrides node_colour, defaults to None
Possible node attributes include {"style": "filled", "fillcolor": "gold"}
edge_attr (str): edge attribute for style, overrides edge_colour, defaults to None
Possible edge attributes include {"style": "bold", "label": "edge label", "color": "black"}
Returns:
(pydot.Dot)
"""
try:
import pydot
except ImportError: # pragma: no cover
raise ImportError(
"pydot not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
)
# Get style
if bg_colour:
graph_style = dict(bgcolor=bg_colour)
else:
graph_style = dict()
if node_colour:
node_style = dict(style="filled", fillcolor=node_colour)
else:
node_style = dict()
if edge_colour:
edge_style = dict(color=edge_colour)
else:
edge_style = dict()
_graph = pydot.Dot(
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
)
if not isinstance(dag, list):
dag = [dag]
for _dag in dag:
if not isinstance(_dag, DAGNode):
raise ValueError(
"Tree should be of type `DAGNode`, or inherit from `DAGNode`"
)
_dag = _dag.copy()
for parent_node, child_node in dag_iterator(_dag):
child_name = child_node.name
child_node_style = node_style.copy()
if node_attr and child_node.get_attr(node_attr):
child_node_style.update(child_node.get_attr(node_attr))
if edge_attr:
edge_style.update(child_node.get_attr(edge_attr))
pydot_child = pydot.Node(
name=child_name, label=child_name, **child_node_style
)
_graph.add_node(pydot_child)
parent_name = parent_node.name
parent_node_style = node_style.copy()
if node_attr and parent_node.get_attr(node_attr):
parent_node_style.update(parent_node.get_attr(node_attr))
pydot_parent = pydot.Node(
name=parent_name, label=parent_name, **parent_node_style
)
_graph.add_node(pydot_parent)
edge = pydot.Edge(parent_name, child_name, **edge_style)
_graph.add_edge(edge)
return _graph

View File

@@ -0,0 +1,696 @@
import copy
from typing import Any, Dict, Iterable, List
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
from bigtree.utils.iterators import preorder_iter
class BaseNode:
"""
BaseNode extends any Python class to a tree node.
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
Nodes can be linked to each other with `parent` and `children` setter methods,
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
>>> from bigtree import Node, print_tree
>>> root = Node("a", age=90)
>>> b = Node("b", age=65)
>>> c = Node("c", age=60)
>>> d = Node("d", age=40)
>>> root.children = [b, c]
>>> d.parent = b
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ └── d [age=40]
└── c [age=60]
>>> from bigtree import Node
>>> root = Node("a", age=90)
>>> b = Node("b", age=65)
>>> c = Node("c", age=60)
>>> d = Node("d", age=40)
>>> root >> b
>>> root >> c
>>> d << b
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ └── d [age=40]
└── c [age=60]
Directly passing `parent` argument.
>>> from bigtree import Node
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=b)
Directly passing `children` argument.
>>> from bigtree import Node
>>> d = Node("d")
>>> c = Node("c")
>>> b = Node("b", children=[d])
>>> a = Node("a", children=[b, c])
**BaseNode Creation**
Node can be created by instantiating a `BaseNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import Node
>>> root = Node.from_dict({"name": "a", "age": 90})
**BaseNode Attributes**
These are node attributes that have getter and/or setter methods.
Get and set other `BaseNode`
1. ``parent``: Get/set parent node
2. ``children``: Get/set child nodes
Get other `BaseNode`
1. ``ancestors``: Get ancestors of node excluding self, iterator
2. ``descendants``: Get descendants of node excluding self, iterator
3. ``leaves``: Get all leaf node(s) from self, iterator
4. ``siblings``: Get siblings of self
5. ``left_sibling``: Get sibling left of self
6. ``right_sibling``: Get sibling right of self
Get `BaseNode` configuration
1. ``node_path``: Get tuple of nodes from root
2. ``is_root``: Get indicator if self is root node
3. ``is_leaf``: Get indicator if self is leaf node
4. ``root``: Get root node of tree
5. ``depth``: Get depth of self
6. ``max_depth``: Get maximum depth from root to leaf node
**BaseNode Methods**
These are methods available to be performed on `BaseNode`.
Constructor methods
1. ``from_dict()``: Create BaseNode from dictionary
`BaseNode` methods
1. ``describe()``: Get node information sorted by attributes, returns list of tuples
2. ``get_attr(attr_name: str)``: Get value of node attribute
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
4. ``go_to(node: BaseNode)``: Get a path from own node to another node from same tree
5. ``copy()``: Deep copy BaseNode
6. ``sort()``: Sort child nodes
----
"""
def __init__(self, parent=None, children: List = None, **kwargs):
self.__parent = None
self.__children = []
if children is None:
children = []
self.parent = parent
self.children = children
if "parents" in kwargs:
raise ValueError(
"Attempting to set `parents` attribute, do you mean `parent`?"
)
self.__dict__.update(**kwargs)
@property
def parent(self):
"""Get parent node
Returns:
(Self)
"""
return self.__parent
@staticmethod
def __check_parent_type(new_parent):
"""Check parent type
Args:
new_parent (Self): parent node
"""
if not (isinstance(new_parent, BaseNode) or new_parent is None):
raise TypeError(
f"Expect input to be BaseNode type or NoneType, received input type {type(new_parent)}"
)
def __check_parent_loop(self, new_parent):
"""Check parent type
Args:
new_parent (Self): parent node
"""
if new_parent is not None:
if new_parent is self:
raise LoopError("Error setting parent: Node cannot be parent of itself")
if any(
ancestor is self
for ancestor in new_parent.ancestors
if new_parent.ancestors
):
raise LoopError(
"Error setting parent: Node cannot be ancestor of itself"
)
@parent.setter
def parent(self, new_parent):
"""Set parent node
Args:
new_parent (Self): parent node
"""
self.__check_parent_type(new_parent)
self.__check_parent_loop(new_parent)
current_parent = self.parent
current_child_idx = None
# Assign new parent - rollback if error
self.__pre_assign_parent(new_parent)
try:
# Remove self from old parent
if current_parent is not None:
if not any(
child is self for child in current_parent.children
): # pragma: no cover
raise CorruptedTreeError(
"Error setting parent: Node does not exist as children of its parent"
)
current_child_idx = current_parent.__children.index(self)
current_parent.__children.remove(self)
# Assign self to new parent
self.__parent = new_parent
if new_parent is not None:
new_parent.__children.append(self)
self.__post_assign_parent(new_parent)
except Exception as exc_info:
# Remove self from new parent
if new_parent is not None:
new_parent.__children.remove(self)
# Reassign self to old parent
self.__parent = current_parent
if current_child_idx is not None:
current_parent.__children.insert(current_child_idx, self)
raise TreeError(exc_info)
def __pre_assign_parent(self, new_parent):
"""Custom method to check before attaching parent
Can be overriden with `_BaseNode__pre_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def __post_assign_parent(self, new_parent):
"""Custom method to check after attaching parent
Can be overriden with `_BaseNode__post_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
@property
def parents(self) -> None:
"""Do not allow `parents` attribute to be accessed"""
raise ValueError(
"Attempting to access `parents` attribute, do you mean `parent`?"
)
@parents.setter
def parents(self, new_parent):
"""Do not allow `parents` attribute to be set
Args:
new_parent (Self): parent node
"""
raise ValueError("Attempting to set `parents` attribute, do you mean `parent`?")
@property
def children(self) -> Iterable:
"""Get child nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__children)
def __check_children_type(self, new_children: List):
"""Check child type
Args:
new_children (List[Self]): child node
"""
if not isinstance(new_children, list):
raise TypeError(
f"Children input should be list type, received input type {type(new_children)}"
)
def __check_children_loop(self, new_children: List):
"""Check child loop
Args:
new_children (List[Self]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if not isinstance(new_child, BaseNode):
raise TypeError(
f"Expect input to be BaseNode type, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestors of itself"
)
# Check for duplicate children
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@children.setter
def children(self, new_children: List):
"""Set child nodes
Args:
new_children (List[Self]): child node
"""
self.__check_children_type(new_children)
self.__check_children_loop(new_children)
current_new_children = {
new_child: (new_child.parent.__children.index(new_child), new_child.parent)
for new_child in new_children
if new_child.parent is not None
}
current_new_orphan = [
new_child for new_child in new_children if new_child.parent is None
]
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Remove old children from self
del self.children
# Assign new children to self
self.__children = new_children
for new_child in new_children:
if new_child.parent:
new_child.parent.__children.remove(new_child)
new_child.__parent = self
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign new children to their original parent
for child, idx_parent in current_new_children.items():
child_idx, parent = idx_parent
child.__parent = parent
parent.__children.insert(child_idx, child)
for child in current_new_orphan:
child.__parent = None
# Reassign old children to self
self.__children = current_children
for child in current_children:
child.__parent = self
raise TreeError(exc_info)
@children.deleter
def children(self):
"""Delete child node(s)"""
for child in self.children:
child.parent.__children.remove(child)
child.__parent = None
def __pre_assign_children(self, new_children: List):
"""Custom method to check before attaching children
Can be overriden with `_BaseNode__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self, new_children: List):
"""Custom method to check after attaching children
Can be overriden with `_BaseNode__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
@property
def ancestors(self) -> Iterable:
"""Get iterator to yield all ancestors of self, does not include self
Returns:
(Iterable[Self])
"""
node = self.parent
while node is not None:
yield node
node = node.parent
@property
def descendants(self) -> Iterable:
"""Get iterator to yield all descendants of self, does not include self
Returns:
(Iterable[Self])
"""
yield from preorder_iter(self, filter_condition=lambda _node: _node != self)
@property
def leaves(self) -> Iterable:
"""Get iterator to yield all leaf nodes from self
Returns:
(Iterable[Self])
"""
yield from preorder_iter(self, filter_condition=lambda _node: _node.is_leaf)
@property
def siblings(self) -> Iterable:
"""Get siblings of self
Returns:
(Iterable[Self])
"""
if self.is_root:
return ()
return tuple(child for child in self.parent.children if child is not self)
@property
def left_sibling(self):
"""Get sibling left of self
Returns:
(Self)
"""
if self.parent:
children = self.parent.children
child_idx = children.index(self)
if child_idx:
return self.parent.children[child_idx - 1]
return None
@property
def right_sibling(self):
"""Get sibling right of self
Returns:
(Self)
"""
if self.parent:
children = self.parent.children
child_idx = children.index(self)
if child_idx + 1 < len(children):
return self.parent.children[child_idx + 1]
return None
@property
def node_path(self) -> Iterable:
"""Get tuple of nodes starting from root
Returns:
(Iterable[Self])
"""
if self.is_root:
return [self]
return tuple(list(self.parent.node_path) + [self])
@property
def is_root(self) -> bool:
"""Get indicator if self is root node
Returns:
(bool)
"""
return self.parent is None
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len(list(self.children))
@property
def root(self):
"""Get root node of tree
Returns:
(Self)
"""
if self.is_root:
return self
return self.parent.root
@property
def depth(self) -> int:
"""Get depth of self, indexing starts from 1
Returns:
(int)
"""
if self.is_root:
return 1
return self.parent.depth + 1
@property
def max_depth(self) -> int:
"""Get maximum depth from root to leaf node
Returns:
(int)
"""
return max(node.depth for node in list(preorder_iter(self.root)))
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]):
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
Input dictionary must have key `name` if not `Node` will not have any name
>>> from bigtree import Node
>>> a = Node.from_dict({"name": "a", "age": 90})
Args:
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
Returns:
(Self)
"""
return cls(**input_dict)
def describe(self, exclude_attributes: List[str] = [], exclude_prefix: str = ""):
"""Get node information sorted by attribute name, returns list of tuples
>>> from bigtree.node.node import Node
>>> a = Node('a', age=90)
>>> a.describe()
[('_BaseNode__children', []), ('_BaseNode__parent', None), ('_sep', '/'), ('age', 90), ('name', 'a')]
>>> a.describe(exclude_prefix="_")
[('age', 90), ('name', 'a')]
>>> a.describe(exclude_prefix="_", exclude_attributes=["name"])
[('age', 90)]
Args:
exclude_attributes (List[str]): list of attributes to exclude
exclude_prefix (str): prefix of attributes to exclude
Returns:
(List[str])
"""
return [
item
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
if (item[0] not in exclude_attributes)
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
]
def get_attr(self, attr_name: str) -> Any:
"""Get value of node attribute
Returns None if attribute name does not exist
>>> from bigtree.node.node import Node
>>> a = Node('a', age=90)
>>> a.get_attr("age")
90
Args:
attr_name (str): attribute name
Returns:
(Any)
"""
try:
return self.__getattribute__(attr_name)
except AttributeError:
return None
def set_attrs(self, attrs: Dict[str, Any]):
"""Set node attributes
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a.set_attrs({"age": 90})
>>> a
Node(/a, age=90)
Args:
attrs (Dict[str, Any]): attribute dictionary,
key: attribute name, value: attribute value
"""
self.__dict__.update(attrs)
def go_to(self, node) -> Iterable:
"""Get path from current node to specified node from same tree
>>> from bigtree import Node, print_tree
>>> a = Node(name="a")
>>> b = Node(name="b", parent=a)
>>> c = Node(name="c", parent=a)
>>> d = Node(name="d", parent=b)
>>> e = Node(name="e", parent=b)
>>> f = Node(name="f", parent=c)
>>> g = Node(name="g", parent=e)
>>> h = Node(name="h", parent=e)
>>> print_tree(a)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> d.go_to(d)
[Node(/a/b/d, )]
>>> d.go_to(g)
[Node(/a/b/d, ), Node(/a/b, ), Node(/a/b/e, ), Node(/a/b/e/g, )]
>>> d.go_to(f)
[Node(/a/b/d, ), Node(/a/b, ), Node(/a, ), Node(/a/c, ), Node(/a/c/f, )]
Args:
node (Self): node to travel to from current node, inclusive of start and end node
Returns:
(Iterable)
"""
if not isinstance(node, BaseNode):
raise TypeError(
f"Expect node to be BaseNode type, received input type {type(node)}"
)
if self.root != node.root:
raise TreeError(
f"Nodes are not from the same tree. Check {self} and {node}"
)
if self == node:
return [self]
self_path = [self] + list(self.ancestors)
node_path = ([node] + list(node.ancestors))[::-1]
common_nodes = set(self_path).intersection(set(node_path))
self_min_index, min_common_node = sorted(
[(self_path.index(_node), _node) for _node in common_nodes]
)[0]
node_min_index = node_path.index(min_common_node)
return self_path[:self_min_index] + node_path[node_min_index:]
def copy(self):
"""Deep copy self; clone BaseNode
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a_copy = a.copy()
Returns:
(Self)
"""
return copy.deepcopy(self)
def sort(self, **kwargs):
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
>>> from bigtree import Node, print_tree
>>> a = Node('a')
>>> c = Node("c", parent=a)
>>> b = Node("b", parent=a)
>>> print_tree(a)
a
├── c
└── b
>>> a.sort(key=lambda node: node.name)
>>> print_tree(a)
a
├── b
└── c
"""
children = list(self.children)
children.sort(**kwargs)
self.__children = children
def __copy__(self):
"""Shallow copy self
>>> import copy
>>> from bigtree.node.node import Node
>>> a = Node('a')
>>> a_copy = copy.deepcopy(a)
Returns:
(Self)
"""
obj = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def __repr__(self):
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_")
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({node_description})"
def __rshift__(self, other):
"""Set children using >> bitshift operator for self >> other
Args:
other (Self): other node, children
"""
other.parent = self
def __lshift__(self, other):
"""Set parent using << bitshift operator for self << other
Args:
other (Self): other node, parent
"""
self.parent = other

View File

@@ -0,0 +1,395 @@
from typing import Iterable, List, Union
from bigtree.node.node import Node
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
class BinaryNode(Node):
"""
BinaryNode is an extension of Node, and is able to extend to any Python class for Binary Tree implementation.
Nodes can have attributes if they are initialized from `BinaryNode`, *dictionary*, or *pandas DataFrame*.
BinaryNode can be linked to each other with `children`, `left`, or `right` setter methods.
If initialized with `children`, it must be length 2, denoting left and right child.
>>> from bigtree import BinaryNode, print_tree
>>> a = BinaryNode(1)
>>> b = BinaryNode(2)
>>> c = BinaryNode(3)
>>> d = BinaryNode(4)
>>> a.children = [b, c]
>>> b.right = d
>>> print_tree(a)
1
├── 2
│ └── 4
└── 3
Directly passing `left`, `right`, or `children` argument.
>>> from bigtree import BinaryNode
>>> d = BinaryNode(4)
>>> c = BinaryNode(3)
>>> b = BinaryNode(2, right=d)
>>> a = BinaryNode(1, children=[b, c])
**BinaryNode Creation**
Node can be created by instantiating a `BinaryNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import BinaryNode
>>> a = BinaryNode.from_dict({"name": "1"})
>>> a
BinaryNode(name=1, val=1)
**BinaryNode Attributes**
These are node attributes that have getter and/or setter methods.
Get `BinaryNode` configuration
1. ``left``: Get left children
2. ``right``: Get right children
----
"""
def __init__(
self,
name: Union[str, int] = "",
left=None,
right=None,
parent=None,
children: List = None,
**kwargs,
):
self.val = int(name)
self.name = str(name)
self._sep = "/"
self.__parent = None
self.__children = []
if not children:
children = []
if len(children):
if len(children) and len(children) != 2:
raise ValueError("Children input must have length 2")
if left and left != children[0]:
raise ValueError(
f"Attempting to set both left and children with mismatched values\n"
f"Check left {left} and children {children}"
)
if right and right != children[1]:
raise ValueError(
f"Attempting to set both right and children with mismatched values\n"
f"Check right {right} and children {children}"
)
else:
children = [left, right]
self.parent = parent
self.children = children
if "parents" in kwargs:
raise ValueError(
"Attempting to set `parents` attribute, do you mean `parent`?"
)
self.__dict__.update(**kwargs)
@property
def left(self):
"""Get left children
Returns:
(Self)
"""
return self.__children[0]
@left.setter
def left(self, left_child):
"""Set left children
Args:
left_child (Self): left child
"""
self.children = [left_child, self.right]
@property
def right(self):
"""Get right children
Returns:
(Self)
"""
return self.__children[1]
@right.setter
def right(self, right_child):
"""Set right children
Args:
right_child (Self): right child
"""
self.children = [self.left, right_child]
@property
def parent(self):
"""Get parent node
Returns:
(Self)
"""
return self.__parent
@staticmethod
def __check_parent_type(new_parent):
"""Check parent type
Args:
new_parent (Self): parent node
"""
if not (isinstance(new_parent, BinaryNode) or new_parent is None):
raise TypeError(
f"Expect input to be BinaryNode type or NoneType, received input type {type(new_parent)}"
)
@parent.setter
def parent(self, new_parent):
"""Set parent node
Args:
new_parent (Self): parent node
"""
self.__check_parent_type(new_parent)
self._BaseNode__check_parent_loop(new_parent)
current_parent = self.parent
current_child_idx = None
# Assign new parent - rollback if error
self.__pre_assign_parent(new_parent)
try:
# Remove self from old parent
if current_parent is not None:
if not any(
child is self for child in current_parent.children
): # pragma: no cover
raise CorruptedTreeError(
"Error setting parent: Node does not exist as children of its parent"
)
current_child_idx = current_parent.__children.index(self)
current_parent.__children[current_child_idx] = None
# Assign self to new parent
self.__parent = new_parent
if new_parent is not None:
inserted = False
for child_idx, child in enumerate(new_parent.__children):
if not child and not inserted:
new_parent.__children[child_idx] = self
inserted = True
if not inserted:
raise TreeError(f"Parent {new_parent} already has 2 children")
self.__post_assign_parent(new_parent)
except Exception as exc_info:
# Remove self from new parent
if new_parent is not None and self in new_parent.__children:
child_idx = new_parent.__children.index(self)
new_parent.__children[child_idx] = None
# Reassign self to old parent
self.__parent = current_parent
if current_child_idx is not None:
current_parent.__children[current_child_idx] = self
raise TreeError(exc_info)
def __pre_assign_parent(self, new_parent):
"""Custom method to check before attaching parent
Can be overriden with `_BinaryNode__pre_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def __post_assign_parent(self, new_parent):
"""Custom method to check after attaching parent
Can be overriden with `_BinaryNode__post_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
@property
def children(self) -> Iterable:
"""Get child nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__children)
def __check_children_type(self, new_children: List) -> List:
"""Check child type
Args:
new_children (List[Self]): child node
"""
if not len(new_children):
new_children = [None, None]
if len(new_children) != 2:
raise ValueError("Children input must have length 2")
return new_children
def __check_children_loop(self, new_children: List):
"""Check child loop
Args:
new_children (List[Self]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if new_child is not None and not isinstance(new_child, BinaryNode):
raise TypeError(
f"Expect input to be BinaryNode type or NoneType, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestors of itself"
)
# Check for duplicate children
if new_child is not None:
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@children.setter
def children(self, new_children: List):
"""Set child nodes
Args:
new_children (List[Self]): child node
"""
self._BaseNode__check_children_type(new_children)
new_children = self.__check_children_type(new_children)
self.__check_children_loop(new_children)
current_new_children = {
new_child: (
new_child.parent.__children.index(new_child),
new_child.parent,
)
for new_child in new_children
if new_child is not None and new_child.parent is not None
}
current_new_orphan = [
new_child
for new_child in new_children
if new_child is not None and new_child.parent is None
]
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Remove old children from self
del self.children
# Assign new children to self
self.__children = new_children
for new_child in new_children:
if new_child is not None:
if new_child.parent:
child_idx = new_child.parent.__children.index(new_child)
new_child.parent.__children[child_idx] = None
new_child.__parent = self
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign new children to their original parent
for child, idx_parent in current_new_children.items():
child_idx, parent = idx_parent
child.__parent = parent
parent.__children[child_idx] = child
for child in current_new_orphan:
child.__parent = None
# Reassign old children to self
self.__children = current_children
for child in current_children:
if child:
child.__parent = self
raise TreeError(exc_info)
@children.deleter
def children(self):
"""Delete child node(s)"""
for child in self.children:
if child is not None:
child.parent.__children.remove(child)
child.__parent = None
def __pre_assign_children(self, new_children: List):
"""Custom method to check before attaching children
Can be overriden with `_BinaryNode__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self, new_children: List):
"""Custom method to check after attaching children
Can be overriden with `_BinaryNode__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len([child for child in self.children if child])
def sort(self, **kwargs):
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
>>> from bigtree import BinaryNode, print_tree
>>> a = BinaryNode(1)
>>> c = BinaryNode(3, parent=a)
>>> b = BinaryNode(2, parent=a)
>>> print_tree(a)
1
├── 3
└── 2
>>> a.sort(key=lambda node: node.val)
>>> print_tree(a)
1
├── 2
└── 3
"""
children = [child for child in self.children if child]
if len(children) == 2:
children.sort(**kwargs)
self.__children = children
def __repr__(self):
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_", exclude_attributes=[])
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({node_description})"

View File

@@ -0,0 +1,570 @@
import copy
from typing import Any, Dict, Iterable, List
from bigtree.utils.exceptions import LoopError, TreeError
from bigtree.utils.iterators import preorder_iter
class DAGNode:
"""
Base DAGNode extends any Python class to a DAG node, for DAG implementation.
In DAG implementation, a node can have multiple parents.
Parents and children cannot be reassigned once assigned, as Nodes are allowed to have multiple parents and children.
If each node only has one parent, use `Node` class.
DAGNodes can have attributes if they are initialized from `DAGNode` or dictionary.
DAGNode can be linked to each other with `parents` and `children` setter methods,
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> c.parents = [a, b]
>>> c.children = [d]
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> a >> c
>>> b >> c
>>> d << c
Directly passing `parents` argument.
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c", parents=[a, b])
>>> d = DAGNode("d", parents=[c])
Directly passing `children` argument.
>>> from bigtree import DAGNode
>>> d = DAGNode("d")
>>> c = DAGNode("c", children=[d])
>>> b = DAGNode("b", children=[c])
>>> a = DAGNode("a", children=[c])
**DAGNode Creation**
Node can be created by instantiating a `DAGNode` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import DAGNode
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
**DAGNode Attributes**
These are node attributes that have getter and/or setter methods.
Get and set other `DAGNode`
1. ``parents``: Get/set parent nodes
2. ``children``: Get/set child nodes
Get other `DAGNode`
1. ``ancestors``: Get ancestors of node excluding self, iterator
2. ``descendants``: Get descendants of node excluding self, iterator
3. ``siblings``: Get siblings of self
Get `DAGNode` configuration
1. ``node_name``: Get node name, without accessing `name` directly
2. ``is_root``: Get indicator if self is root node
3. ``is_leaf``: Get indicator if self is leaf node
**DAGNode Methods**
These are methods available to be performed on `DAGNode`.
Constructor methods
1. ``from_dict()``: Create DAGNode from dictionary
`DAGNode` methods
1. ``describe()``: Get node information sorted by attributes, returns list of tuples
2. ``get_attr(attr_name: str)``: Get value of node attribute
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
4. ``go_to(node: BaseNode)``: Get a path from own node to another node from same DAG
5. ``copy()``: Deep copy DAGNode
----
"""
def __init__(
self, name: str = "", parents: List = None, children: List = None, **kwargs
):
self.name = name
self.__parents = []
self.__children = []
if parents is None:
parents = []
if children is None:
children = []
self.parents = parents
self.children = children
if "parent" in kwargs:
raise ValueError(
"Attempting to set `parent` attribute, do you mean `parents`?"
)
self.__dict__.update(**kwargs)
@property
def parent(self) -> None:
"""Do not allow `parent` attribute to be accessed"""
raise ValueError(
"Attempting to access `parent` attribute, do you mean `parents`?"
)
@parent.setter
def parent(self, new_parent):
"""Do not allow `parent` attribute to be set
Args:
new_parent (Self): parent node
"""
raise ValueError("Attempting to set `parent` attribute, do you mean `parents`?")
@property
def parents(self) -> Iterable:
"""Get parent nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__parents)
@staticmethod
def __check_parent_type(new_parents: List):
"""Check parent type
Args:
new_parents (List[Self]): parent nodes
"""
if not isinstance(new_parents, list):
raise TypeError(
f"Parents input should be list type, received input type {type(new_parents)}"
)
def __check_parent_loop(self, new_parents: List):
"""Check parent type
Args:
new_parents (List[Self]): parent nodes
"""
seen_parent = []
for new_parent in new_parents:
# Check type
if not isinstance(new_parent, DAGNode):
raise TypeError(
f"Expect input to be DAGNode type, received input type {type(new_parent)}"
)
# Check for loop and tree structure
if new_parent is self:
raise LoopError("Error setting parent: Node cannot be parent of itself")
if new_parent.ancestors:
if any(ancestor is self for ancestor in new_parent.ancestors):
raise LoopError(
"Error setting parent: Node cannot be ancestor of itself"
)
# Check for duplicate children
if id(new_parent) in seen_parent:
raise TreeError(
"Error setting parent: Node cannot be added multiple times as a parent"
)
else:
seen_parent.append(id(new_parent))
@parents.setter
def parents(self, new_parents: List):
"""Set parent node
Args:
new_parents (List[Self]): parent nodes
"""
self.__check_parent_type(new_parents)
self.__check_parent_loop(new_parents)
current_parents = self.__parents.copy()
# Assign new parents - rollback if error
self.__pre_assign_parents(new_parents)
try:
# Assign self to new parent
for new_parent in new_parents:
if new_parent not in self.__parents:
self.__parents.append(new_parent)
new_parent.__children.append(self)
self.__post_assign_parents(new_parents)
except Exception as exc_info:
# Remove self from new parent
for new_parent in new_parents:
if new_parent not in current_parents:
self.__parents.remove(new_parent)
new_parent.__children.remove(self)
raise TreeError(
f"{exc_info}, current parents {current_parents}, new parents {new_parents}"
)
def __pre_assign_parents(self, new_parents: List):
"""Custom method to check before attaching parent
Can be overriden with `_DAGNode__pre_assign_parent()`
Args:
new_parents (List): new parents to be added
"""
pass
def __post_assign_parents(self, new_parents: List):
"""Custom method to check after attaching parent
Can be overriden with `_DAGNode__post_assign_parent()`
Args:
new_parents (List): new parents to be added
"""
pass
@property
def children(self) -> Iterable:
"""Get child nodes
Returns:
(Iterable[Self])
"""
return tuple(self.__children)
def __check_children_type(self, new_children: List):
"""Check child type
Args:
new_children (List[Self]): child node
"""
if not isinstance(new_children, list):
raise TypeError(
f"Children input should be list type, received input type {type(new_children)}"
)
def __check_children_loop(self, new_children: List):
"""Check child loop
Args:
new_children (List[Self]): child node
"""
seen_children = []
for new_child in new_children:
# Check type
if not isinstance(new_child, DAGNode):
raise TypeError(
f"Expect input to be DAGNode type, received input type {type(new_child)}"
)
# Check for loop and tree structure
if new_child is self:
raise LoopError("Error setting child: Node cannot be child of itself")
if any(child is new_child for child in self.ancestors):
raise LoopError(
"Error setting child: Node cannot be ancestors of itself"
)
# Check for duplicate children
if id(new_child) in seen_children:
raise TreeError(
"Error setting child: Node cannot be added multiple times as a child"
)
else:
seen_children.append(id(new_child))
@children.setter
def children(self, new_children: List):
"""Set child nodes
Args:
new_children (List[Self]): child node
"""
self.__check_children_type(new_children)
self.__check_children_loop(new_children)
current_children = list(self.children)
# Assign new children - rollback if error
self.__pre_assign_children(new_children)
try:
# Assign new children to self
for new_child in new_children:
if self not in new_child.__parents:
new_child.__parents.append(self)
self.__children.append(new_child)
self.__post_assign_children(new_children)
except Exception as exc_info:
# Reassign old children to self
for new_child in new_children:
if new_child not in current_children:
new_child.__parents.remove(self)
self.__children.remove(new_child)
raise TreeError(exc_info)
def __pre_assign_children(self, new_children: List):
"""Custom method to check before attaching children
Can be overriden with `_DAGNode__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self, new_children: List):
"""Custom method to check after attaching children
Can be overriden with `_DAGNode__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
@property
def ancestors(self) -> Iterable:
"""Get iterator to yield all ancestors of self, does not include self
Returns:
(Iterable[Self])
"""
if not len(list(self.parents)):
return ()
def recursive_parent(node):
for _node in node.parents:
yield from recursive_parent(_node)
yield _node
ancestors = list(recursive_parent(self))
return list(dict.fromkeys(ancestors))
@property
def descendants(self) -> Iterable:
"""Get iterator to yield all descendants of self, does not include self
Returns:
(Iterable[Self])
"""
descendants = list(
preorder_iter(self, filter_condition=lambda _node: _node != self)
)
return list(dict.fromkeys(descendants))
@property
def siblings(self) -> Iterable:
"""Get siblings of self
Returns:
(Iterable[Self])
"""
if self.is_root:
return ()
return tuple(
child
for parent in self.parents
for child in parent.children
if child is not self
)
@property
def node_name(self) -> str:
"""Get node name
Returns:
(str)
"""
return self.name
@property
def is_root(self) -> bool:
"""Get indicator if self is root node
Returns:
(bool)
"""
return not len(list(self.parents))
@property
def is_leaf(self) -> bool:
"""Get indicator if self is leaf node
Returns:
(bool)
"""
return not len(list(self.children))
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]):
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
Input dictionary must have key `name` if not `Node` will not have any name
>>> from bigtree import DAGNode
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
Args:
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
Returns:
(Self)
"""
return cls(**input_dict)
def describe(self, exclude_attributes: List[str] = [], exclude_prefix: str = ""):
"""Get node information sorted by attribute name, returns list of tuples
Args:
exclude_attributes (List[str]): list of attributes to exclude
exclude_prefix (str): prefix of attributes to exclude
Returns:
(List[str])
"""
return [
item
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
if (item[0] not in exclude_attributes)
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
]
def get_attr(self, attr_name: str) -> Any:
"""Get value of node attribute
Returns None if attribute name does not exist
Args:
attr_name (str): attribute name
Returns:
(Any)
"""
try:
return self.__getattribute__(attr_name)
except AttributeError:
return None
def set_attrs(self, attrs: Dict[str, Any]):
"""Set node attributes
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a.set_attrs({"age": 90})
>>> a
DAGNode(a, age=90)
Args:
attrs (Dict[str, Any]): attribute dictionary,
key: attribute name, value: attribute value
"""
self.__dict__.update(attrs)
def go_to(self, node) -> Iterable[Iterable]:
"""Get list of possible paths from current node to specified node from same tree
>>> from bigtree import DAGNode
>>> a = DAGNode("a")
>>> b = DAGNode("b")
>>> c = DAGNode("c")
>>> d = DAGNode("d")
>>> a >> c
>>> b >> c
>>> c >> d
>>> a >> d
>>> a.go_to(c)
[[DAGNode(a, ), DAGNode(c, )]]
>>> a.go_to(d)
[[DAGNode(a, ), DAGNode(c, ), DAGNode(d, )], [DAGNode(a, ), DAGNode(d, )]]
>>> a.go_to(b)
Traceback (most recent call last):
...
bigtree.utils.exceptions.TreeError: It is not possible to go to DAGNode(b, )
Args:
node (Self): node to travel to from current node, inclusive of start and end node
Returns:
(Iterable[Iterable])
"""
if not isinstance(node, DAGNode):
raise TypeError(
f"Expect node to be DAGNode type, received input type {type(node)}"
)
if self == node:
return [self]
if node not in self.descendants:
raise TreeError(f"It is not possible to go to {node}")
self.__path = []
def recursive_path(_node, _path, _ans):
if _node: # pragma: no cover
_path.append(_node)
if _node == node:
return _path
for _child in _node.children:
ans = recursive_path(_child, _path.copy(), _ans)
if ans:
self.__path.append(ans)
recursive_path(self, [], [])
return self.__path
def copy(self):
"""Deep copy self; clone DAGNode
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a_copy = a.copy()
Returns:
(Self)
"""
return copy.deepcopy(self)
def __copy__(self):
"""Shallow copy self
>>> import copy
>>> from bigtree.node.dagnode import DAGNode
>>> a = DAGNode('a')
>>> a_copy = copy.deepcopy(a)
Returns:
(Self)
"""
obj = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def __rshift__(self, other):
"""Set children using >> bitshift operator for self >> other
Args:
other (Self): other node, children
"""
other.parents = [self]
def __lshift__(self, other):
"""Set parent using << bitshift operator for self << other
Args:
other (Self): other node, parent
"""
self.parents = [other]
def __repr__(self):
class_name = self.__class__.__name__
node_dict = self.describe(exclude_attributes=["name"])
node_description = ", ".join(
[f"{k}={v}" for k, v in node_dict if not k.startswith("_")]
)
return f"{class_name}({self.node_name}, {node_description})"

View File

@@ -0,0 +1,204 @@
from collections import Counter
from typing import List
from bigtree.node.basenode import BaseNode
from bigtree.utils.exceptions import TreeError
class Node(BaseNode):
"""
Node is an extension of BaseNode, and is able to extend to any Python class.
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
Nodes can be linked to each other with `parent` and `children` setter methods.
>>> from bigtree import Node
>>> a = Node("a")
>>> b = Node("b")
>>> c = Node("c")
>>> d = Node("d")
>>> b.parent = a
>>> b.children = [c, d]
Directly passing `parent` argument.
>>> from bigtree import Node
>>> a = Node("a")
>>> b = Node("b", parent=a)
>>> c = Node("c", parent=b)
>>> d = Node("d", parent=b)
Directly passing `children` argument.
>>> from bigtree import Node
>>> d = Node("d")
>>> c = Node("c")
>>> b = Node("b", children=[c, d])
>>> a = Node("a", children=[b])
**Node Creation**
Node can be created by instantiating a `Node` class or by using a *dictionary*.
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
>>> from bigtree import Node
>>> a = Node.from_dict({"name": "a", "age": 90})
**Node Attributes**
These are node attributes that have getter and/or setter methods.
Get and set `Node` configuration
1. ``sep``: Get/set separator for path name
Get `Node` configuration
1. ``node_name``: Get node name, without accessing `name` directly
2. ``path_name``: Get path name from root, separated by `sep`
----
"""
def __init__(self, name: str = "", **kwargs):
self.name = name
self._sep: str = "/"
super().__init__(**kwargs)
if not self.node_name:
raise TreeError("Node must have a `name` attribute")
@property
def node_name(self) -> str:
"""Get node name
Returns:
(str)
"""
return self.name
@property
def sep(self) -> str:
"""Get separator, gets from root node
Returns:
(str)
"""
if self.is_root:
return self._sep
return self.parent.sep
@sep.setter
def sep(self, value: str):
"""Set separator, affects root node
Args:
value (str): separator to replace default separator
"""
self.root._sep = value
@property
def path_name(self) -> str:
"""Get path name, separated by self.sep
Returns:
(str)
"""
if self.is_root:
return f"{self.sep}{self.name}"
return f"{self.parent.path_name}{self.sep}{self.name}"
def __pre_assign_children(self, new_children: List):
"""Custom method to check before attaching children
Can be overriden with `_Node__pre_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __post_assign_children(self, new_children: List):
"""Custom method to check after attaching children
Can be overriden with `_Node__post_assign_children()`
Args:
new_children (List[Self]): new children to be added
"""
pass
def __pre_assign_parent(self, new_parent):
"""Custom method to check before attaching parent
Can be overriden with `_Node__pre_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def __post_assign_parent(self, new_parent):
"""Custom method to check after attaching parent
Can be overriden with `_Node__post_assign_parent()`
Args:
new_parent (Self): new parent to be added
"""
pass
def _BaseNode__pre_assign_parent(self, new_parent):
"""Do not allow duplicate nodes of same path
Args:
new_parent (Self): new parent to be added
"""
self.__pre_assign_parent(new_parent)
if new_parent is not None:
if any(
child.node_name == self.node_name and child is not self
for child in new_parent.children
):
raise TreeError(
f"Error: Duplicate node with same path\n"
f"There exist a node with same path {new_parent.path_name}{self.sep}{self.node_name}"
)
def _BaseNode__post_assign_parent(self, new_parent):
"""No rules
Args:
new_parent (Self): new parent to be added
"""
self.__post_assign_parent(new_parent)
def _BaseNode__pre_assign_children(self, new_children: List):
"""Do not allow duplicate nodes of same path
Args:
new_children (List[Self]): new children to be added
"""
self.__pre_assign_children(new_children)
children_names = [node.node_name for node in new_children]
duplicated_names = [
item[0] for item in Counter(children_names).items() if item[1] > 1
]
if len(duplicated_names):
duplicated_names = " and ".join(
[f"{self.path_name}{self.sep}{name}" for name in duplicated_names]
)
raise TreeError(
f"Error: Duplicate node with same path\n"
f"Attempting to add nodes same path {duplicated_names}"
)
def _BaseNode__post_assign_children(self, new_children: List):
"""No rules
Args:
new_children (List[Self]): new children to be added
"""
self.__post_assign_children(new_children)
def __repr__(self):
class_name = self.__class__.__name__
node_dict = self.describe(exclude_prefix="_", exclude_attributes=["name"])
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
return f"{class_name}({self.path_name}, {node_description})"

View File

@@ -0,0 +1,914 @@
import re
from collections import OrderedDict
from typing import List, Tuple, Type
import numpy as np
import pandas as pd
from bigtree.node.node import Node
from bigtree.tree.export import tree_to_dataframe
from bigtree.tree.search import find_children, find_name
from bigtree.utils.exceptions import DuplicatedNodeError, TreeError
__all__ = [
"add_path_to_tree",
"add_dict_to_tree_by_path",
"add_dict_to_tree_by_name",
"add_dataframe_to_tree_by_path",
"add_dataframe_to_tree_by_name",
"str_to_tree",
"list_to_tree",
"list_to_tree_by_relation",
"dict_to_tree",
"nested_dict_to_tree",
"dataframe_to_tree",
"dataframe_to_tree_by_relation",
]
def add_path_to_tree(
tree: Node,
path: str,
sep: str = "/",
duplicate_name_allowed: bool = True,
node_attrs: dict = {},
) -> Node:
"""Add nodes and attributes to existing tree *in-place*, return node of added path.
Adds to existing tree from list of path strings.
Path should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
- Path separator `sep` is for the input `path` and can be different from that of existing tree.
Path can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> from bigtree import add_path_to_tree, print_tree
>>> root = Node("a")
>>> add_path_to_tree(root, "a/b/c")
Node(/a/b/c, )
>>> print_tree(root)
a
└── b
└── c
Args:
tree (Node): existing tree
path (str): path to be added to tree
sep (str): path separator for input `path`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
node_attrs (dict): attributes to add to node, key: attribute name, value: attribute value, optional
Returns:
(Node)
"""
if not len(path):
raise ValueError("Path is empty, check `path`")
tree_root = tree.root
tree_sep = tree_root.sep
node_type = tree_root.__class__
branch = path.lstrip(sep).rstrip(sep).split(sep)
if branch[0] != tree_root.node_name:
raise TreeError(
f"Error: Path does not have same root node, expected {tree_root.node_name}, received {branch[0]}\n"
f"Check your input paths or verify that path separator `sep` is set correctly"
)
# Grow tree
node = tree_root
parent_node = tree_root
for idx in range(1, len(branch)):
node_name = branch[idx]
node_path = tree_sep.join(branch[: idx + 1])
if not duplicate_name_allowed:
node = find_name(tree_root, node_name)
if node and not node.path_name.endswith(node_path):
raise DuplicatedNodeError(
f"Node {node_name} already exists, try setting `duplicate_name_allowed` to True "
f"to allow `Node` with same node name"
)
else:
node = find_children(parent_node, node_name)
if not node:
node = node_type(branch[idx])
node.parent = parent_node
parent_node = node
node.set_attrs(node_attrs)
return node
def add_dict_to_tree_by_path(
tree: Node,
path_attrs: dict,
sep: str = "/",
duplicate_name_allowed: bool = True,
) -> Node:
"""Add nodes and attributes to tree *in-place*, return root of tree.
Adds to existing tree from nested dictionary, ``key``: path, ``value``: dict of attribute name and attribute value.
Path should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
- Path separator `sep` is for the input `path` and can be different from that of existing tree.
Path can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> from bigtree import Node, add_dict_to_tree_by_path, print_tree
>>> root = Node("a")
>>> path_dict = {
... "a": {"age": 90},
... "a/b": {"age": 65},
... "a/c": {"age": 60},
... "a/b/d": {"age": 40},
... "a/b/e": {"age": 35},
... "a/c/f": {"age": 38},
... "a/b/e/g": {"age": 10},
... "a/b/e/h": {"age": 6},
... }
>>> root = add_dict_to_tree_by_path(root, path_dict)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
Args:
tree (Node): existing tree
path_attrs (dict): dictionary containing node path and attribute information,
key: node path, value: dict of node attribute name and attribute value
sep (str): path separator for input `path_attrs`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
Returns:
(Node)
"""
if not len(path_attrs):
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
tree_root = tree.root
for k, v in path_attrs.items():
add_path_to_tree(
tree_root,
k,
sep=sep,
duplicate_name_allowed=duplicate_name_allowed,
node_attrs=v,
)
return tree_root
def add_dict_to_tree_by_name(
tree: Node, path_attrs: dict, join_type: str = "left"
) -> Node:
"""Add attributes to tree, return *new* root of tree.
Adds to existing tree from nested dictionary, ``key``: name, ``value``: dict of attribute name and attribute value.
Function can return all existing tree nodes or only tree nodes that are in the input dictionary keys.
Input dictionary keys that are not existing node names will be ignored.
Note that if multiple nodes have the same name, attributes will be added to all nodes sharing same name.
>>> from bigtree import Node, add_dict_to_tree_by_name, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> name_dict = {
... "a": {"age": 90},
... "b": {"age": 65},
... }
>>> root = add_dict_to_tree_by_name(root, name_dict)
>>> print_tree(root, attr_list=["age"])
a [age=90]
└── b [age=65]
Args:
tree (Node): existing tree
path_attrs (dict): dictionary containing node name and attribute information,
key: node name, value: dict of node attribute name and attribute value
join_type (str): join type with attribute, default of 'left' takes existing tree nodes,
if join_type is set to 'inner' it will only take tree nodes that are in `path_attrs` key and drop others
Returns:
(Node)
"""
if join_type not in ["inner", "left"]:
raise ValueError("`join_type` must be one of 'inner' or 'left'")
if not len(path_attrs):
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
# Convert dictionary to dataframe
data = pd.DataFrame(path_attrs).T.rename_axis("NAME").reset_index()
return add_dataframe_to_tree_by_name(tree, data=data, join_type=join_type)
def add_dataframe_to_tree_by_path(
tree: Node,
data: pd.DataFrame,
path_col: str = "",
attribute_cols: list = [],
sep: str = "/",
duplicate_name_allowed: bool = True,
) -> Node:
"""Add nodes and attributes to tree *in-place*, return root of tree.
`path_col` and `attribute_cols` specify columns for node path and attributes to add to existing tree.
If columns are not specified, `path_col` takes first column and all other columns are `attribute_cols`
Path in path column should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
- Path separator `sep` is for the input `path_col` and can be different from that of existing tree.
Path in path column can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> import pandas as pd
>>> from bigtree import add_dataframe_to_tree_by_path, print_tree
>>> root = Node("a")
>>> path_data = pd.DataFrame([
... ["a", 90],
... ["a/b", 65],
... ["a/c", 60],
... ["a/b/d", 40],
... ["a/b/e", 35],
... ["a/c/f", 38],
... ["a/b/e/g", 10],
... ["a/b/e/h", 6],
... ],
... columns=["PATH", "age"]
... )
>>> root = add_dataframe_to_tree_by_path(root, path_data)
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
│ ├── g [age=10]
│ └── h [age=6]
└── c [age=60]
└── f [age=38]
Args:
tree (Node): existing tree
data (pandas.DataFrame): data containing node path and attribute information
path_col (str): column of data containing `path_name` information,
if not set, it will take the first column of data
attribute_cols (list): columns of data containing node attribute information,
if not set, it will take all columns of data except `path_col`
sep (str): path separator for input `path_col`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
Returns:
(Node)
"""
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not path_col:
path_col = data.columns[0]
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(path_col)
tree_root = tree.root
data[path_col] = data[path_col].str.lstrip(sep).str.rstrip(sep)
data2 = data.copy()[[path_col] + attribute_cols].astype(str).drop_duplicates()
_duplicate_check = (
data2[path_col]
.value_counts()
.to_frame("counts")
.rename_axis(path_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate path with different attributes\nCheck {_duplicate_check}"
)
for row in data.to_dict(orient="index").values():
node_attrs = row.copy()
del node_attrs[path_col]
node_attrs = {k: v for k, v in node_attrs.items() if not np.all(pd.isnull(v))}
add_path_to_tree(
tree_root,
row[path_col],
sep=sep,
duplicate_name_allowed=duplicate_name_allowed,
node_attrs=node_attrs,
)
return tree_root
def add_dataframe_to_tree_by_name(
tree: Node,
data: pd.DataFrame,
name_col: str = "",
attribute_cols: list = [],
join_type: str = "left",
):
"""Add attributes to tree, return *new* root of tree.
`name_col` and `attribute_cols` specify columns for node name and attributes to add to existing tree.
If columns are not specified, the first column will be taken as name column and all other columns as attributes.
Function can return all existing tree nodes or only tree nodes that are in the input data node names.
Input data node names that are not existing node names will be ignored.
Note that if multiple nodes have the same name, attributes will be added to all nodes sharing same name.
>>> import pandas as pd
>>> from bigtree import add_dataframe_to_tree_by_name, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> name_data = pd.DataFrame([
... ["a", 90],
... ["b", 65],
... ],
... columns=["NAME", "age"]
... )
>>> root = add_dataframe_to_tree_by_name(root, name_data)
>>> print_tree(root, attr_list=["age"])
a [age=90]
└── b [age=65]
Args:
tree (Node): existing tree
data (pandas.DataFrame): data containing node name and attribute information
name_col (str): column of data containing `name` information,
if not set, it will take the first column of data
attribute_cols (list): column(s) of data containing node attribute information,
if not set, it will take all columns of data except path_col
join_type (str): join type with attribute, default of 'left' takes existing tree nodes,
if join_type is set to 'inner' it will only take tree nodes with attributes and drop the other nodes
Returns:
(Node)
"""
if join_type not in ["inner", "left"]:
raise ValueError("`join_type` must be one of 'inner' or 'left'")
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not name_col:
name_col = data.columns[0]
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(name_col)
# Attribute data
path_col = "PATH"
data2 = data.copy()[[name_col] + attribute_cols].astype(str).drop_duplicates()
_duplicate_check = (
data2[name_col]
.value_counts()
.to_frame("counts")
.rename_axis(name_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate name with different attributes\nCheck {_duplicate_check}"
)
# Tree data
tree_root = tree.root
sep = tree_root.sep
node_type = tree_root.__class__
data_tree = tree_to_dataframe(
tree_root, name_col=name_col, path_col=path_col, all_attrs=True
)
common_cols = list(set(data_tree.columns).intersection(attribute_cols))
data_tree = data_tree.drop(columns=common_cols)
# Attribute data
data_tree_attrs = pd.merge(data_tree, data, on=name_col, how=join_type)
data_tree_attrs = data_tree_attrs.drop(columns=name_col)
return dataframe_to_tree(
data_tree_attrs, path_col=path_col, sep=sep, node_type=node_type
)
def str_to_tree(
tree_string: str,
tree_prefix_list: List[str] = [],
node_type: Type[Node] = Node,
) -> Node:
r"""Construct tree from tree string
>>> from bigtree import str_to_tree, print_tree
>>> tree_str = 'a\n├── b\n│ ├── d\n│ └── e\n│ ├── g\n│ └── h\n└── c\n └── f'
>>> root = str_to_tree(tree_str, tree_prefix_list=["├──", "└──"])
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
Args:
tree_string (str): String to construct tree
tree_prefix_list (list): List of prefix to mark the end of tree branch/stem and start of node name, optional.
If not specified, it will infer unicode characters and whitespace as prefix.
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
tree_string = tree_string.strip("\n")
if not len(tree_string):
raise ValueError("Tree string does not contain any data, check `tree_string`")
tree_list = tree_string.split("\n")
tree_root = node_type(tree_list[0])
# Infer prefix length
prefix_length = None
cur_parent = tree_root
for node_str in tree_list[1:]:
if len(tree_prefix_list):
node_name = re.split("|".join(tree_prefix_list), node_str)[-1].lstrip()
else:
node_name = node_str.encode("ascii", "ignore").decode("ascii").lstrip()
# Find node parent
if not prefix_length:
prefix_length = node_str.index(node_name)
if not prefix_length:
raise ValueError(
f"Invalid prefix, prefix should be unicode character or whitespace, "
f"otherwise specify one or more prefixes in `tree_prefix_list`, check: {node_str}"
)
node_prefix_length = node_str.index(node_name)
if node_prefix_length % prefix_length:
raise ValueError(
f"Tree string have different prefix length, check branch: {node_str}"
)
while cur_parent.depth > node_prefix_length / prefix_length:
cur_parent = cur_parent.parent
# Link node
child_node = node_type(node_name)
child_node.parent = cur_parent
cur_parent = child_node
return tree_root
def list_to_tree(
paths: list,
sep: str = "/",
duplicate_name_allowed: bool = True,
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from list of path strings.
Path should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
Path can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> from bigtree import list_to_tree, print_tree
>>> path_list = ["a/b", "a/c", "a/b/d", "a/b/e", "a/c/f", "a/b/e/g", "a/b/e/h"]
>>> root = list_to_tree(path_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
Args:
paths (list): list containing path strings
sep (str): path separator for input `paths` and created tree, defaults to `/`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
if not len(paths):
raise ValueError("Path list does not contain any data, check `paths`")
# Remove duplicates
paths = list(OrderedDict.fromkeys(paths))
# Construct root node
root_name = paths[0].lstrip(sep).split(sep)[0]
root_node = node_type(root_name)
root_node.sep = sep
for path in paths:
add_path_to_tree(
root_node, path, sep=sep, duplicate_name_allowed=duplicate_name_allowed
)
root_node.sep = sep
return root_node
def list_to_tree_by_relation(
relations: List[Tuple[str, str]],
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from list of tuple containing parent-child names.
Note that node names must be unique since tree is created from parent-child names,
except for leaf nodes - names of leaf nodes may be repeated as there is no confusion.
>>> from bigtree import list_to_tree_by_relation, print_tree
>>> relations_list = [("a", "b"), ("a", "c"), ("b", "d"), ("b", "e"), ("c", "f"), ("e", "g"), ("e", "h")]
>>> root = list_to_tree_by_relation(relations_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
Args:
relations (list): list containing tuple containing parent-child names
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
if not len(relations):
raise ValueError("Path list does not contain any data, check `relations`")
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
return dataframe_to_tree_by_relation(
relation_data, child_col="child", parent_col="parent", node_type=node_type
)
def dict_to_tree(
path_attrs: dict,
sep: str = "/",
duplicate_name_allowed: bool = True,
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from nested dictionary using path,
``key``: path, ``value``: dict of attribute name and attribute value.
Path should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
Path can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> from bigtree import dict_to_tree, print_tree
>>> path_dict = {
... "a": {"age": 90},
... "a/b": {"age": 65},
... "a/c": {"age": 60},
... "a/b/d": {"age": 40},
... "a/b/e": {"age": 35},
... "a/c/f": {"age": 38},
... "a/b/e/g": {"age": 10},
... "a/b/e/h": {"age": 6},
... }
>>> root = dict_to_tree(path_dict)
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
│ ├── g [age=10]
│ └── h [age=6]
└── c [age=60]
└── f [age=38]
Args:
path_attrs (dict): dictionary containing path and node attribute information,
key: path, value: dict of tree attribute and attribute value
sep (str): path separator of input `path_attrs` and created tree, defaults to `/`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
if not len(path_attrs):
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
# Convert dictionary to dataframe
data = pd.DataFrame(path_attrs).T.rename_axis("PATH").reset_index()
return dataframe_to_tree(
data,
sep=sep,
duplicate_name_allowed=duplicate_name_allowed,
node_type=node_type,
)
def nested_dict_to_tree(
node_attrs: dict,
name_key: str = "name",
child_key: str = "children",
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from nested recursive dictionary.
- ``key``: `name_key`, `child_key`, or any attributes key.
- ``value`` of `name_key` (str): node name.
- ``value`` of `child_key` (list): list of dict containing `name_key` and `child_key` (recursive).
>>> from bigtree import nested_dict_to_tree, print_tree
>>> path_dict = {
... "name": "a",
... "age": 90,
... "children": [
... {"name": "b",
... "age": 65,
... "children": [
... {"name": "d", "age": 40},
... {"name": "e", "age": 35, "children": [
... {"name": "g", "age": 10},
... ]},
... ]},
... ],
... }
>>> root = nested_dict_to_tree(path_dict)
>>> print_tree(root, attr_list=["age"])
a [age=90]
└── b [age=65]
├── d [age=40]
└── e [age=35]
└── g [age=10]
Args:
node_attrs (dict): dictionary containing node, children, and node attribute information,
key: `name_key` and `child_key`
value of `name_key` (str): node name
value of `child_key` (list): list of dict containing `name_key` and `child_key` (recursive)
name_key (str): key of node name, value is type str
child_key (str): key of child list, value is type list
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
def recursive_add_child(child_dict, parent_node=None):
child_dict = child_dict.copy()
node_name = child_dict.pop(name_key)
node_children = child_dict.pop(child_key, [])
node = node_type(node_name, parent=parent_node, **child_dict)
for _child in node_children:
recursive_add_child(_child, parent_node=node)
return node
root_node = recursive_add_child(node_attrs)
return root_node
def dataframe_to_tree(
data: pd.DataFrame,
path_col: str = "",
attribute_cols: list = [],
sep: str = "/",
duplicate_name_allowed: bool = True,
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from pandas DataFrame using path, return root of tree.
`path_col` and `attribute_cols` specify columns for node path and attributes to construct tree.
If columns are not specified, `path_col` takes first column and all other columns are `attribute_cols`.
Path in path column can start from root node `name`, or start with `sep`.
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
Path in path column should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
All paths should start from the same root node.
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
>>> import pandas as pd
>>> from bigtree import dataframe_to_tree, print_tree
>>> path_data = pd.DataFrame([
... ["a", 90],
... ["a/b", 65],
... ["a/c", 60],
... ["a/b/d", 40],
... ["a/b/e", 35],
... ["a/c/f", 38],
... ["a/b/e/g", 10],
... ["a/b/e/h", 6],
... ],
... columns=["PATH", "age"]
... )
>>> root = dataframe_to_tree(path_data)
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
│ ├── g [age=10]
│ └── h [age=6]
└── c [age=60]
└── f [age=38]
Args:
data (pandas.DataFrame): data containing path and node attribute information
path_col (str): column of data containing `path_name` information,
if not set, it will take the first column of data
attribute_cols (list): columns of data containing node attribute information,
if not set, it will take all columns of data except `path_col`
sep (str): path separator of input `path_col` and created tree, defaults to `/`
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not path_col:
path_col = data.columns[0]
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(path_col)
data[path_col] = data[path_col].str.lstrip(sep).str.rstrip(sep)
data2 = data.copy()[[path_col] + attribute_cols].astype(str).drop_duplicates()
_duplicate_check = (
data2[path_col]
.value_counts()
.to_frame("counts")
.rename_axis(path_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate path with different attributes\nCheck {_duplicate_check}"
)
root_name = data[path_col].values[0].split(sep)[0]
root_node = node_type(root_name)
add_dataframe_to_tree_by_path(
root_node,
data,
sep=sep,
duplicate_name_allowed=duplicate_name_allowed,
)
root_node.sep = sep
return root_node
def dataframe_to_tree_by_relation(
data: pd.DataFrame,
child_col: str = "",
parent_col: str = "",
attribute_cols: list = [],
node_type: Type[Node] = Node,
) -> Node:
"""Construct tree from pandas DataFrame using parent and child names, return root of tree.
Note that node names must be unique since tree is created from parent-child names,
except for leaf nodes - names of leaf nodes may be repeated as there is no confusion.
`child_col` and `parent_col` specify columns for child name and parent name to construct tree.
`attribute_cols` specify columns for node attribute for child name
If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
columns are `attribute_cols`.
>>> import pandas as pd
>>> from bigtree import dataframe_to_tree_by_relation, print_tree
>>> relation_data = pd.DataFrame([
... ["a", None, 90],
... ["b", "a", 65],
... ["c", "a", 60],
... ["d", "b", 40],
... ["e", "b", 35],
... ["f", "c", 38],
... ["g", "e", 10],
... ["h", "e", 6],
... ],
... columns=["child", "parent", "age"]
... )
>>> root = dataframe_to_tree_by_relation(relation_data)
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
│ ├── g [age=10]
│ └── h [age=6]
└── c [age=60]
└── f [age=38]
Args:
data (pandas.DataFrame): data containing path and node attribute information
child_col (str): column of data containing child name information, defaults to None
if not set, it will take the first column of data
parent_col (str): column of data containing parent name information, defaults to None
if not set, it will take the second column of data
attribute_cols (list): columns of data containing node attribute information,
if not set, it will take all columns of data except `child_col` and `parent_col`
node_type (Type[Node]): node type of tree to be created, defaults to Node
Returns:
(Node)
"""
if not len(data.columns):
raise ValueError("Data does not contain any columns, check `data`")
if not len(data):
raise ValueError("Data does not contain any rows, check `data`")
if not child_col:
child_col = data.columns[0]
if not parent_col:
parent_col = data.columns[1]
if not len(attribute_cols):
attribute_cols = list(data.columns)
attribute_cols.remove(child_col)
attribute_cols.remove(parent_col)
data_check = data.copy()[[child_col, parent_col]].drop_duplicates()
# Filter for child nodes that are parent of other nodes
data_check = data_check[data_check[child_col].isin(data_check[parent_col])]
_duplicate_check = (
data_check[child_col]
.value_counts()
.to_frame("counts")
.rename_axis(child_col)
.reset_index()
)
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
if len(_duplicate_check):
raise ValueError(
f"There exists duplicate child with different parent where the child is also a parent node.\n"
f"Duplicated node names should not happen, but can only exist in leaf nodes to avoid confusion.\n"
f"Check {_duplicate_check}"
)
# If parent-child contains None -> root
root_row = data[data[parent_col].isnull()]
root_names = list(root_row[child_col])
if not len(root_names):
root_names = list(set(data[parent_col]) - set(data[child_col]))
if len(root_names) != 1:
raise ValueError(f"Unable to determine root node\nCheck {root_names}")
root_name = root_names[0]
root_node = node_type(root_name)
def retrieve_attr(row):
node_attrs = row.copy()
node_attrs["name"] = node_attrs[child_col]
del node_attrs[child_col]
del node_attrs[parent_col]
_node_attrs = {k: v for k, v in node_attrs.items() if not np.all(pd.isnull(v))}
return _node_attrs
def recursive_create_child(parent_node):
child_rows = data[data[parent_col] == parent_node.node_name]
for row in child_rows.to_dict(orient="index").values():
child_node = node_type(**retrieve_attr(row))
child_node.parent = parent_node
recursive_create_child(child_node)
# Create root node attributes
if len(root_row):
row = list(root_row.to_dict(orient="index").values())[0]
root_node.set_attrs(retrieve_attr(row))
recursive_create_child(root_node)
return root_node

View File

@@ -0,0 +1,831 @@
import collections
from typing import Any, Dict, List, Tuple, Union
import pandas as pd
from bigtree.node.node import Node
from bigtree.tree.search import find_path
from bigtree.utils.iterators import preorder_iter
__all__ = [
"print_tree",
"yield_tree",
"tree_to_dict",
"tree_to_nested_dict",
"tree_to_dataframe",
"tree_to_dot",
"tree_to_pillow",
]
available_styles = {
"ansi": ("| ", "|-- ", "`-- "),
"ascii": ("| ", "|-- ", "+-- "),
"const": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 "),
"const_bold": ("\u2503 ", "\u2523\u2501\u2501 ", "\u2517\u2501\u2501 "),
"rounded": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2570\u2500\u2500 "),
"double": ("\u2551 ", "\u2560\u2550\u2550 ", "\u255a\u2550\u2550 "),
"custom": ("", "", ""),
}
def print_tree(
tree: Node,
node_name_or_path: str = "",
max_depth: int = None,
attr_list: List[str] = None,
all_attrs: bool = False,
attr_omit_null: bool = True,
attr_bracket: List[str] = ["[", "]"],
style: str = "const",
custom_style: List[str] = [],
):
"""Print tree to console, starting from `tree`.
- Able to select which node to print from, resulting in a subtree, using `node_name_or_path`
- Able to customize for maximum depth to print, using `max_depth`
- Able to choose which attributes to show or show all attributes, using `attr_name_filter` and `all_attrs`
- Able to omit showing of attributes if it is null, using `attr_omit_null`
- Able to customize open and close brackets if attributes are shown, using `attr_bracket`
- Able to customize style, to choose from `ansi`, `ascii`, `const`, `rounded`, `double`, and `custom` style
- Default style is `const` style
- If style is set to custom, user can choose their own style for stem, branch and final stem icons
- Stem, branch, and final stem symbol should have the same number of characters
**Printing tree**
>>> from bigtree import Node, print_tree
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
└── c
**Printing Sub-tree**
>>> print_tree(root, node_name_or_path="b")
b
├── d
└── e
>>> print_tree(root, max_depth=2)
a
├── b
└── c
**Printing Attributes**
>>> print_tree(root, attr_list=["age"])
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
└── c [age=60]
>>> print_tree(root, attr_list=["age"], attr_bracket=["*(", ")"])
a *(age=90)
├── b *(age=65)
│ ├── d *(age=40)
│ └── e *(age=35)
└── c *(age=60)
**Available Styles**
>>> print_tree(root, style="ansi")
a
|-- b
| |-- d
| `-- e
`-- c
>>> print_tree(root, style="ascii")
a
|-- b
| |-- d
| +-- e
+-- c
>>> print_tree(root, style="const")
a
├── b
│ ├── d
│ └── e
└── c
>>> print_tree(root, style="const_bold")
a
┣━━ b
┃ ┣━━ d
┃ ┗━━ e
┗━━ c
>>> print_tree(root, style="rounded")
a
├── b
│ ├── d
│ ╰── e
╰── c
>>> print_tree(root, style="double")
a
╠══ b
║ ╠══ d
║ ╚══ e
╚══ c
Args:
tree (Node): tree to print
node_name_or_path (str): node to print from, becomes the root node of printing
max_depth (int): maximum depth of tree to print, based on `depth` attribute, optional
attr_list (list): list of node attributes to print, optional
all_attrs (bool): indicator to show all attributes, overrides `attr_list`
attr_omit_null (bool): indicator whether to omit showing of null attributes, defaults to True
attr_bracket (List[str]): open and close bracket for `all_attrs` or `attr_list`
style (str): style of print, defaults to abstract style
custom_style (List[str]): style of stem, branch and final stem, used when `style` is set to 'custom'
"""
for pre_str, fill_str, _node in yield_tree(
tree=tree,
node_name_or_path=node_name_or_path,
max_depth=max_depth,
style=style,
custom_style=custom_style,
):
# Get node_str (node name and attributes)
attr_str = ""
if all_attrs or attr_list:
if len(attr_bracket) != 2:
raise ValueError(
f"Expect open and close brackets in `attr_bracket`, received {attr_bracket}"
)
attr_bracket_open, attr_bracket_close = attr_bracket
if all_attrs:
attrs = _node.describe(exclude_attributes=["name"], exclude_prefix="_")
attr_str_list = [f"{k}={v}" for k, v in attrs]
else:
if attr_omit_null:
attr_str_list = [
f"{attr_name}={_node.get_attr(attr_name)}"
for attr_name in attr_list
if _node.get_attr(attr_name)
]
else:
attr_str_list = [
f"{attr_name}={_node.get_attr(attr_name)}"
for attr_name in attr_list
]
attr_str = ", ".join(attr_str_list)
if attr_str:
attr_str = f" {attr_bracket_open}{attr_str}{attr_bracket_close}"
node_str = f"{_node.node_name}{attr_str}"
print(f"{pre_str}{fill_str}{node_str}")
def yield_tree(
tree: Node,
node_name_or_path: str = "",
max_depth: int = None,
style: str = "const",
custom_style: List[str] = [],
):
"""Generator method for customizing printing of tree, starting from `tree`.
- Able to select which node to print from, resulting in a subtree, using `node_name_or_path`
- Able to customize for maximum depth to print, using `max_depth`
- Able to customize style, to choose from `ansi`, `ascii`, `const`, `rounded`, `double`, and `custom` style
- Default style is `const` style
- If style is set to custom, user can choose their own style for stem, branch and final stem icons
- Stem, branch, and final stem symbol should have the same number of characters
**Printing tree**
>>> from bigtree import Node, print_tree
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> for branch, stem, node in yield_tree(root):
... print(f"{branch}{stem}{node.node_name}")
a
├── b
│ ├── d
│ └── e
└── c
**Printing Sub-tree**
>>> for branch, stem, node in yield_tree(root, node_name_or_path="b"):
... print(f"{branch}{stem}{node.node_name}")
b
├── d
└── e
>>> for branch, stem, node in yield_tree(root, max_depth=2):
... print(f"{branch}{stem}{node.node_name}")
a
├── b
└── c
**Available Styles**
>>> for branch, stem, node in yield_tree(root, style="ansi"):
... print(f"{branch}{stem}{node.node_name}")
a
|-- b
| |-- d
| `-- e
`-- c
>>> for branch, stem, node in yield_tree(root, style="ascii"):
... print(f"{branch}{stem}{node.node_name}")
a
|-- b
| |-- d
| +-- e
+-- c
>>> for branch, stem, node in yield_tree(root, style="const"):
... print(f"{branch}{stem}{node.node_name}")
a
├── b
│ ├── d
│ └── e
└── c
>>> for branch, stem, node in yield_tree(root, style="const_bold"):
... print(f"{branch}{stem}{node.node_name}")
a
┣━━ b
┃ ┣━━ d
┃ ┗━━ e
┗━━ c
>>> for branch, stem, node in yield_tree(root, style="rounded"):
... print(f"{branch}{stem}{node.node_name}")
a
├── b
│ ├── d
│ ╰── e
╰── c
>>> for branch, stem, node in yield_tree(root, style="double"):
... print(f"{branch}{stem}{node.node_name}")
a
╠══ b
║ ╠══ d
║ ╚══ e
╚══ c
**Printing Attributes**
>>> for branch, stem, node in yield_tree(root, style="const"):
... print(f"{branch}{stem}{node.node_name} [age={node.age}]")
a [age=90]
├── b [age=65]
│ ├── d [age=40]
│ └── e [age=35]
└── c [age=60]
Args:
tree (Node): tree to print
node_name_or_path (str): node to print from, becomes the root node of printing, optional
max_depth (int): maximum depth of tree to print, based on `depth` attribute, optional
style (str): style of print, defaults to abstract style
custom_style (List[str]): style of stem, branch and final stem, used when `style` is set to 'custom'
"""
if style not in available_styles.keys():
raise ValueError(
f"Choose one of {available_styles.keys()} style, use `custom` to define own style"
)
tree = tree.copy()
if node_name_or_path:
tree = find_path(tree, node_name_or_path)
if not tree.is_root:
tree.parent = None
# Set style
if style == "custom":
if len(custom_style) != 3:
raise ValueError(
"Custom style selected, please specify the style of stem, branch, and final stem in `custom_style`"
)
style_stem, style_branch, style_stem_final = custom_style
else:
style_stem, style_branch, style_stem_final = available_styles[style]
if not len(style_stem) == len(style_branch) == len(style_stem_final):
raise ValueError(
"`style_stem`, `style_branch`, and `style_stem_final` are of different length"
)
gap_str = " " * len(style_stem)
unclosed_depth = set()
initial_depth = tree.depth
for _node in preorder_iter(tree, max_depth=max_depth):
pre_str = ""
fill_str = ""
if not _node.is_root:
node_depth = _node.depth - initial_depth
# Get fill_str (style_branch or style_stem_final)
if _node.right_sibling:
unclosed_depth.add(node_depth)
fill_str = style_branch
else:
if node_depth in unclosed_depth:
unclosed_depth.remove(node_depth)
fill_str = style_stem_final
# Get pre_str (style_stem, style_branch, style_stem_final, or gap)
pre_str = ""
for _depth in range(1, node_depth):
if _depth in unclosed_depth:
pre_str += style_stem
else:
pre_str += gap_str
yield pre_str, fill_str, _node
def tree_to_dict(
tree: Node,
name_key: str = "name",
parent_key: str = "",
attr_dict: dict = {},
all_attrs: bool = False,
max_depth: int = None,
skip_depth: int = None,
leaf_only: bool = False,
) -> Dict[str, Any]:
"""Export tree to dictionary.
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
Exported dictionary will have key as node path, and node attributes as a nested dictionary.
>>> from bigtree import Node, tree_to_dict
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> tree_to_dict(root, name_key="name", parent_key="parent", attr_dict={"age": "person age"})
{'/a': {'name': 'a', 'parent': None, 'person age': 90}, '/a/b': {'name': 'b', 'parent': 'a', 'person age': 65}, '/a/b/d': {'name': 'd', 'parent': 'b', 'person age': 40}, '/a/b/e': {'name': 'e', 'parent': 'b', 'person age': 35}, '/a/c': {'name': 'c', 'parent': 'a', 'person age': 60}}
For a subset of a tree
>>> tree_to_dict(c, name_key="name", parent_key="parent", attr_dict={"age": "person age"})
{'/a/c': {'name': 'c', 'parent': 'a', 'person age': 60}}
Args:
tree (Node): tree to be exported
name_key (str): dictionary key for `node.node_name`, defaults to 'name'
parent_key (str): dictionary key for `node.parent.node_name`, optional
attr_dict (dict): dictionary mapping node attributes to dictionary key,
key: node attributes, value: corresponding dictionary key, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes
max_depth (int): maximum depth to export tree, optional
skip_depth (int): number of initial depth to skip, optional
leaf_only (bool): indicator to retrieve only information from leaf nodes
Returns:
(dict)
"""
tree = tree.copy()
data_dict = {}
def recursive_append(node):
if node:
if (
(not max_depth or node.depth <= max_depth)
and (not skip_depth or node.depth > skip_depth)
and (not leaf_only or node.is_leaf)
):
data_child = {}
if name_key:
data_child[name_key] = node.node_name
if parent_key:
parent_name = None
if node.parent:
parent_name = node.parent.node_name
data_child[parent_key] = parent_name
if all_attrs:
data_child.update(
dict(
node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
)
else:
for k, v in attr_dict.items():
data_child[v] = node.get_attr(k)
data_dict[node.path_name] = data_child
for _node in node.children:
recursive_append(_node)
recursive_append(tree)
return data_dict
def tree_to_nested_dict(
tree: Node,
name_key: str = "name",
child_key: str = "children",
attr_dict: dict = {},
all_attrs: bool = False,
max_depth: int = None,
) -> Dict[str, Any]:
"""Export tree to nested dictionary.
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
Exported dictionary will have key as node attribute names, and children as a nested recursive dictionary.
>>> from bigtree import Node, tree_to_nested_dict
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> tree_to_nested_dict(root, all_attrs=True)
{'name': 'a', 'age': 90, 'children': [{'name': 'b', 'age': 65, 'children': [{'name': 'd', 'age': 40}, {'name': 'e', 'age': 35}]}, {'name': 'c', 'age': 60}]}
Args:
tree (Node): tree to be exported
name_key (str): dictionary key for `node.node_name`, defaults to 'name'
child_key (str): dictionary key for list of children, optional
attr_dict (dict): dictionary mapping node attributes to dictionary key,
key: node attributes, value: corresponding dictionary key, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes
max_depth (int): maximum depth to export tree, optional
Returns:
(dict)
"""
tree = tree.copy()
data_dict = {}
def recursive_append(node, parent_dict):
if node:
if not max_depth or node.depth <= max_depth:
data_child = {name_key: node.node_name}
if all_attrs:
data_child.update(
dict(
node.describe(
exclude_attributes=["name"], exclude_prefix="_"
)
)
)
else:
for k, v in attr_dict.items():
data_child[v] = node.get_attr(k)
if child_key in parent_dict:
parent_dict[child_key].append(data_child)
else:
parent_dict[child_key] = [data_child]
for _node in node.children:
recursive_append(_node, data_child)
recursive_append(tree, data_dict)
return data_dict[child_key][0]
def tree_to_dataframe(
tree: Node,
path_col: str = "path",
name_col: str = "name",
parent_col: str = "",
attr_dict: dict = {},
all_attrs: bool = False,
max_depth: int = None,
skip_depth: int = None,
leaf_only: bool = False,
) -> pd.DataFrame:
"""Export tree to pandas DataFrame.
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
>>> from bigtree import Node, tree_to_dataframe
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> tree_to_dataframe(root, name_col="name", parent_col="parent", path_col="path", attr_dict={"age": "person age"})
path name parent person age
0 /a a None 90
1 /a/b b a 65
2 /a/b/d d b 40
3 /a/b/e e b 35
4 /a/c c a 60
For a subset of a tree.
>>> tree_to_dataframe(b, name_col="name", parent_col="parent", path_col="path", attr_dict={"age": "person age"})
path name parent person age
0 /a/b b a 65
1 /a/b/d d b 40
2 /a/b/e e b 35
Args:
tree (Node): tree to be exported
path_col (str): column name for `node.path_name`, optional
name_col (str): column name for `node.node_name`, defaults to 'name'
parent_col (str): column name for `node.parent.node_name`, optional
attr_dict (dict): dictionary mapping node attributes to column name,
key: node attributes, value: corresponding column in dataframe, optional
all_attrs (bool): indicator whether to retrieve all `Node` attributes
max_depth (int): maximum depth to export tree, optional
skip_depth (int): number of initial depth to skip, optional
leaf_only (bool): indicator to retrieve only information from leaf nodes
Returns:
(pd.DataFrame)
"""
tree = tree.copy()
data_list = []
def recursive_append(node):
if node:
if (
(not max_depth or node.depth <= max_depth)
and (not skip_depth or node.depth > skip_depth)
and (not leaf_only or node.is_leaf)
):
data_child = {}
if path_col:
data_child[path_col] = node.path_name
if name_col:
data_child[name_col] = node.node_name
if parent_col:
parent_name = None
if node.parent:
parent_name = node.parent.node_name
data_child[parent_col] = parent_name
if all_attrs:
data_child.update(
node.describe(exclude_attributes=["name"], exclude_prefix="_")
)
else:
for k, v in attr_dict.items():
data_child[v] = node.get_attr(k)
data_list.append(data_child)
for _node in node.children:
recursive_append(_node)
recursive_append(tree)
return pd.DataFrame(data_list)
def tree_to_dot(
tree: Union[Node, List[Node]],
directed: bool = True,
rankdir: str = "TB",
bg_colour: str = None,
node_colour: str = None,
node_shape: str = None,
edge_colour: str = None,
node_attr: str = None,
edge_attr: str = None,
):
r"""Export tree or list of trees to image.
Posible node attributes include style, fillcolor, shape.
>>> from bigtree import Node, tree_to_dot
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> graph = tree_to_dot(root)
Export to image, dot file, etc.
>>> graph.write_png("tree.png")
>>> graph.write_dot("tree.dot")
Export to string
>>> graph.to_string()
'strict digraph G {\nrankdir=TB;\na0 [label=a];\nb0 [label=b];\na0 -> b0;\nd0 [label=d];\nb0 -> d0;\ne0 [label=e];\nb0 -> e0;\nc0 [label=c];\na0 -> c0;\n}\n'
Defining node and edge attributes
>>> class CustomNode(Node):
... def __init__(self, name, node_shape="", edge_label="", **kwargs):
... super().__init__(name, **kwargs)
... self.node_shape = node_shape
... self.edge_label = edge_label
...
... @property
... def edge_attr(self):
... if self.edge_label:
... return {"label": self.edge_label}
... return {}
...
... @property
... def node_attr(self):
... if self.node_shape:
... return {"shape": self.node_shape}
... return {}
>>>
>>>
>>> root = CustomNode("a", node_shape="circle")
>>> b = CustomNode("b", edge_label="child", parent=root)
>>> c = CustomNode("c", edge_label="child", parent=root)
>>> d = CustomNode("d", node_shape="square", edge_label="child", parent=b)
>>> e = CustomNode("e", node_shape="square", edge_label="child", parent=b)
>>> graph = tree_to_dot(root, node_colour="gold", node_shape="diamond", node_attr="node_attr", edge_attr="edge_attr")
>>> graph.write_png("assets/custom_tree.png")
.. image:: https://github.com/kayjan/bigtree/raw/master/assets/custom_tree.png
Args:
tree (Node/List[Node]): tree or list of trees to be exported
directed (bool): indicator whether graph should be directed or undirected, defaults to True
rankdir (str): set direction of graph layout, defaults to 'TB' (top to bottom), can be 'BT' (bottom to top),
'LR' (left to right), 'RL' (right to left)
bg_colour (str): background color of image, defaults to None
node_colour (str): fill colour of nodes, defaults to None
node_shape (str): shape of nodes, defaults to None
Possible node_shape include "circle", "square", "diamond", "triangle"
edge_colour (str): colour of edges, defaults to None
node_attr (str): `Node` attribute for node style, overrides `node_colour` and `node_shape`, defaults to None.
Possible node style (attribute value) include {"style": "filled", "fillcolor": "gold", "shape": "diamond"}
edge_attr (str): `Node` attribute for edge style, overrides `edge_colour`, defaults to None.
Possible edge style (attribute value) include {"style": "bold", "label": "edge label", "color": "black"}
Returns:
(pydot.Dot)
"""
try:
import pydot
except ImportError: # pragma: no cover
raise ImportError(
"pydot not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
)
# Get style
if bg_colour:
graph_style = dict(bgcolor=bg_colour)
else:
graph_style = dict()
if node_colour:
node_style = dict(style="filled", fillcolor=node_colour)
else:
node_style = dict()
if node_shape:
node_style["shape"] = node_shape
if edge_colour:
edge_style = dict(color=edge_colour)
else:
edge_style = dict()
tree = tree.copy()
if directed:
_graph = pydot.Dot(
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
)
else:
_graph = pydot.Dot(
graph_type="graph", strict=True, rankdir=rankdir, **graph_style
)
if not isinstance(tree, list):
tree = [tree]
for _tree in tree:
if not isinstance(_tree, Node):
raise ValueError("Tree should be of type `Node`, or inherit from `Node`")
name_dict = collections.defaultdict(list)
def recursive_create_node_and_edges(parent_name, child_node):
_node_style = node_style.copy()
_edge_style = edge_style.copy()
child_label = child_node.node_name
if child_node.path_name not in name_dict[child_label]: # pragma: no cover
name_dict[child_label].append(child_node.path_name)
child_name = child_label + str(
name_dict[child_label].index(child_node.path_name)
)
if node_attr and child_node.get_attr(node_attr):
_node_style.update(child_node.get_attr(node_attr))
if edge_attr:
_edge_style.update(child_node.get_attr(edge_attr))
node = pydot.Node(name=child_name, label=child_label, **_node_style)
_graph.add_node(node)
if parent_name is not None:
edge = pydot.Edge(parent_name, child_name, **_edge_style)
_graph.add_edge(edge)
for child in child_node.children:
if child:
recursive_create_node_and_edges(child_name, child)
recursive_create_node_and_edges(None, _tree.root)
return _graph
def tree_to_pillow(
tree: Node,
width: int = 0,
height: int = 0,
start_pos: Tuple[float, float] = (10, 10),
font_family: str = "assets/DejaVuSans.ttf",
font_size: int = 12,
font_colour: Union[Tuple[float, float, float], str] = "black",
bg_colour: Union[Tuple[float, float, float], str] = "white",
**kwargs,
):
"""Export tree to image (JPG, PNG).
Image will be similar format as `print_tree`, accepts additional keyword arguments as input to `yield_tree`
>>> from bigtree import Node, tree_to_pillow
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> pillow_image = tree_to_pillow(root)
Export to image (PNG, JPG) file, etc.
>>> pillow_image.save("tree_pillow.png")
>>> pillow_image.save("tree_pillow.jpg")
Args:
tree (Node): tree to be exported
width (int): width of image, optional as width of image is calculated automatically
height (int): height of image, optional as height of image is calculated automatically
start_pos (Tuple[float, float]): start position of text, (x-offset, y-offset), defaults to (10, 10)
font_family (str): file path of font family, requires .ttf file, defaults to DejaVuSans
font_size (int): font size, defaults to 12
font_colour (Union[List[int], str]): font colour, accepts tuple of RGB values or string, defaults to black
bg_colour (Union[List[int], str]): background of image, accepts tuple of RGB values or string, defaults to white
Returns:
(PIL.Image.Image)
"""
try:
from PIL import Image, ImageDraw, ImageFont
except ImportError: # pragma: no cover
raise ImportError(
"Pillow not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
)
# Initialize font
font = ImageFont.truetype(font_family, font_size)
# Initialize text
image_text = []
for branch, stem, node in yield_tree(tree, **kwargs):
image_text.append(f"{branch}{stem}{node.node_name}\n")
# Calculate image dimension from text, otherwise override with argument
def get_list_of_text_dimensions(text_list):
"""Get list dimensions
Args:
text_list (List[str]): list of texts
Returns:
(List[Iterable[int]]): list of (left, top, right, bottom) bounding box
"""
_image = Image.new("RGB", (0, 0))
_draw = ImageDraw.Draw(_image)
return [_draw.textbbox((0, 0), text_line, font=font) for text_line in text_list]
text_dimensions = get_list_of_text_dimensions(image_text)
text_height = sum(
[text_dimension[3] + text_dimension[1] for text_dimension in text_dimensions]
)
text_width = max(
[text_dimension[2] + text_dimension[0] for text_dimension in text_dimensions]
)
image_text = "".join(image_text)
width = max(width, text_width + 2 * start_pos[0])
height = max(height, text_height + 2 * start_pos[1])
# Initialize and draw image
image = Image.new("RGB", (width, height), bg_colour)
image_draw = ImageDraw.Draw(image)
image_draw.text(start_pos, image_text, font=font, fill=font_colour)
return image

View File

@@ -0,0 +1,201 @@
from typing import Optional, Type
import numpy as np
from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
from bigtree.node.node import Node
from bigtree.tree.construct import dataframe_to_tree
from bigtree.tree.export import tree_to_dataframe
from bigtree.tree.search import find_path
from bigtree.utils.exceptions import NotFoundError
__all__ = ["clone_tree", "prune_tree", "get_tree_diff"]
def clone_tree(tree: BaseNode, node_type: Type[BaseNode]) -> BaseNode:
"""Clone tree to another `Node` type.
If the same type is needed, simply do a tree.copy().
>>> from bigtree import BaseNode, Node, clone_tree
>>> root = BaseNode(name="a")
>>> b = BaseNode(name="b", parent=root)
>>> clone_tree(root, Node)
Node(/a, )
Args:
tree (BaseNode): tree to be cloned, must inherit from BaseNode
node_type (Type[BaseNode]): type of cloned tree
Returns:
(BaseNode)
"""
if not isinstance(tree, BaseNode):
raise ValueError(
"Tree should be of type `BaseNode`, or inherit from `BaseNode`"
)
# Start from root
root_info = dict(tree.root.describe(exclude_prefix="_"))
root_node = node_type(**root_info)
def recursive_add_child(_new_parent_node, _parent_node):
for _child in _parent_node.children:
if _child:
child_info = dict(_child.describe(exclude_prefix="_"))
child_node = node_type(**child_info)
child_node.parent = _new_parent_node
recursive_add_child(child_node, _child)
recursive_add_child(root_node, tree.root)
return root_node
def prune_tree(tree: Node, prune_path: str, sep: str = "/") -> Node:
"""Prune tree to leave only the prune path, returns the root of a *copy* of the original tree.
All siblings along the prune path will be removed.
Prune path name should be unique, can be full path or partial path (trailing part of path) or node name.
Path should contain `Node` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
>>> from bigtree import Node, prune_tree, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> print_tree(root)
a
├── b
└── c
>>> root_pruned = prune_tree(root, "a/b")
>>> print_tree(root_pruned)
a
└── b
Args:
tree (Node): existing tree
prune_path (str): prune path, all siblings along the prune path will be removed
sep (str): path separator
Returns:
(Node)
"""
prune_path = prune_path.replace(sep, tree.sep)
tree_copy = tree.copy()
child = find_path(tree_copy, prune_path)
if not child:
raise NotFoundError(
f"Cannot find any node matching path_name ending with {prune_path}"
)
if isinstance(child.parent, BinaryNode):
while child.parent:
child.parent.children = [child, None]
child = child.parent
return tree_copy
while child.parent:
child.parent.children = [child]
child = child.parent
return tree_copy
def get_tree_diff(
tree: Node, other_tree: Node, only_diff: bool = True
) -> Optional[Node]:
"""Get difference of `tree` to `other_tree`, changes are relative to `tree`.
(+) and (-) will be added relative to `tree`.
- For example: (+) refers to nodes that are in `other_tree` but not `tree`.
- For example: (-) refers to nodes that are in `tree` but not `other_tree`.
Note that only leaf nodes are compared and have (+) or (-) indicator. Intermediate parent nodes are not compared.
Function can return all original tree nodes and differences, or only the differences.
>>> from bigtree import Node, get_tree_diff, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=b)
>>> e = Node("e", parent=root)
>>> print_tree(root)
a
├── b
│ └── d
├── c
└── e
>>> root_other = Node("a")
>>> b_other = Node("b", parent=root_other)
>>> c_other = Node("c", parent=b_other)
>>> d_other = Node("d", parent=root_other)
>>> e_other = Node("e", parent=root_other)
>>> print_tree(root_other)
a
├── b
│ └── c
├── d
└── e
>>> tree_diff = get_tree_diff(root, root_other)
>>> print_tree(tree_diff)
a
├── b
│ ├── c (+)
│ └── d (-)
├── c (-)
└── d (+)
>>> tree_diff = get_tree_diff(root, root_other, only_diff=False)
>>> print_tree(tree_diff)
a
├── b
│ ├── c (+)
│ └── d (-)
├── c (-)
├── d (+)
└── e
Args:
tree (Node): tree to be compared against
other_tree (Node): tree to be compared with
only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True
Returns:
(Node)
"""
tree = tree.copy()
other_tree = other_tree.copy()
name_col = "name"
path_col = "PATH"
indicator_col = "Exists"
data = tree_to_dataframe(tree, name_col=name_col, path_col=path_col, leaf_only=True)
data_other = tree_to_dataframe(
other_tree, name_col=name_col, path_col=path_col, leaf_only=True
)
data_both = data[[path_col, name_col]].merge(
data_other[[path_col, name_col]], how="outer", indicator=indicator_col
)
data_both[name_col] = np.where(
data_both[indicator_col] == "left_only",
data_both[name_col] + " (-)",
np.where(
data_both[indicator_col] == "right_only",
data_both[name_col] + " (+)",
data_both[name_col],
),
)
if only_diff:
data_both = data_both.query(f"{indicator_col} != 'both'")
data_both = data_both.drop(columns=indicator_col).sort_values(path_col)
if len(data_both):
return dataframe_to_tree(
data_both,
node_type=tree.__class__,
)

View File

@@ -0,0 +1,856 @@
import logging
from typing import List, Optional
from bigtree.node.node import Node
from bigtree.tree.search import find_path
from bigtree.utils.exceptions import NotFoundError, TreeError
logging.getLogger(__name__).addHandler(logging.NullHandler())
__all__ = [
"shift_nodes",
"copy_nodes",
"copy_nodes_from_tree_to_tree",
"copy_or_shift_logic",
]
def shift_nodes(
tree: Node,
from_paths: List[str],
to_paths: List[str],
sep: str = "/",
skippable: bool = False,
overriding: bool = False,
merge_children: bool = False,
merge_leaves: bool = False,
delete_children: bool = False,
):
"""Shift nodes from `from_paths` to `to_paths` *in-place*.
- Creates intermediate nodes if to path is not present
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
- Able to shift node only and delete children, defaults to False (nodes are shifted together with children).
For paths in `from_paths` and `to_paths`,
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
- Path name must be unique to one node.
For paths in `to_paths`,
- Can set to empty string or None to delete the path in `from_paths`, note that ``copy`` must be set to False.
If ``merge_children=True``,
- If `to_path` is not present, it shifts children of `from_path`.
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
If ``merge_leaves=True``,
- If `to_path` is not present, it shifts leaves of `from_path`.
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained,
original node in `from_path` is retained.
>>> from bigtree import Node, shift_nodes, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=root)
>>> print_tree(root)
a
├── b
├── c
└── d
>>> shift_nodes(root, ["a/c", "a/d"], ["a/b/c", "a/dummy/d"])
>>> print_tree(root)
a
├── b
│ └── c
└── dummy
└── d
To delete node,
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> print_tree(root)
a
├── b
└── c
>>> shift_nodes(root, ["a/b"], [None])
>>> print_tree(root)
a
└── c
In overriding case,
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> print_tree(root)
a
├── b
│ └── c
│ └── e
└── c
└── d
>>> shift_nodes(root, ["a/b/c"], ["a/c"], overriding=True)
>>> print_tree(root)
a
├── b
└── c
└── e
In ``merge_children`` case, child nodes are shifted instead of the parent node.
- If the path already exists, child nodes are merged with existing children.
- If same node is shifted, the child nodes of the node are merged with the node's parent.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> f = Node("f", parent=root)
>>> g = Node("g", parent=f)
>>> h = Node("h", parent=g)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ └── d
└── f
└── g
└── h
>>> shift_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_children=True)
>>> print_tree(root)
a
├── b
├── c
│ ├── d
│ └── e
├── y
└── g
└── h
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
- If the path already exists, leaf nodes are merged with existing children.
- If same node is copied, the leaf nodes of the node are merged with the node's parent.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> f = Node("f", parent=root)
>>> g = Node("g", parent=f)
>>> h = Node("h", parent=g)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ └── d
└── f
└── g
└── h
>>> shift_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_leaves=True)
>>> print_tree(root)
a
├── b
│ ├── c
│ └── z
├── c
│ ├── d
│ └── e
├── f
│ └── g
├── y
└── h
In ``delete_children`` case, only the node is shifted without its accompanying children/descendants.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
└── c
└── d
>>> shift_nodes(root, ["a/b/z"], ["a/z"], delete_children=True)
>>> print_tree(root)
a
├── b
│ └── c
│ └── e
├── c
│ └── d
└── z
Args:
tree (Node): tree to modify
from_paths (list): original paths to shift nodes from
to_paths (list): new paths to shift nodes to
sep (str): path separator for input paths, applies to `from_path` and `to_path`
skippable (bool): indicator to skip if from path is not found, defaults to False
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
delete_children (bool): indicator to shift node only without children, defaults to False
"""
return copy_or_shift_logic(
tree=tree,
from_paths=from_paths,
to_paths=to_paths,
sep=sep,
copy=False,
skippable=skippable,
overriding=overriding,
merge_children=merge_children,
merge_leaves=merge_leaves,
delete_children=delete_children,
to_tree=None,
) # pragma: no cover
def copy_nodes(
tree: Node,
from_paths: List[str],
to_paths: List[str],
sep: str = "/",
skippable: bool = False,
overriding: bool = False,
merge_children: bool = False,
merge_leaves: bool = False,
delete_children: bool = False,
):
"""Copy nodes from `from_paths` to `to_paths` *in-place*.
- Creates intermediate nodes if to path is not present
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
- Able to copy node only and delete children, defaults to False (nodes are copied together with children).
For paths in `from_paths` and `to_paths`,
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
- Path name must be unique to one node.
If ``merge_children=True``,
- If `to_path` is not present, it copies children of `from_path`.
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
If ``merge_leaves=True``,
- If `to_path` is not present, it copies leaves of `from_path`.
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
>>> from bigtree import Node, copy_nodes, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=root)
>>> print_tree(root)
a
├── b
├── c
└── d
>>> copy_nodes(root, ["a/c", "a/d"], ["a/b/c", "a/dummy/d"])
>>> print_tree(root)
a
├── b
│ └── c
├── c
├── d
└── dummy
└── d
In overriding case,
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> print_tree(root)
a
├── b
│ └── c
│ └── e
└── c
└── d
>>> copy_nodes(root, ["a/b/c"], ["a/c"], overriding=True)
>>> print_tree(root)
a
├── b
│ └── c
│ └── e
└── c
└── e
In ``merge_children`` case, child nodes are copied instead of the parent node.
- If the path already exists, child nodes are merged with existing children.
- If same node is copied, the child nodes of the node are merged with the node's parent.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> f = Node("f", parent=root)
>>> g = Node("g", parent=f)
>>> h = Node("h", parent=g)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ └── d
└── f
└── g
└── h
>>> copy_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_children=True)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ ├── d
│ └── e
├── y
└── g
└── h
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
- If the path already exists, leaf nodes are merged with existing children.
- If same node is copied, the leaf nodes of the node are merged with the node's parent.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> f = Node("f", parent=root)
>>> g = Node("g", parent=f)
>>> h = Node("h", parent=g)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ └── d
└── f
└── g
└── h
>>> copy_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_leaves=True)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ ├── d
│ └── e
├── f
│ └── g
│ └── h
├── y
└── h
In ``delete_children`` case, only the node is copied without its accompanying children/descendants.
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> c2 = Node("c", parent=b)
>>> e = Node("e", parent=c2)
>>> z = Node("z", parent=b)
>>> y = Node("y", parent=z)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
└── c
└── d
>>> copy_nodes(root, ["a/b/z"], ["a/z"], delete_children=True)
>>> print_tree(root)
a
├── b
│ ├── c
│ │ └── e
│ └── z
│ └── y
├── c
│ └── d
└── z
Args:
tree (Node): tree to modify
from_paths (list): original paths to shift nodes from
to_paths (list): new paths to shift nodes to
sep (str): path separator for input paths, applies to `from_path` and `to_path`
skippable (bool): indicator to skip if from path is not found, defaults to False
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
delete_children (bool): indicator to copy node only without children, defaults to False
"""
return copy_or_shift_logic(
tree=tree,
from_paths=from_paths,
to_paths=to_paths,
sep=sep,
copy=True,
skippable=skippable,
overriding=overriding,
merge_children=merge_children,
merge_leaves=merge_leaves,
delete_children=delete_children,
to_tree=None,
) # pragma: no cover
def copy_nodes_from_tree_to_tree(
from_tree: Node,
to_tree: Node,
from_paths: List[str],
to_paths: List[str],
sep: str = "/",
skippable: bool = False,
overriding: bool = False,
merge_children: bool = False,
merge_leaves: bool = False,
delete_children: bool = False,
):
"""Copy nodes from `from_paths` to `to_paths` *in-place*.
- Creates intermediate nodes if to path is not present
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
- Able to copy node only and delete children, defaults to False (nodes are copied together with children).
For paths in `from_paths` and `to_paths`,
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
- Path name must be unique to one node.
If ``merge_children=True``,
- If `to_path` is not present, it copies children of `from_path`
- If `to_path` is present, and ``overriding=False``, original and new children are merged
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
If ``merge_leaves=True``,
- If `to_path` is not present, it copies leaves of `from_path`.
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
>>> from bigtree import Node, copy_nodes_from_tree_to_tree, print_tree
>>> root = Node("a")
>>> b = Node("b", parent=root)
>>> c = Node("c", parent=root)
>>> d = Node("d", parent=c)
>>> e = Node("e", parent=root)
>>> f = Node("f", parent=e)
>>> g = Node("g", parent=f)
>>> print_tree(root)
a
├── b
├── c
│ └── d
└── e
└── f
└── g
>>> root_other = Node("aa")
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/b", "a/c", "a/e"], ["aa/b", "aa/b/c", "aa/dummy/e"])
>>> print_tree(root_other)
aa
├── b
│ └── c
│ └── d
└── dummy
└── e
└── f
└── g
In overriding case,
>>> root_other = Node("aa")
>>> c = Node("c", parent=root_other)
>>> e = Node("e", parent=c)
>>> print_tree(root_other)
aa
└── c
└── e
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/b", "a/c"], ["aa/b", "aa/c"], overriding=True)
>>> print_tree(root_other)
aa
├── b
└── c
└── d
In ``merge_children`` case, child nodes are copied instead of the parent node.
- If the path already exists, child nodes are merged with existing children.
>>> root_other = Node("aa")
>>> c = Node("c", parent=root_other)
>>> e = Node("e", parent=c)
>>> print_tree(root_other)
aa
└── c
└── e
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], merge_children=True)
>>> print_tree(root_other)
aa
├── c
│ ├── e
│ └── d
└── f
└── g
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
- If the path already exists, leaf nodes are merged with existing children.
>>> root_other = Node("aa")
>>> c = Node("c", parent=root_other)
>>> e = Node("e", parent=c)
>>> print_tree(root_other)
aa
└── c
└── e
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], merge_leaves=True)
>>> print_tree(root_other)
aa
├── c
│ ├── e
│ └── d
└── g
In ``delete_children`` case, only the node is copied without its accompanying children/descendants.
>>> root_other = Node("aa")
>>> print_tree(root_other)
aa
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], delete_children=True)
>>> print_tree(root_other)
aa
├── c
└── e
Args:
from_tree (Node): tree to copy nodes from
to_tree (Node): tree to copy nodes to
from_paths (list): original paths to shift nodes from
to_paths (list): new paths to shift nodes to
sep (str): path separator for input paths, applies to `from_path` and `to_path`
skippable (bool): indicator to skip if from path is not found, defaults to False
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
delete_children (bool): indicator to copy node only without children, defaults to False
"""
return copy_or_shift_logic(
tree=from_tree,
from_paths=from_paths,
to_paths=to_paths,
sep=sep,
copy=True,
skippable=skippable,
overriding=overriding,
merge_children=merge_children,
merge_leaves=merge_leaves,
delete_children=delete_children,
to_tree=to_tree,
) # pragma: no cover
def copy_or_shift_logic(
tree: Node,
from_paths: List[str],
to_paths: List[str],
sep: str = "/",
copy: bool = False,
skippable: bool = False,
overriding: bool = False,
merge_children: bool = False,
merge_leaves: bool = False,
delete_children: bool = False,
to_tree: Optional[Node] = None,
):
"""Shift or copy nodes from `from_paths` to `to_paths` *in-place*.
- Creates intermediate nodes if to path is not present
- Able to copy node, defaults to False (nodes are shifted; not copied).
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable)
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden)
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged)
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
- Able to shift/copy node only and delete children, defaults to False (nodes are shifted/copied together with children).
- Able to shift/copy nodes from one tree to another tree, defaults to None (shifting/copying happens within same tree)
For paths in `from_paths` and `to_paths`,
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
- Path name must be unique to one node.
For paths in `to_paths`,
- Can set to empty string or None to delete the path in `from_paths`, note that ``copy`` must be set to False.
If ``merge_children=True``,
- If `to_path` is not present, it shifts/copies children of `from_path`.
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
If ``merge_leaves=True``,
- If `to_path` is not present, it shifts/copies leaves of `from_path`.
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained,
original non-leaf nodes in `from_path` are retained.
Args:
tree (Node): tree to modify
from_paths (list): original paths to shift nodes from
to_paths (list): new paths to shift nodes to
sep (str): path separator for input paths, applies to `from_path` and `to_path`
copy (bool): indicator to copy node, defaults to False
skippable (bool): indicator to skip if from path is not found, defaults to False
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
delete_children (bool): indicator to shift/copy node only without children, defaults to False
to_tree (Node): tree to copy to, defaults to None
"""
if merge_children and merge_leaves:
raise ValueError(
"Invalid shifting, can only specify one type of merging, check `merge_children` and `merge_leaves`"
)
if not (isinstance(from_paths, list) and isinstance(to_paths, list)):
raise ValueError(
"Invalid type, `from_paths` and `to_paths` should be list type"
)
if len(from_paths) != len(to_paths):
raise ValueError(
f"Paths are different length, input `from_paths` have {len(from_paths)} entries, "
f"while output `to_paths` have {len(to_paths)} entries"
)
for from_path, to_path in zip(from_paths, to_paths):
if to_path:
if from_path.split(sep)[-1] != to_path.split(sep)[-1]:
raise ValueError(
f"Unable to assign from_path {from_path} to to_path {to_path}\n"
f"Verify that `sep` is defined correctly for path\n"
f"Alternatively, check that `from_path` and `to_path` is reassigning the same node"
)
transfer_indicator = False
node_type = tree.__class__
tree_sep = tree.sep
if to_tree:
transfer_indicator = True
node_type = to_tree.__class__
tree_sep = to_tree.sep
for from_path, to_path in zip(from_paths, to_paths):
from_path = from_path.replace(sep, tree.sep)
from_node = find_path(tree, from_path)
# From node not found
if not from_node:
if not skippable:
raise NotFoundError(
f"Unable to find from_path {from_path}\n"
f"Set `skippable` to True to skip shifting for nodes not found"
)
else:
logging.info(f"Unable to find from_path {from_path}")
# From node found
else:
# Node to be deleted
if not to_path:
to_node = None
# Node to be copied/shifted
else:
to_path = to_path.replace(sep, tree_sep)
if transfer_indicator:
to_node = find_path(to_tree, to_path)
else:
to_node = find_path(tree, to_path)
# To node found
if to_node:
if from_node == to_node:
if merge_children:
parent = to_node.parent
to_node.parent = None
to_node = parent
elif merge_leaves:
to_node = to_node.parent
else:
raise TreeError(
f"Attempting to shift the same node {from_node.node_name} back to the same position\n"
f"Check from path {from_path} and to path {to_path}\n"
f"Alternatively, set `merge_children` or `merge_leaves` to True if intermediate node is to be removed"
)
elif merge_children:
# Specify override to remove existing node, else children are merged
if not overriding:
logging.info(
f"Path {to_path} already exists and children are merged"
)
else:
logging.info(
f"Path {to_path} already exists and its children be overridden by the merge"
)
parent = to_node.parent
to_node.parent = None
to_node = parent
merge_children = False
elif merge_leaves:
# Specify override to remove existing node, else leaves are merged
if not overriding:
logging.info(
f"Path {to_path} already exists and leaves are merged"
)
else:
logging.info(
f"Path {to_path} already exists and its leaves be overridden by the merge"
)
del to_node.children
else:
if not overriding:
raise TreeError(
f"Path {to_path} already exists and unable to override\n"
f"Set `overriding` to True to perform overrides\n"
f"Alternatively, set `merge_children` to True if nodes are to be merged"
)
logging.info(
f"Path {to_path} already exists and will be overridden"
)
parent = to_node.parent
to_node.parent = None
to_node = parent
# To node not found
else:
# Find parent node
to_path_list = to_path.split(tree_sep)
idx = 1
to_path_parent = tree_sep.join(to_path_list[:-idx])
if transfer_indicator:
to_node = find_path(to_tree, to_path_parent)
else:
to_node = find_path(tree, to_path_parent)
# Create intermediate parent node, if applicable
while (not to_node) & (idx + 1 < len(to_path_list)):
idx += 1
to_path_parent = sep.join(to_path_list[:-idx])
if transfer_indicator:
to_node = find_path(to_tree, to_path_parent)
else:
to_node = find_path(tree, to_path_parent)
if not to_node:
raise NotFoundError(
f"Unable to find to_path {to_path}\n"
f"Please specify valid path to shift node to"
)
for depth in range(len(to_path_list) - idx, len(to_path_list) - 1):
intermediate_child_node = node_type(to_path_list[depth])
intermediate_child_node.parent = to_node
to_node = intermediate_child_node
# Reassign from_node to new parent
if copy:
logging.debug(f"Copying {from_node.node_name}")
from_node = from_node.copy()
if merge_children:
logging.debug(
f"Reassigning children from {from_node.node_name} to {to_node.node_name}"
)
for children in from_node.children:
if delete_children:
del children.children
children.parent = to_node
from_node.parent = None
elif merge_leaves:
logging.debug(
f"Reassigning leaf nodes from {from_node.node_name} to {to_node.node_name}"
)
for children in from_node.leaves:
children.parent = to_node
else:
if delete_children:
del from_node.children
from_node.parent = to_node

View File

@@ -0,0 +1,316 @@
from typing import Any, Callable, Iterable
from bigtree.node.basenode import BaseNode
from bigtree.node.node import Node
from bigtree.utils.exceptions import CorruptedTreeError, SearchError
from bigtree.utils.iterators import preorder_iter
__all__ = [
"findall",
"find",
"find_name",
"find_names",
"find_full_path",
"find_path",
"find_paths",
"find_attr",
"find_attrs",
"find_children",
]
def findall(
tree: BaseNode,
condition: Callable,
max_depth: int = None,
min_count: int = None,
max_count: int = None,
) -> tuple:
"""
Search tree for nodes matching condition (callable function).
>>> from bigtree import Node, findall
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> findall(root, lambda node: node.age > 62)
(Node(/a, age=90), Node(/a/b, age=65))
Args:
tree (BaseNode): tree to search
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
min_count (int): checks for minimum number of occurrence,
raise SearchError if number of results do not meet min_count, defaults to None
max_count (int): checks for maximum number of occurrence,
raise SearchError if number of results do not meet min_count, defaults to None
Returns:
(tuple)
"""
result = tuple(preorder_iter(tree, filter_condition=condition, max_depth=max_depth))
if min_count and len(result) < min_count:
raise SearchError(
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
)
if max_count and len(result) > max_count:
raise SearchError(
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
)
return result
def find(tree: BaseNode, condition: Callable, max_depth: int = None) -> BaseNode:
"""
Search tree for *single node* matching condition (callable function).
>>> from bigtree import Node, find
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find(root, lambda node: node.age == 65)
Node(/a/b, age=65)
>>> find(root, lambda node: node.age > 5)
Traceback (most recent call last):
...
bigtree.utils.exceptions.SearchError: Expected less than 1 element(s), found 4 elements
(Node(/a, age=90), Node(/a/b, age=65), Node(/a/c, age=60), Node(/a/c/d, age=40))
Args:
tree (BaseNode): tree to search
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
Returns:
(BaseNode)
"""
result = findall(tree, condition, max_depth, max_count=1)
if result:
return result[0]
def find_name(tree: Node, name: str, max_depth: int = None) -> Node:
"""
Search tree for single node matching name attribute.
>>> from bigtree import Node, find_name
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_name(root, "c")
Node(/a/c, age=60)
Args:
tree (Node): tree to search
name (str): value to match for name attribute
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
Returns:
(Node)
"""
return find(tree, lambda node: node.node_name == name, max_depth)
def find_names(tree: Node, name: str, max_depth: int = None) -> Iterable[Node]:
"""
Search tree for multiple node(s) matching name attribute.
>>> from bigtree import Node, find_names
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("b", age=40, parent=c)
>>> find_names(root, "c")
(Node(/a/c, age=60),)
>>> find_names(root, "b")
(Node(/a/b, age=65), Node(/a/c/b, age=40))
Args:
tree (Node): tree to search
name (str): value to match for name attribute
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
Returns:
(Iterable[Node])
"""
return findall(tree, lambda node: node.node_name == name, max_depth)
def find_full_path(tree: Node, path_name: str) -> Node:
"""
Search tree for single node matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name must be full path, works similar to `find_path` but faster.
>>> from bigtree import Node, find_full_path
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_full_path(root, "/a/c/d")
Node(/a/c/d, age=40)
Args:
tree (Node): tree to search
path_name (str): value to match (full path) of path_name attribute
Returns:
(Node)
"""
path_name = path_name.rstrip(tree.sep).lstrip(tree.sep)
path_list = path_name.split(tree.sep)
if path_list[0] != tree.root.node_name:
raise ValueError(
f"Path {path_name} does not match the root node name {tree.root.node_name}"
)
parent_node = tree.root
child_node = parent_node
for child_name in path_list[1:]:
child_node = find_children(parent_node, child_name)
if not child_node:
break
parent_node = child_node
return child_node
def find_path(tree: Node, path_name: str) -> Node:
"""
Search tree for single node matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name can be full path or partial path (trailing part of path) or node name.
>>> from bigtree import Node, find_path
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_path(root, "c")
Node(/a/c, age=60)
>>> find_path(root, "/c")
Node(/a/c, age=60)
Args:
tree (Node): tree to search
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
Returns:
(Node)
"""
path_name = path_name.rstrip(tree.sep)
return find(tree, lambda node: node.path_name.endswith(path_name))
def find_paths(tree: Node, path_name: str) -> tuple:
"""
Search tree for multiple nodes matching path attribute.
- Path name can be with or without leading tree path separator symbol.
- Path name can be partial path (trailing part of path) or node name.
>>> from bigtree import Node, find_paths
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("c", age=40, parent=c)
>>> find_paths(root, "/a/c")
(Node(/a/c, age=60),)
>>> find_paths(root, "/c")
(Node(/a/c, age=60), Node(/a/c/c, age=40))
Args:
tree (Node): tree to search
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
Returns:
(tuple)
"""
path_name = path_name.rstrip(tree.sep)
return findall(tree, lambda node: node.path_name.endswith(path_name))
def find_attr(
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = None
) -> BaseNode:
"""
Search tree for single node matching custom attribute.
>>> from bigtree import Node, find_attr
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_attr(root, "age", 65)
Node(/a/b, age=65)
Args:
tree (BaseNode): tree to search
attr_name (str): attribute name to perform matching
attr_value (Any): value to match for attr_name attribute
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
Returns:
(BaseNode)
"""
return find(
tree, lambda node: node.__getattribute__(attr_name) == attr_value, max_depth
)
def find_attrs(
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = None
) -> tuple:
"""
Search tree for node(s) matching custom attribute.
>>> from bigtree import Node, find_attrs
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=65, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_attrs(root, "age", 65)
(Node(/a/b, age=65), Node(/a/c, age=65))
Args:
tree (BaseNode): tree to search
attr_name (str): attribute name to perform matching
attr_value (Any): value to match for attr_name attribute
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
Returns:
(tuple)
"""
return findall(
tree, lambda node: node.__getattribute__(attr_name) == attr_value, max_depth
)
def find_children(tree: Node, name: str) -> Node:
"""
Search tree for single node matching name attribute.
>>> from bigtree import Node, find_children
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=c)
>>> find_children(root, "c")
Node(/a/c, age=60)
>>> find_children(c, "d")
Node(/a/c/d, age=40)
Args:
tree (Node): tree to search, parent node
name (str): value to match for name attribute, child node
Returns:
(Node)
"""
child = [node for node in tree.children if node and node.node_name == name]
if len(child) > 1: # pragma: no cover
raise CorruptedTreeError(
f"There are more than one path for {child[0].path_name}, check {child}"
)
elif len(child):
return child[0]

View File

@@ -0,0 +1,32 @@
class TreeError(Exception):
pass
class LoopError(TreeError):
"""Error during node creation"""
pass
class CorruptedTreeError(TreeError):
"""Error during node creation or tree creation"""
pass
class DuplicatedNodeError(TreeError):
"""Error during tree creation"""
pass
class NotFoundError(TreeError):
"""Error during tree creation or tree search"""
pass
class SearchError(TreeError):
"""Error during tree search"""
pass

View File

@@ -0,0 +1,371 @@
from typing import Callable, Iterable, List, Tuple
__all__ = [
"inorder_iter",
"preorder_iter",
"postorder_iter",
"levelorder_iter",
"levelordergroup_iter",
"dag_iterator",
]
def inorder_iter(
tree,
filter_condition: Callable = None,
max_depth: int = None,
) -> Iterable:
"""Iterate through all children of a tree.
In Iteration Algorithm, LNR
1. Recursively traverse the current node's left subtree.
2. Visit the current node.
3. Recursively traverse the current node's right subtree.
>>> from bigtree import BinaryNode, list_to_binarytree, inorder_iter, print_tree
>>> num_list = [1, 2, 3, 4, 5, 6, 7, 8]
>>> root = list_to_binarytree(num_list)
>>> print_tree(root)
1
├── 2
│ ├── 4
│ │ └── 8
│ └── 5
└── 3
├── 6
└── 7
>>> [node.node_name for node in inorder_iter(root)]
['8', '4', '2', '5', '1', '6', '3', '7']
>>> [node.node_name for node in inorder_iter(root, filter_condition=lambda x: x.node_name in ["1", "4", "3", "6", "7"])]
['4', '1', '6', '3', '7']
>>> [node.node_name for node in inorder_iter(root, max_depth=3)]
['4', '2', '5', '1', '6', '3', '7']
Args:
tree (BaseNode): input tree
filter_condition (Callable): function that takes in node as argument, optional
Returns node if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Iterable[BaseNode])
"""
if tree and (not max_depth or not tree.depth > max_depth):
yield from inorder_iter(tree.left, filter_condition, max_depth)
if not filter_condition or filter_condition(tree):
yield tree
yield from inorder_iter(tree.right, filter_condition, max_depth)
def preorder_iter(
tree,
filter_condition: Callable = None,
stop_condition: Callable = None,
max_depth: int = None,
) -> Iterable:
"""Iterate through all children of a tree.
Pre-Order Iteration Algorithm, NLR
1. Visit the current node.
2. Recursively traverse the current node's left subtree.
3. Recursively traverse the current node's right subtree.
It is topologically sorted because a parent node is processed before its child nodes.
>>> from bigtree import Node, list_to_tree, preorder_iter, print_tree
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in preorder_iter(root)]
['a', 'b', 'd', 'e', 'g', 'h', 'c', 'f']
>>> [node.node_name for node in preorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['a', 'd', 'e', 'g', 'f']
>>> [node.node_name for node in preorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
['a', 'b', 'd', 'c', 'f']
>>> [node.node_name for node in preorder_iter(root, max_depth=3)]
['a', 'b', 'd', 'e', 'c', 'f']
Args:
tree (BaseNode): input tree
filter_condition (Callable): function that takes in node as argument, optional
Returns node if condition evaluates to `True`
stop_condition (Callable): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Iterable[BaseNode])
"""
if (
tree
and (not max_depth or not tree.depth > max_depth)
and (not stop_condition or not stop_condition(tree))
):
if not filter_condition or filter_condition(tree):
yield tree
for child in tree.children:
yield from preorder_iter(child, filter_condition, stop_condition, max_depth)
def postorder_iter(
tree,
filter_condition: Callable = None,
stop_condition: Callable = None,
max_depth: int = None,
) -> Iterable:
"""Iterate through all children of a tree.
Post-Order Iteration Algorithm, LRN
1. Recursively traverse the current node's left subtree.
2. Recursively traverse the current node's right subtree.
3. Visit the current node.
>>> from bigtree import Node, list_to_tree, postorder_iter, print_tree
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in postorder_iter(root)]
['d', 'g', 'h', 'e', 'b', 'f', 'c', 'a']
>>> [node.node_name for node in postorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['d', 'g', 'e', 'f', 'a']
>>> [node.node_name for node in postorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
['d', 'b', 'f', 'c', 'a']
>>> [node.node_name for node in postorder_iter(root, max_depth=3)]
['d', 'e', 'b', 'f', 'c', 'a']
Args:
tree (BaseNode): input tree
filter_condition (Callable): function that takes in node as argument, optional
Returns node if condition evaluates to `True`
stop_condition (Callable): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
Returns:
(Iterable[BaseNode])
"""
if (
tree
and (not max_depth or not tree.depth > max_depth)
and (not stop_condition or not stop_condition(tree))
):
for child in tree.children:
yield from postorder_iter(
child, filter_condition, stop_condition, max_depth
)
if not filter_condition or filter_condition(tree):
yield tree
def levelorder_iter(
tree,
filter_condition: Callable = None,
stop_condition: Callable = None,
max_depth: int = None,
) -> Iterable:
"""Iterate through all children of a tree.
Level Order Algorithm
1. Recursively traverse the nodes on same level.
>>> from bigtree import Node, list_to_tree, levelorder_iter, print_tree
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [node.node_name for node in levelorder_iter(root)]
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
>>> [node.node_name for node in levelorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
['a', 'd', 'e', 'f', 'g']
>>> [node.node_name for node in levelorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
['a', 'b', 'c', 'd', 'f']
>>> [node.node_name for node in levelorder_iter(root, max_depth=3)]
['a', 'b', 'c', 'd', 'e', 'f']
Args:
tree (BaseNode): input tree
filter_condition (Callable): function that takes in node as argument, optional
Returns node if condition evaluates to `True`
stop_condition (Callable): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[BaseNode])
"""
if not isinstance(tree, List):
tree = [tree]
next_level = []
for _tree in tree:
if _tree:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
yield _tree
next_level.extend(list(_tree.children))
if len(next_level):
yield from levelorder_iter(
next_level, filter_condition, stop_condition, max_depth
)
def levelordergroup_iter(
tree,
filter_condition: Callable = None,
stop_condition: Callable = None,
max_depth: int = None,
) -> Iterable[Iterable]:
"""Iterate through all children of a tree.
Level Order Group Algorithm
1. Recursively traverse the nodes on same level, returns nodes level by level in a nested list.
>>> from bigtree import Node, list_to_tree, levelordergroup_iter, print_tree
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
>>> root = list_to_tree(path_list)
>>> print_tree(root)
a
├── b
│ ├── d
│ └── e
│ ├── g
│ └── h
└── c
└── f
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root)]
[['a'], ['b', 'c'], ['d', 'e', 'f'], ['g', 'h']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
[['a'], [], ['d', 'e', 'f'], ['g']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, stop_condition=lambda x: x.node_name=="e")]
[['a'], ['b', 'c'], ['d', 'f']]
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, max_depth=3)]
[['a'], ['b', 'c'], ['d', 'e', 'f']]
Args:
tree (BaseNode): input tree
filter_condition (Callable): function that takes in node as argument, optional
Returns node if condition evaluates to `True`
stop_condition (Callable): function that takes in node as argument, optional
Stops iteration if condition evaluates to `True`
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
Returns:
(Iterable[Iterable])
"""
if not isinstance(tree, List):
tree = [tree]
current_tree = []
next_tree = []
for _tree in tree:
if (not max_depth or not _tree.depth > max_depth) and (
not stop_condition or not stop_condition(_tree)
):
if not filter_condition or filter_condition(_tree):
current_tree.append(_tree)
next_tree.extend([_child for _child in _tree.children if _child])
yield tuple(current_tree)
if len(next_tree) and (not max_depth or not next_tree[0].depth > max_depth):
yield from levelordergroup_iter(
next_tree, filter_condition, stop_condition, max_depth
)
def dag_iterator(dag) -> Iterable[Tuple]:
"""Iterate through all nodes of a Directed Acyclic Graph (DAG).
Note that node names must be unique.
Note that DAG must at least have two nodes to be shown on graph.
1. Visit the current node.
2. Recursively traverse the current node's parents.
3. Recursively traverse the current node's children.
>>> from bigtree import DAGNode, dag_iterator
>>> a = DAGNode("a", step=1)
>>> b = DAGNode("b", step=1)
>>> c = DAGNode("c", step=2, parents=[a, b])
>>> d = DAGNode("d", step=2, parents=[a, c])
>>> e = DAGNode("e", step=3, parents=[d])
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(a)]
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
Args:
dag (DAGNode): input dag
Returns:
(Iterable[Tuple[DAGNode, DAGNode]])
"""
visited_nodes = set()
def recursively_parse_dag(node):
node_name = node.node_name
visited_nodes.add(node_name)
# Parse upwards
for parent in node.parents:
parent_name = parent.node_name
if parent_name not in visited_nodes:
yield parent, node
# Parse downwards
for child in node.children:
child_name = child.node_name
if child_name not in visited_nodes:
yield node, child
# Parse upwards
for parent in node.parents:
parent_name = parent.node_name
if parent_name not in visited_nodes:
yield from recursively_parse_dag(parent)
# Parse downwards
for child in node.children:
child_name = child.node_name
if child_name not in visited_nodes:
yield from recursively_parse_dag(child)
yield from recursively_parse_dag(dag)

Some files were not shown because too many files have changed in this diff Show More