Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 96a2bfb362 | |||
| 6b59fe16ce |
@@ -250,7 +250,7 @@ QGroupBox::title#MACHINE2 {
|
||||
border: 2px solid #98c998;
|
||||
border-radius: 3px;
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #ffffff, stop: 1#98c998);
|
||||
stop: 0 #ffffff, stop: 1 #98c998);
|
||||
}
|
||||
|
||||
QGroupBox#Machine::disabled
|
||||
@@ -271,6 +271,102 @@ QGroupBox#Porthos::disabled
|
||||
}
|
||||
|
||||
|
||||
|
||||
QWidget#INJECTOR, QTabWidget#INJECTOR
|
||||
{
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #FFFFFF, stop: 1 #008b8b);
|
||||
color: black;
|
||||
font-size: 10pt;
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-family: "Sans Serif";
|
||||
border-radius: 0px;
|
||||
margin-top: 0.0ex;
|
||||
margin-left: 0.0ex;
|
||||
padding-top: 2px;
|
||||
padding-bottom: 4px;
|
||||
}
|
||||
|
||||
QGroupBox#INJECTOR
|
||||
{
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #008b8b, stop: 1 #ffffff);
|
||||
color: #008b8b;
|
||||
font-size: 10pt;
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-family: "Sans Serif";
|
||||
border: 2px solid #008b8b;
|
||||
border-radius: 5px;
|
||||
margin-top: 1.5ex;
|
||||
margin-left: 0.0ex;
|
||||
margin-bottom: 0.0ex;
|
||||
padding-top: 2px;
|
||||
padding-bottom: 4px;
|
||||
qproperty-alignment: 'AlignCenter | AlignVCenter';
|
||||
}
|
||||
|
||||
QGroupBox::title#INJECTOR {
|
||||
subcontrol-origin: margin;
|
||||
subcontrol-position: top center;
|
||||
padding: 2px 2px 2px 2px;
|
||||
margin: 0px 0px 0px 0px;
|
||||
border: 2px solid #008b8b;
|
||||
border-radius: 3px;
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #ffffff , stop: 1 #008b8b);
|
||||
}
|
||||
|
||||
|
||||
QWidget#CYCLOTRON, QTabWidget#CYCLOTRON
|
||||
{
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #FFFFFF, stop: 1 #000047ab);
|
||||
color: black;
|
||||
font-size: 10pt;
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-family: "Sans Serif";
|
||||
border-radius: 0px;
|
||||
margin-top: 0.0ex;
|
||||
margin-left: 0.0ex;
|
||||
padding-top: 2px;
|
||||
padding-bottom: 4px;
|
||||
}
|
||||
|
||||
QGroupBox#CYCLOTRON
|
||||
{
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #0047ab, stop: 1 #ffffff);
|
||||
color: #0047ab;
|
||||
font-size: 10pt;
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-family: "Sans Serif";
|
||||
border: 2px solid #0047ab;
|
||||
border-radius: 5px;
|
||||
margin-top: 1.5ex;
|
||||
margin-left: 0.0ex;
|
||||
margin-bottom: 0.0ex;
|
||||
padding-top: 2px;
|
||||
padding-bottom: 4px;
|
||||
qproperty-alignment: 'AlignCenter | AlignVCenter';
|
||||
}
|
||||
|
||||
QGroupBox::title#CYCLOTRON {
|
||||
subcontrol-origin: margin;
|
||||
subcontrol-position: top center;
|
||||
padding: 2px 2px 2px 2px;
|
||||
margin: 0px 0px 0px 0px;
|
||||
border: 2px solid #0047ab;
|
||||
border-radius: 3px;
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
stop: 0 #ffffff , stop: 1 #0047ab);
|
||||
}
|
||||
|
||||
|
||||
|
||||
QWidget#MACHINE, QTabWidget#MACHINE
|
||||
{
|
||||
background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
||||
|
||||
13
packages/EGG-INFO/PKG-INFO
Normal file
13
packages/EGG-INFO/PKG-INFO
Normal file
@@ -0,0 +1,13 @@
|
||||
Metadata-Version: 1.1
|
||||
Name: pyscan
|
||||
Version: 2.8.0
|
||||
Summary: PyScan is a python class that performs a scan for single or multiple given knobs.
|
||||
Home-page: UNKNOWN
|
||||
Author: Paul Scherrer Institute
|
||||
Author-email: UNKNOWN
|
||||
License: UNKNOWN
|
||||
Description: UNKNOWN
|
||||
Platform: UNKNOWN
|
||||
Requires: numpy
|
||||
Requires: pcaspy
|
||||
Requires: requests
|
||||
32
packages/EGG-INFO/SOURCES.txt
Normal file
32
packages/EGG-INFO/SOURCES.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
README.md
|
||||
setup.py
|
||||
pyscan/__init__.py
|
||||
pyscan/config.py
|
||||
pyscan/scan.py
|
||||
pyscan/scan_actions.py
|
||||
pyscan/scan_parameters.py
|
||||
pyscan/scanner.py
|
||||
pyscan/utils.py
|
||||
pyscan.egg-info/PKG-INFO
|
||||
pyscan.egg-info/SOURCES.txt
|
||||
pyscan.egg-info/dependency_links.txt
|
||||
pyscan.egg-info/top_level.txt
|
||||
pyscan/dal/__init__.py
|
||||
pyscan/dal/bsread_dal.py
|
||||
pyscan/dal/epics_dal.py
|
||||
pyscan/dal/function_dal.py
|
||||
pyscan/dal/pshell_dal.py
|
||||
pyscan/interface/__init__.py
|
||||
pyscan/interface/pshell.py
|
||||
pyscan/interface/pyScan/__init__.py
|
||||
pyscan/interface/pyScan/scan.py
|
||||
pyscan/interface/pyScan/utils.py
|
||||
pyscan/positioner/__init__.py
|
||||
pyscan/positioner/area.py
|
||||
pyscan/positioner/bsread.py
|
||||
pyscan/positioner/compound.py
|
||||
pyscan/positioner/line.py
|
||||
pyscan/positioner/serial.py
|
||||
pyscan/positioner/static.py
|
||||
pyscan/positioner/time.py
|
||||
pyscan/positioner/vector.py
|
||||
1
packages/EGG-INFO/dependency_links.txt
Normal file
1
packages/EGG-INFO/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
packages/EGG-INFO/top_level.txt
Normal file
1
packages/EGG-INFO/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
pyscan
|
||||
1
packages/EGG-INFO/zip-safe
Normal file
1
packages/EGG-INFO/zip-safe
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
111
packages/EGG-INFO_elog/PKG-INFO
Normal file
111
packages/EGG-INFO_elog/PKG-INFO
Normal file
@@ -0,0 +1,111 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: elog
|
||||
Version: 1.3.4
|
||||
Summary: Python library to access Elog.
|
||||
Home-page: https://github.com/paulscherrerinstitute/py_elog
|
||||
Author: Paul Scherrer Institute (PSI)
|
||||
Author-email: UNKNOWN
|
||||
License: UNKNOWN
|
||||
Description: [](https://travis-ci.org/paulscherrerinstitute/py_elog) [](https://ci.appveyor.com/project/simongregorebner/py-elog)
|
||||
|
||||
# Overview
|
||||
This Python module provides a native interface [electronic logbooks](https://midas.psi.ch/elog/). It is compatible with Python versions 3.5 and higher.
|
||||
|
||||
# Usage
|
||||
|
||||
For accessing a logbook at ```http[s]://<hostename>:<port>/[<subdir>/]<logbook>/[<msg_id>]``` a logbook handle must be retrieved.
|
||||
|
||||
```python
|
||||
import elog
|
||||
|
||||
# Open GFA SwissFEL test logbook
|
||||
logbook = elog.open('https://elog-gfa.psi.ch/SwissFEL+test/')
|
||||
|
||||
# Contstructor using detailed arguments
|
||||
# Open demo logbook on local host: http://localhost:8080/demo/
|
||||
logbook = elog.open('localhost', 'demo', port=8080, use_ssl=False)
|
||||
```
|
||||
|
||||
Once you have hold of the logbook handle one of its public methods can be used to read, create, reply to, edit or delete the message.
|
||||
|
||||
## Get Existing Message Ids
|
||||
Get all the existing message ids of a logbook
|
||||
|
||||
```python
|
||||
message_ids = logbook.get_message_ids()
|
||||
```
|
||||
|
||||
To get if of the last inserted message
|
||||
```python
|
||||
last_message_id = logbook.get_last_message_id()
|
||||
```
|
||||
|
||||
## Read Message
|
||||
|
||||
```python
|
||||
# Read message with with message ID = 23
|
||||
message, attributes, attachments = logbook.read(23)
|
||||
```
|
||||
|
||||
## Create Message
|
||||
|
||||
```python
|
||||
# Create new message with some text, attributes (dict of attributes + kwargs) and attachments
|
||||
new_msg_id = logbook.post('This is message text', attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
|
||||
```
|
||||
|
||||
What attributes are required is determined by the configuration of the elog server (keywork `Required Attributes`).
|
||||
If the configuration looks like this:
|
||||
|
||||
```
|
||||
Required Attributes = Author, Type
|
||||
```
|
||||
|
||||
You have to provide author and type when posting a message.
|
||||
|
||||
In case type need to be specified, the supported keywords can as well be found in the elog configuration with the key `Options Type`.
|
||||
|
||||
If the config looks like this:
|
||||
```
|
||||
Options Type = Routine, Software Installation, Problem Fixed, Configuration, Other
|
||||
```
|
||||
|
||||
A working create call would look like this:
|
||||
|
||||
```python
|
||||
new_msg_id = logbook.post('This is message text', author='me', type='Routine')
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Reply to Message
|
||||
|
||||
```python
|
||||
# Reply to message with ID=23
|
||||
new_msg_id = logbook.post('This is a reply', msg_id=23, reply=True, attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
|
||||
```
|
||||
|
||||
## Edit Message
|
||||
|
||||
```python
|
||||
# Edit message with ID=23. Changed message text, some attributes (dict of edited attributes + kwargs) and new attachments
|
||||
edited_msg_id = logbook.post('This is new message text', msg_id=23, attributes=dict_of_changed_attributes, attachments=list_of_new_attachments, attribute_as_param='new value')
|
||||
```
|
||||
|
||||
## Delete Message (and all its replies)
|
||||
|
||||
```python
|
||||
# Delete message with ID=23. All its replies will also be deleted.
|
||||
logbook.delete(23)
|
||||
```
|
||||
|
||||
__Note:__ Due to the way elog implements delete this function is only supported on english logbooks.
|
||||
|
||||
# Installation
|
||||
The Elog module and only depends on the `passlib` and `requests` library used for password encryption and http(s) communication. It is packed as [anaconda package](https://anaconda.org/paulscherrerinstitute/elog) and can be installed as follows:
|
||||
|
||||
```bash
|
||||
conda install -c paulscherrerinstitute elog
|
||||
```
|
||||
Keywords: elog,electronic,logbook
|
||||
Platform: UNKNOWN
|
||||
8
packages/EGG-INFO_elog/SOURCES.txt
Normal file
8
packages/EGG-INFO_elog/SOURCES.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
setup.py
|
||||
elog/__init__.py
|
||||
elog/logbook.py
|
||||
elog/logbook_exceptions.py
|
||||
elog.egg-info/PKG-INFO
|
||||
elog.egg-info/SOURCES.txt
|
||||
elog.egg-info/dependency_links.txt
|
||||
elog.egg-info/top_level.txt
|
||||
1
packages/EGG-INFO_elog/dependency_links.txt
Normal file
1
packages/EGG-INFO_elog/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
packages/EGG-INFO_elog/top_level.txt
Normal file
1
packages/EGG-INFO_elog/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
elog
|
||||
1
packages/EGG-INFO_elog/zip-safe
Normal file
1
packages/EGG-INFO_elog/zip-safe
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
packages/elog.pth
Normal file
1
packages/elog.pth
Normal file
@@ -0,0 +1 @@
|
||||
./elog-1.3.4-py3.7.egg
|
||||
13
packages/elog/__init__.py
Normal file
13
packages/elog/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from elog.logbook import Logbook
|
||||
from elog.logbook import LogbookError, LogbookAuthenticationError, LogbookServerProblem, LogbookMessageRejected, \
|
||||
LogbookInvalidMessageID, LogbookInvalidAttachmentType
|
||||
|
||||
|
||||
def open(*args, **kwargs):
|
||||
"""
|
||||
Will return a Logbook object. All arguments are passed to the logbook constructor.
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: Logbook() instance
|
||||
"""
|
||||
return Logbook(*args, **kwargs)
|
||||
571
packages/elog/logbook.py
Normal file
571
packages/elog/logbook.py
Normal file
@@ -0,0 +1,571 @@
|
||||
import requests
|
||||
import urllib.parse
|
||||
import os
|
||||
import builtins
|
||||
import re
|
||||
from elog.logbook_exceptions import *
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Logbook(object):
|
||||
"""
|
||||
Logbook provides methods to interface with logbook on location: "server:port/subdir/logbook". User can create,
|
||||
edit, delete logbook messages.
|
||||
"""
|
||||
|
||||
def __init__(self, hostname, logbook='', port=None, user=None, password=None, subdir='', use_ssl=True,
|
||||
encrypt_pwd=True):
|
||||
"""
|
||||
:param hostname: elog server hostname. If whole url is specified here, it will be parsed and arguments:
|
||||
"logbook, port, subdir, use_ssl" will be overwritten by parsed values.
|
||||
:param logbook: name of the logbook on the elog server
|
||||
:param port: elog server port (if not specified will default to '80' if use_ssl=False or '443' if use_ssl=True
|
||||
:param user: username (if authentication needed)
|
||||
:param password: password (if authentication needed) Password will be encrypted with sha256 unless
|
||||
encrypt_pwd=False (default: True)
|
||||
:param subdir: subdirectory of logbooks locations
|
||||
:param use_ssl: connect using ssl (ignored if url starts with 'http://'' or 'https://'?
|
||||
:param encrypt_pwd: To avoid exposing password in the code, this flag can be set to False and password
|
||||
will then be handled as it is (user needs to provide sha256 encrypted password with
|
||||
salt= '' and rounds=5000)
|
||||
:return:
|
||||
"""
|
||||
hostname = hostname.strip()
|
||||
|
||||
# parse url to see if some parameters are defined with url
|
||||
parsed_url = urllib.parse.urlsplit(hostname)
|
||||
|
||||
# ---- handle SSL -----
|
||||
# hostname must be modified according to use_ssl flag. If hostname starts with https:// or http://
|
||||
# the use_ssl flag is ignored
|
||||
url_scheme = parsed_url.scheme
|
||||
if url_scheme == 'http':
|
||||
use_ssl = False
|
||||
|
||||
elif url_scheme == 'https':
|
||||
use_ssl = True
|
||||
|
||||
elif not url_scheme:
|
||||
# add http or https
|
||||
if use_ssl:
|
||||
url_scheme = 'https'
|
||||
else:
|
||||
url_scheme = 'http'
|
||||
|
||||
# ---- handle port -----
|
||||
# 1) by default use port defined in the url
|
||||
# 2) remove any 'default' ports such as 80 for http and 443 for https
|
||||
# 3) if port not defined in url and not 'default' add it to netloc
|
||||
|
||||
netloc = parsed_url.netloc
|
||||
if netloc == "" and "localhost" in hostname:
|
||||
netloc = 'localhost'
|
||||
netloc_split = netloc.split(':')
|
||||
if len(netloc_split) > 1:
|
||||
# port defined in url --> remove if needed
|
||||
port = netloc_split[1]
|
||||
if (port == 80 and not use_ssl) or (port == 443 and use_ssl):
|
||||
netloc = netloc_split[0]
|
||||
|
||||
else:
|
||||
# add port info if needed
|
||||
if port is not None and not (port == 80 and not use_ssl) and not (port == 443 and use_ssl):
|
||||
netloc += ':{}'.format(port)
|
||||
|
||||
# ---- handle subdir and logbook -----
|
||||
# parsed_url.path = /<subdir>/<logbook>/
|
||||
|
||||
# Remove last '/' for easier parsing
|
||||
url_path = parsed_url.path
|
||||
if url_path.endswith('/'):
|
||||
url_path = url_path[:-1]
|
||||
|
||||
splitted_path = url_path.split('/')
|
||||
if url_path and len(splitted_path) > 1:
|
||||
# If here ... then at least some part of path is defined.
|
||||
|
||||
# If logbook defined --> treat path current path as subdir and add logbook at the end
|
||||
# to define the full path. Else treat existing path as <subdir>/<logbook>.
|
||||
# Put first and last '/' back on its place
|
||||
if logbook:
|
||||
url_path += '/{}'.format(logbook)
|
||||
else:
|
||||
logbook = splitted_path[-1]
|
||||
|
||||
else:
|
||||
# There is nothing. Use arguments.
|
||||
url_path = subdir + '/' + logbook
|
||||
|
||||
# urllib.parse.quote replaces special characters with %xx escapes
|
||||
# self._logbook_path = urllib.parse.quote('/' + url_path + '/').replace('//', '/')
|
||||
self._logbook_path = ('/' + url_path + '/').replace('//', '/')
|
||||
|
||||
self._url = url_scheme + '://' + netloc + self._logbook_path
|
||||
self.logbook = logbook
|
||||
self._user = user
|
||||
self._password = _handle_pswd(password, encrypt_pwd)
|
||||
|
||||
def post(self, message, msg_id=None, reply=False, attributes=None, attachments=None, encoding=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Posts message to the logbook. If msg_id is not specified new message will be created, otherwise existing
|
||||
message will be edited, or a reply (if reply=True) to it will be created. This method returns the msg_id
|
||||
of the newly created message.
|
||||
|
||||
:param message: string with message text
|
||||
:param msg_id: ID number of message to edit or reply. If not specified new message is created.
|
||||
:param reply: If 'True' reply to existing message is created instead of editing it
|
||||
:param attributes: Dictionary of attributes. Following attributes are used internally by the elog and will be
|
||||
ignored: Text, Date, Encoding, Reply to, In reply to, Locked by, Attachment
|
||||
:param attachments: list of:
|
||||
- file like objects which read() will return bytes (if file_like_object.name is not
|
||||
defined, default name "attachment<i>" will be used.
|
||||
- paths to the files
|
||||
All items will be appended as attachment to the elog entry. In case of unknown
|
||||
attachment an exception LogbookInvalidAttachment will be raised.
|
||||
:param encoding: Defines encoding of the message. Can be: 'plain' -> plain text, 'html'->html-text,
|
||||
'ELCode' --> elog formatting syntax
|
||||
:param kwargs: Anything in the kwargs will be interpreted as attribute. e.g.: logbook.post('Test text',
|
||||
Author='Rok Vintar), "Author" will be sent as an attribute. If named same as one of the
|
||||
attributes defined in "attributes", kwargs will have priority.
|
||||
|
||||
:return: msg_id
|
||||
"""
|
||||
|
||||
attributes = attributes or {}
|
||||
attributes = {**attributes, **kwargs} # kwargs as attributes with higher priority
|
||||
|
||||
attachments = attachments or []
|
||||
|
||||
if encoding is not None:
|
||||
if encoding not in ['plain', 'HTML', 'ELCode']:
|
||||
raise LogbookMessageRejected('Invalid message encoding. Valid options: plain, HTML, ELCode.')
|
||||
attributes['Encoding'] = encoding
|
||||
|
||||
attributes_to_edit = dict()
|
||||
if msg_id:
|
||||
# Message exists, we can continue
|
||||
if reply:
|
||||
# Verify that there is a message on the server, otherwise do not reply to it!
|
||||
self._check_if_message_on_server(msg_id) # raises exception in case of none existing message
|
||||
|
||||
attributes['reply_to'] = str(msg_id)
|
||||
|
||||
else: # Edit existing
|
||||
attributes['edit_id'] = str(msg_id)
|
||||
attributes['skiplock'] = '1'
|
||||
|
||||
# Handle existing attachments
|
||||
msg_to_edit, attributes_to_edit, attach_to_edit = self.read(msg_id)
|
||||
|
||||
i = 0
|
||||
for attachment in attach_to_edit:
|
||||
if attachment:
|
||||
# Existing attachments must be passed as regular arguments attachment<i> with value= file name
|
||||
# Read message returnes full urls to existing attachments:
|
||||
# <hostname>:[<port>][/<subdir]/<logbook>/<msg_id>/<file_name>
|
||||
attributes['attachment' + str(i)] = os.path.basename(attachment)
|
||||
i += 1
|
||||
|
||||
for attribute, data in attributes.items():
|
||||
new_data = attributes.get(attribute)
|
||||
if new_data is not None:
|
||||
attributes_to_edit[attribute] = new_data
|
||||
else:
|
||||
# As we create a new message, specify creation time if not already specified in attributes
|
||||
if 'When' not in attributes:
|
||||
attributes['When'] = int(datetime.now().timestamp())
|
||||
|
||||
if not attributes_to_edit:
|
||||
attributes_to_edit = attributes
|
||||
# Remove any attributes that should not be sent
|
||||
_remove_reserved_attributes(attributes_to_edit)
|
||||
|
||||
if attachments:
|
||||
files_to_attach, objects_to_close = self._prepare_attachments(attachments)
|
||||
else:
|
||||
objects_to_close = list()
|
||||
files_to_attach = list()
|
||||
|
||||
# Make requests module think that Text is a "file". This is the only way to force requests to send data as
|
||||
# multipart/form-data even if there are no attachments. Elog understands only multipart/form-data
|
||||
files_to_attach.append(('Text', ('', message)))
|
||||
|
||||
# Base attributes are common to all messages
|
||||
self._add_base_msg_attributes(attributes_to_edit)
|
||||
|
||||
# Keys in attributes cannot have certain characters like whitespaces or dashes for the http request
|
||||
attributes_to_edit = _replace_special_characters_in_attribute_keys(attributes_to_edit)
|
||||
|
||||
try:
|
||||
response = requests.post(self._url, data=attributes_to_edit, files=files_to_attach, allow_redirects=False,
|
||||
verify=False)
|
||||
# Validate response. Any problems will raise an Exception.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
|
||||
# Close file like objects that were opened by the elog (if path
|
||||
for file_like_object in objects_to_close:
|
||||
if hasattr(file_like_object, 'close'):
|
||||
file_like_object.close()
|
||||
|
||||
except requests.RequestException as e:
|
||||
# Check if message on server.
|
||||
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
|
||||
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to post a message, ' + 'because of:\n' +
|
||||
'{0}'.format(e))
|
||||
|
||||
# Any error before here should raise an exception, but check again for nay case.
|
||||
if not resp_msg_id or resp_msg_id < 1:
|
||||
raise LogbookInvalidMessageID('Invalid message ID: ' + str(resp_msg_id) + ' returned')
|
||||
return resp_msg_id
|
||||
|
||||
def read(self, msg_id):
|
||||
"""
|
||||
Reads message from the logbook server and returns tuple of (message, attributes, attachments) where:
|
||||
message: string with message body
|
||||
attributes: dictionary of all attributes returned by the logbook
|
||||
attachments: list of urls to attachments on the logbook server
|
||||
|
||||
:param msg_id: ID of the message to be read
|
||||
:return: message, attributes, attachments
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
|
||||
response = requests.get(self._url + str(msg_id) + '?cmd=download', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read the message with ID: ' + str(msg_id) +
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
# Parse message to separate message body, attributes and attachments
|
||||
attributes = dict()
|
||||
attachments = list()
|
||||
|
||||
returned_msg = resp_message.decode('utf-8', 'ignore').splitlines()
|
||||
delimiter_idx = returned_msg.index('========================================')
|
||||
|
||||
message = '\n'.join(returned_msg[delimiter_idx + 1:])
|
||||
for line in returned_msg[0:delimiter_idx]:
|
||||
line = line.split(': ')
|
||||
data = ''.join(line[1:])
|
||||
if line[0] == 'Attachment':
|
||||
attachments = data.split(',')
|
||||
# Here are only attachment names, make a full url out of it, so they could be
|
||||
# recognisable by others, and downloaded if needed
|
||||
attachments = [self._url + '{0}'.format(i) for i in attachments]
|
||||
else:
|
||||
attributes[line[0]] = data
|
||||
|
||||
return message, attributes, attachments
|
||||
|
||||
def delete(self, msg_id):
|
||||
"""
|
||||
Deletes message thread (!!!message + all replies!!!) from logbook.
|
||||
It also deletes all of attachments of corresponding messages from the server.
|
||||
|
||||
:param msg_id: message to be deleted
|
||||
:return:
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
self._check_if_message_on_server(msg_id) # check if something to delete
|
||||
|
||||
response = requests.get(self._url + str(msg_id) + '?cmd=Delete&confirm=Yes', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
_validate_response(response) # raises exception if any other error identified
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to delete the message with ID: ' + str(msg_id) +
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
# Additional validation: If successfully deleted then status_code = 302. In case command was not executed at
|
||||
# all (not English language --> no download command supported) status_code = 200 and the content is just a
|
||||
# html page of this whole message.
|
||||
if response.status_code == 200:
|
||||
raise LogbookServerProblem('Cannot process delete command (only logbooks in English supported).')
|
||||
|
||||
def search(self, search_term, n_results = 20, scope="subtext"):
|
||||
"""
|
||||
Searches the logbook and returns the message ids.
|
||||
|
||||
"""
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
# Putting n_results = 0 crashes the elog. also in the web-gui.
|
||||
n_results = 1 if n_results < 1 else n_results
|
||||
|
||||
params = {
|
||||
"mode": "full",
|
||||
"reverse": "1",
|
||||
"npp": n_results,
|
||||
scope: search_term
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(self._url, params=params, headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
_validate_response(response)
|
||||
resp_message = response
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read message ids '
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
from lxml import html
|
||||
tree = html.fromstring(resp_message.content)
|
||||
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
|
||||
message_ids = [int(m.split("/")[-1]) for m in message_ids]
|
||||
return message_ids
|
||||
|
||||
|
||||
def get_last_message_id(self):
|
||||
ids = self.get_message_ids()
|
||||
if len(ids) > 0:
|
||||
return ids[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_message_ids(self):
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
response = requests.get(self._url + 'page', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
_validate_response(response)
|
||||
resp_message = response
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read message ids '
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
from lxml import html
|
||||
tree = html.fromstring(resp_message.content)
|
||||
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
|
||||
message_ids = [int(m.split("/")[-1]) for m in message_ids]
|
||||
return message_ids
|
||||
|
||||
def _check_if_message_on_server(self, msg_id):
|
||||
"""Try to load page for specific message. If there is a htm tag like <td class="errormsg"> then there is no
|
||||
such message.
|
||||
|
||||
:param msg_id: ID of message to be checked
|
||||
:return:
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
try:
|
||||
response = requests.get(self._url + str(msg_id), headers=request_headers, allow_redirects=False,
|
||||
verify=False)
|
||||
|
||||
# If there is no message code 200 will be returned (OK) and _validate_response will not recognise it
|
||||
# but there will be some error in the html code.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
# If there is no message, code 200 will be returned (OK) but there will be some error indication in
|
||||
# the html code.
|
||||
if re.findall('<td.*?class="errormsg".*?>.*?</td>',
|
||||
resp_message.decode('utf-8', 'ignore'),
|
||||
flags=re.DOTALL):
|
||||
raise LogbookInvalidMessageID('Message with ID: ' + str(msg_id) + ' does not exist on logbook.')
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise LogbookServerProblem('No response from the logbook server.\nDetails: ' + '{0}'.format(e))
|
||||
|
||||
def _add_base_msg_attributes(self, data):
|
||||
"""
|
||||
Adds base message attributes which are used by all messages.
|
||||
:param data: dict of current attributes
|
||||
:return: content string
|
||||
"""
|
||||
data['cmd'] = 'Submit'
|
||||
data['exp'] = self.logbook
|
||||
if self._user:
|
||||
data['unm'] = self._user
|
||||
if self._password:
|
||||
data['upwd'] = self._password
|
||||
|
||||
def _prepare_attachments(self, files):
|
||||
"""
|
||||
Parses attachments to content objects. Attachments can be:
|
||||
- file like objects: must have method read() which returns bytes. If it has attribute .name it will be used
|
||||
for attachment name, otherwise generic attribute<i> name will be used.
|
||||
- path to the file on disk
|
||||
|
||||
Note that if attachment is is an url pointing to the existing Logbook server it will be ignored and no
|
||||
exceptions will be raised. This can happen if attachments returned with read_method are resend.
|
||||
|
||||
:param files: list of file like objects or paths
|
||||
:return: content string
|
||||
"""
|
||||
prepared = list()
|
||||
i = 0
|
||||
objects_to_close = list() # objects that are created (opened) by elog must be later closed
|
||||
for file_obj in files:
|
||||
if hasattr(file_obj, 'read'):
|
||||
i += 1
|
||||
attribute_name = 'attfile' + str(i)
|
||||
|
||||
filename = attribute_name # If file like object has no name specified use this one
|
||||
candidate_filename = os.path.basename(file_obj.name)
|
||||
|
||||
if filename: # use only if not empty string
|
||||
filename = candidate_filename
|
||||
|
||||
elif isinstance(file_obj, str):
|
||||
# Check if it is:
|
||||
# - a path to the file --> open file and append
|
||||
# - an url pointing to the existing Logbook server --> ignore
|
||||
|
||||
filename = ""
|
||||
attribute_name = ""
|
||||
|
||||
if os.path.isfile(file_obj):
|
||||
i += 1
|
||||
attribute_name = 'attfile' + str(i)
|
||||
|
||||
file_obj = builtins.open(file_obj, 'rb')
|
||||
filename = os.path.basename(file_obj.name)
|
||||
|
||||
objects_to_close.append(file_obj)
|
||||
|
||||
elif not file_obj.startswith(self._url):
|
||||
raise LogbookInvalidAttachmentType('Invalid type of attachment: \"' + file_obj + '\".')
|
||||
else:
|
||||
raise LogbookInvalidAttachmentType('Invalid type of attachment[' + str(i) + '].')
|
||||
|
||||
prepared.append((attribute_name, (filename, file_obj)))
|
||||
|
||||
return prepared, objects_to_close
|
||||
|
||||
def _make_user_and_pswd_cookie(self):
|
||||
"""
|
||||
prepares user name and password cookie. It is sent in header when posting a message.
|
||||
:return: user name and password value for the Cookie header
|
||||
"""
|
||||
cookie = ''
|
||||
if self._user:
|
||||
cookie += 'unm=' + self._user + ';'
|
||||
if self._password:
|
||||
cookie += 'upwd=' + self._password + ';'
|
||||
|
||||
return cookie
|
||||
|
||||
|
||||
def _remove_reserved_attributes(attributes):
|
||||
"""
|
||||
Removes elog reserved attributes (from the attributes dict) that can not be sent.
|
||||
|
||||
:param attributes: dictionary of attributes to be cleaned.
|
||||
:return:
|
||||
"""
|
||||
|
||||
if attributes:
|
||||
attributes.get('$@MID@$', None)
|
||||
attributes.pop('Date', None)
|
||||
attributes.pop('Attachment', None)
|
||||
attributes.pop('Text', None) # Remove this one because it will be send attachment like
|
||||
|
||||
|
||||
def _replace_special_characters_in_attribute_keys(attributes):
|
||||
"""
|
||||
Replaces special characters in elog attribute keys by underscore, otherwise attribute values will be erased in
|
||||
the http request. This is using the same replacement elog itself is using to handle these cases
|
||||
|
||||
:param attributes: dictionary of attributes to be cleaned.
|
||||
:return: attributes with replaced keys
|
||||
"""
|
||||
return {re.sub('[^0-9a-zA-Z]', '_', key): value for key, value in attributes.items()}
|
||||
|
||||
|
||||
def _validate_response(response):
|
||||
""" Validate response of the request."""
|
||||
|
||||
msg_id = None
|
||||
|
||||
if response.status_code not in [200, 302]:
|
||||
# 200 --> OK; 302 --> Found
|
||||
# Html page is returned with error description (handling errors same way as on original client. Looks
|
||||
# like there is no other way.
|
||||
|
||||
err = re.findall('<td.*?class="errormsg".*?>.*?</td>',
|
||||
response.content.decode('utf-8', 'ignore'),
|
||||
flags=re.DOTALL)
|
||||
|
||||
if len(err) > 0:
|
||||
# Remove html tags
|
||||
# If part of the message has: Please go back... remove this part since it is an instruction for
|
||||
# the user when using browser.
|
||||
err = re.sub('(?:<.*?>)', '', err[0])
|
||||
if err:
|
||||
raise LogbookMessageRejected('Rejected because of: ' + err)
|
||||
else:
|
||||
raise LogbookMessageRejected('Rejected because of unknown error.')
|
||||
|
||||
# Other unknown errors
|
||||
raise LogbookMessageRejected('Rejected because of unknown error.')
|
||||
else:
|
||||
location = response.headers.get('Location')
|
||||
if location is not None:
|
||||
if 'has moved' in location:
|
||||
raise LogbookServerProblem('Logbook server has moved to another location.')
|
||||
elif 'fail' in location:
|
||||
raise LogbookAuthenticationError('Invalid username or password.')
|
||||
else:
|
||||
# returned locations is something like: '<host>/<sub_dir>/<logbook>/<msg_id><query>
|
||||
# with urllib.parse.urlparse returns attribute path=<sub_dir>/<logbook>/<msg_id>
|
||||
msg_id = int(urllib.parse.urlsplit(location).path.split('/')[-1])
|
||||
|
||||
if b'form name=form1' in response.content or b'type=password' in response.content:
|
||||
# Not to smart to check this way, but no other indication of this kind of error.
|
||||
# C client does it the same way
|
||||
raise LogbookAuthenticationError('Invalid username or password.')
|
||||
|
||||
return response.content, response.headers, msg_id
|
||||
|
||||
|
||||
def _handle_pswd(password, encrypt=True):
|
||||
"""
|
||||
Takes password string and returns password as needed by elog. If encrypt=True then password will be
|
||||
sha256 encrypted (salt='', rounds=5000). Before returning password, any trailing $5$$ will be removed
|
||||
independent off encrypt flag.
|
||||
|
||||
:param password: password string
|
||||
:param encrypt: encrypt password?
|
||||
:return: elog prepared password
|
||||
"""
|
||||
if encrypt and password:
|
||||
from passlib.hash import sha256_crypt
|
||||
return sha256_crypt.encrypt(password, salt='', rounds=5000)[4:]
|
||||
elif password and password.startswith('$5$$'):
|
||||
return password[4:]
|
||||
else:
|
||||
return password
|
||||
28
packages/elog/logbook_exceptions.py
Normal file
28
packages/elog/logbook_exceptions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
class LogbookError(Exception):
|
||||
""" Parent logbook exception."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookAuthenticationError(LogbookError):
|
||||
""" Raise when problem with username and password."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookServerProblem(LogbookError):
|
||||
""" Raise when problem accessing logbook server."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookMessageRejected(LogbookError):
|
||||
""" Raised when manipulating/creating message was rejected by the server or there was problem composing message."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookInvalidMessageID(LogbookMessageRejected):
|
||||
""" Raised when there is no message with specified ID on the server."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookInvalidAttachmentType(LogbookMessageRejected):
|
||||
""" Raised when passed attachment has invalid type."""
|
||||
pass
|
||||
1
packages/pyscan.pth
Executable file
1
packages/pyscan.pth
Executable file
@@ -0,0 +1 @@
|
||||
./pyscan-2.8.0-py3.7.egg
|
||||
19
packages/pyscan/__init__.py
Normal file
19
packages/pyscan/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Import the scan part.
|
||||
from .scan import *
|
||||
from .scan_parameters import *
|
||||
from .scan_actions import *
|
||||
from .scanner import *
|
||||
|
||||
# Import DALs
|
||||
from .dal.epics_dal import *
|
||||
from .dal.bsread_dal import *
|
||||
from .dal.pshell_dal import *
|
||||
|
||||
# Import positioners.
|
||||
from .positioner.line import *
|
||||
from .positioner.serial import *
|
||||
from .positioner.vector import *
|
||||
from .positioner.area import *
|
||||
from .positioner.compound import *
|
||||
from .positioner.time import *
|
||||
from .positioner.static import *
|
||||
58
packages/pyscan/config.py
Normal file
58
packages/pyscan/config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#########################
|
||||
# General configuration #
|
||||
#########################
|
||||
|
||||
# Minimum tolerance for comparing floats.
|
||||
max_float_tolerance = 0.00001
|
||||
# 1ms time tolerance for time critical measurements.
|
||||
max_time_tolerance = 0.05
|
||||
|
||||
######################
|
||||
# Scan configuration #
|
||||
######################
|
||||
|
||||
# Default number of scans.
|
||||
scan_default_n_measurements = 1
|
||||
# Default interval between multiple measurements in a single position. Taken into account when n_measurements > 1.
|
||||
scan_default_measurement_interval = 0
|
||||
# Interval to sleep while the scan is paused.
|
||||
scan_pause_sleep_interval = 0.1
|
||||
# Maximum number of retries to read the channels to get valid data.
|
||||
scan_acquisition_retry_limit = 3
|
||||
# Delay between acquisition retries.
|
||||
scan_acquisition_retry_delay = 1
|
||||
|
||||
############################
|
||||
# BSREAD DAL configuration #
|
||||
############################
|
||||
|
||||
# Queue size for collecting messages from bs_read.
|
||||
bs_queue_size = 20
|
||||
# Max time to wait until the bs read message we need arrives.
|
||||
bs_read_timeout = 5
|
||||
# Max time to wait for a message (if there is none). Important for stopping threads etc.
|
||||
bs_receive_timeout = 1
|
||||
|
||||
# Default bs_read connection address.
|
||||
bs_default_host = None
|
||||
# Default bs_read connection port.
|
||||
bs_default_port = None
|
||||
# Default bs connection port.
|
||||
bs_connection_mode = "sub"
|
||||
# Default property value for bs properties missing in stream. Exception means to raise an Exception when this happens.
|
||||
bs_default_missing_property_value = Exception
|
||||
|
||||
###########################
|
||||
# EPICS DAL configuration #
|
||||
###########################
|
||||
|
||||
# Default set and match timeout - how much time a PV has to reach the target value.
|
||||
epics_default_set_and_match_timeout = 3
|
||||
# After all motors have reached their destination (set_and_match), extra time to wait.
|
||||
epics_default_settling_time = 0
|
||||
|
||||
############################
|
||||
# PShell DAL configuration #
|
||||
############################
|
||||
pshell_default_server_url = "http://sf-daq-mgmt:8090"
|
||||
pshell_default_scan_in_background = False
|
||||
0
packages/pyscan/dal/__init__.py
Normal file
0
packages/pyscan/dal/__init__.py
Normal file
186
packages/pyscan/dal/bsread_dal.py
Normal file
186
packages/pyscan/dal/bsread_dal.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import math
|
||||
from time import time
|
||||
|
||||
from bsread import Source, mflow
|
||||
|
||||
from pyscan import config
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class ReadGroupInterface(object):
|
||||
"""
|
||||
Provide a beam synchronous acquisition for PV data.
|
||||
"""
|
||||
|
||||
def __init__(self, properties, conditions=None, host=None, port=None, filter_function=None):
|
||||
"""
|
||||
Create the bsread group read interface.
|
||||
:param properties: List of PVs to read for processing.
|
||||
:param conditions: List of PVs to read as conditions.
|
||||
:param filter_function: Filter the BS stream with a custom function.
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.properties = convert_to_list(properties)
|
||||
self.conditions = convert_to_list(conditions)
|
||||
self.filter = filter_function
|
||||
|
||||
self._message_cache = None
|
||||
self._message_cache_timestamp = None
|
||||
self._message_cache_position_index = None
|
||||
|
||||
self._connect_bsread(config.bs_default_host, config.bs_default_port)
|
||||
|
||||
def _connect_bsread(self, host, port):
|
||||
# Configure the connection type.
|
||||
if config.bs_connection_mode.lower() == "sub":
|
||||
mode = mflow.SUB
|
||||
elif config.bs_connection_mode.lower() == "pull":
|
||||
mode = mflow.PULL
|
||||
|
||||
if host and port:
|
||||
self.stream = Source(host=host,
|
||||
port=port,
|
||||
queue_size=config.bs_queue_size,
|
||||
receive_timeout=config.bs_receive_timeout,
|
||||
mode=mode)
|
||||
else:
|
||||
channels = [x.identifier for x in self.properties] + [x.identifier for x in self.conditions]
|
||||
self.stream = Source(channels=channels,
|
||||
queue_size=config.bs_queue_size,
|
||||
receive_timeout=config.bs_receive_timeout,
|
||||
mode=mode)
|
||||
self.stream.connect()
|
||||
|
||||
@staticmethod
|
||||
def is_message_after_timestamp(message, timestamp):
|
||||
"""
|
||||
Check if the received message was captured after the provided timestamp.
|
||||
:param message: Message to inspect.
|
||||
:param timestamp: Timestamp to compare the message to.
|
||||
:return: True if the message is after the timestamp, False otherwise.
|
||||
"""
|
||||
# Receive might timeout, in this case we have nothing to compare.
|
||||
if not message:
|
||||
return False
|
||||
|
||||
# This is how BSread encodes the timestamp.
|
||||
current_sec = int(timestamp)
|
||||
current_ns = int(math.modf(timestamp)[0] * 1e9)
|
||||
|
||||
message_sec = message.data.global_timestamp
|
||||
message_ns = message.data.global_timestamp_offset
|
||||
|
||||
# If the seconds are the same, the nanoseconds must be equal or larger.
|
||||
if message_sec == current_sec:
|
||||
return message_ns >= current_ns
|
||||
# If the seconds are not the same, the message seconds need to be larger than the current seconds.
|
||||
else:
|
||||
return message_sec > current_sec
|
||||
|
||||
@staticmethod
|
||||
def _get_missing_property_default(property_definition):
|
||||
"""
|
||||
In case a bs read value is missing, either return the default value or raise an Exception.
|
||||
:param property_definition:
|
||||
:return:
|
||||
"""
|
||||
# Exception is defined, raise it.
|
||||
if Exception == property_definition.default_value:
|
||||
raise property_definition.default_value("Property '%s' missing in bs stream."
|
||||
% property_definition.identifier)
|
||||
# Else just return the default value.
|
||||
else:
|
||||
return property_definition.default_value
|
||||
|
||||
def _read_pvs_from_cache(self, properties):
|
||||
"""
|
||||
Read the requested properties from the cache.
|
||||
:param properties: List of properties to read.
|
||||
:return: List with PV values.
|
||||
"""
|
||||
if not self._message_cache:
|
||||
raise ValueError("Message cache is empty, cannot read PVs %s." % properties)
|
||||
|
||||
pv_values = []
|
||||
for property_name, property_definition in ((x.identifier, x) for x in properties):
|
||||
if property_name in self._message_cache.data.data:
|
||||
value = self._message_cache.data.data[property_name].value
|
||||
else:
|
||||
value = self._get_missing_property_default(property_definition)
|
||||
|
||||
# TODO: Check if the python conversion works in every case?
|
||||
# BS read always return numpy, and we always convert it to Python.
|
||||
pv_values.append(value)
|
||||
|
||||
return pv_values
|
||||
|
||||
def read(self, current_position_index=None, retry=False):
|
||||
"""
|
||||
Reads the PV values from BSread. It uses the first PVs data sampled after the invocation of this method.
|
||||
:return: List of values for read pvs. Note: Condition PVs are excluded.
|
||||
"""
|
||||
|
||||
# Perform the actual read.
|
||||
read_timestamp = time()
|
||||
while time() - read_timestamp < config.bs_read_timeout:
|
||||
|
||||
message = self.stream.receive(filter=self.filter)
|
||||
|
||||
if self.is_message_after_timestamp(message, read_timestamp):
|
||||
|
||||
self._message_cache = message
|
||||
self._message_cache_position_index = current_position_index
|
||||
self._message_cache_timestamp = read_timestamp
|
||||
|
||||
return self._read_pvs_from_cache(self.properties)
|
||||
|
||||
else:
|
||||
raise Exception("Read timeout exceeded for BS read stream. Could not find the desired package in time.")
|
||||
|
||||
def read_cached_conditions(self):
|
||||
"""
|
||||
Returns the conditions associated with the last read command.
|
||||
:return: List of condition values.
|
||||
"""
|
||||
return self._read_pvs_from_cache(self.conditions)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Disconnect from the stream and clear the message cache.
|
||||
"""
|
||||
if self.stream:
|
||||
self.stream.disconnect()
|
||||
|
||||
self._message_cache = None
|
||||
self._message_cache_timestamp = None
|
||||
|
||||
|
||||
class ImmediateReadGroupInterface(ReadGroupInterface):
|
||||
|
||||
@staticmethod
|
||||
def is_message_after_timestamp(message, timestamp):
|
||||
"""
|
||||
Every message is a good message, expect a NULL one.
|
||||
:param message: Message to inspect.
|
||||
:param timestamp: Timestamp to compare the message to.
|
||||
:return: True if the message is not None.
|
||||
"""
|
||||
# Receive might timeout, in this case we have nothing to compare.
|
||||
if not message:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def read(self, current_position_index=None, retry=False):
|
||||
|
||||
# Invalidate cache on retry attempt.
|
||||
if retry:
|
||||
self._message_cache_position_index = None
|
||||
|
||||
# Message for this position already cached.
|
||||
if current_position_index is not None and current_position_index == self._message_cache_position_index:
|
||||
return self._read_pvs_from_cache(self.properties)
|
||||
|
||||
return super(ImmediateReadGroupInterface, self).read(current_position_index=current_position_index,
|
||||
retry=retry)
|
||||
208
packages/pyscan/dal/epics_dal.py
Normal file
208
packages/pyscan/dal/epics_dal.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import time
|
||||
from itertools import count
|
||||
|
||||
from pyscan import config
|
||||
from pyscan.utils import convert_to_list, validate_lists_length, connect_to_pv, compare_channel_value
|
||||
|
||||
|
||||
class PyEpicsDal(object):
|
||||
"""
|
||||
Provide a high level abstraction over PyEpics with group support.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.groups = {}
|
||||
self.pvs = {}
|
||||
|
||||
def add_group(self, group_name, group_interface):
|
||||
# Do not allow to overwrite the group.
|
||||
if group_name in self.groups:
|
||||
raise ValueError("Group with name %s already exists. "
|
||||
"Use different name of close existing group first." % group_name)
|
||||
|
||||
self.groups[group_name] = group_interface
|
||||
return group_name
|
||||
|
||||
def add_reader_group(self, group_name, pv_names):
|
||||
self.add_group(group_name, ReadGroupInterface(pv_names))
|
||||
|
||||
def add_writer_group(self, group_name, pv_names, readback_pv_names=None, tolerances=None, timeout=None):
|
||||
self.add_group(group_name, WriteGroupInterface(pv_names, readback_pv_names, tolerances, timeout))
|
||||
|
||||
def get_group(self, handle):
|
||||
return self.groups.get(handle)
|
||||
|
||||
def close_group(self, group_name):
|
||||
if group_name not in self.groups:
|
||||
raise ValueError("Group does not exist. Available groups:\n%s" % self.groups.keys())
|
||||
|
||||
# Close the PV connection.
|
||||
self.groups[group_name].close()
|
||||
del self.groups[group_name]
|
||||
|
||||
def close_all_groups(self):
|
||||
for group in self.groups.values():
|
||||
group.close()
|
||||
self.groups.clear()
|
||||
|
||||
|
||||
class WriteGroupInterface(object):
|
||||
"""
|
||||
Manage a group of Write PVs.
|
||||
"""
|
||||
default_timeout = 5
|
||||
default_get_sleep = 0.1
|
||||
|
||||
def __init__(self, pv_names, readback_pv_names=None, tolerances=None, timeout=None):
|
||||
"""
|
||||
Initialize the write group.
|
||||
:param pv_names: PV names (or name, list or single string) to connect to.
|
||||
:param readback_pv_names: PV names (or name, list or single string) of readback PVs to connect to.
|
||||
:param tolerances: Tolerances to be used for set_and_match. You can also specify them on the set_and_match
|
||||
:param timeout: Timeout to reach the destination.
|
||||
"""
|
||||
self.pv_names = convert_to_list(pv_names)
|
||||
self.pvs = [self.connect(pv_name) for pv_name in self.pv_names]
|
||||
|
||||
if readback_pv_names:
|
||||
self.readback_pv_name = convert_to_list(readback_pv_names)
|
||||
self.readback_pvs = [self.connect(pv_name) for pv_name in self.readback_pv_name]
|
||||
else:
|
||||
self.readback_pv_name = self.pv_names
|
||||
self.readback_pvs = self.pvs
|
||||
|
||||
self.tolerances = self._setup_tolerances(tolerances)
|
||||
|
||||
# We also do not allow timeout to be zero.
|
||||
self.timeout = timeout or self.default_timeout
|
||||
|
||||
# Verify if all provided lists are of same size.
|
||||
validate_lists_length(self.pvs, self.readback_pvs, self.tolerances)
|
||||
|
||||
# Check if timeout is int or float.
|
||||
if not isinstance(self.timeout, (int, float)):
|
||||
raise ValueError("Timeout must be int or float, but %s was provided." % self.timeout)
|
||||
|
||||
def _setup_tolerances(self, tolerances):
|
||||
"""
|
||||
Construct the list of tolerances. No tolerance can be less then the minimal tolerance.
|
||||
:param tolerances: Input tolerances.
|
||||
:return: Tolerances adjusted to the minimum value, if needed.
|
||||
"""
|
||||
# If the provided tolerances are empty, substitute them with a list of default tolerances.
|
||||
tolerances = convert_to_list(tolerances) or [config.max_float_tolerance] * len(self.pvs)
|
||||
# Each tolerance needs to be at least the size of the minimum tolerance.
|
||||
tolerances = [max(config.max_float_tolerance, tolerance) for tolerance in tolerances]
|
||||
|
||||
return tolerances
|
||||
|
||||
def set_and_match(self, values, tolerances=None, timeout=None):
|
||||
"""
|
||||
Set the value and wait for the PV to reach it, within tollerance.
|
||||
:param values: Values to set (Must match the number of PVs in this group)
|
||||
:param tolerances: Tolerances for each PV (Must match the number of PVs in this group)
|
||||
:param timeout: Timeout, single value, to wait until the value is reached.
|
||||
:raise ValueError if any position cannot be reached.
|
||||
"""
|
||||
values = convert_to_list(values)
|
||||
if not tolerances:
|
||||
tolerances = self.tolerances
|
||||
else:
|
||||
# We do not allow tolerances to be less than the default tolerance.
|
||||
tolerances = self._setup_tolerances(tolerances)
|
||||
if not timeout:
|
||||
timeout = self.timeout
|
||||
|
||||
# Verify if all provided lists are of same size.
|
||||
validate_lists_length(self.pvs, values, tolerances)
|
||||
|
||||
# Check if timeout is int or float.
|
||||
if not isinstance(timeout, (int, float)):
|
||||
raise ValueError("Timeout must be int or float, but %s was provided." % timeout)
|
||||
|
||||
# Write all the PV values.
|
||||
for pv, value in zip(self.pvs, values):
|
||||
pv.put(value)
|
||||
|
||||
# Boolean array to represent which PVs have reached their target value.s
|
||||
within_tolerance = [False] * len(self.pvs)
|
||||
initial_timestamp = time.time()
|
||||
|
||||
# Read values until all PVs have reached the desired value or time has run out.
|
||||
while (not all(within_tolerance)) and (time.time() - initial_timestamp < timeout):
|
||||
# Get only the PVs that have not yet reached the final position.
|
||||
for index, pv, tolerance in ((index, pv, tolerance) for index, pv, tolerance, values_reached
|
||||
in zip(count(), self.readback_pvs, tolerances, within_tolerance)
|
||||
if not values_reached):
|
||||
|
||||
current_value = pv.get()
|
||||
expected_value = values[index]
|
||||
|
||||
if compare_channel_value(current_value, expected_value, tolerance):
|
||||
within_tolerance[index] = True
|
||||
|
||||
time.sleep(self.default_get_sleep)
|
||||
|
||||
if not all(within_tolerance):
|
||||
error_message = ""
|
||||
# Get the indexes that did not reach the supposed values.
|
||||
for index in [index for index, reached_value in enumerate(within_tolerance) if not reached_value]:
|
||||
expected_value = values[index]
|
||||
pv_name = self.pv_names[index]
|
||||
tolerance = tolerances[index]
|
||||
|
||||
error_message += "Cannot achieve value %s, on PV %s, with tolerance %s.\n" % \
|
||||
(expected_value, pv_name, tolerance)
|
||||
|
||||
raise ValueError(error_message)
|
||||
|
||||
@staticmethod
|
||||
def connect(pv_name):
|
||||
return connect_to_pv(pv_name)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close all PV connections.
|
||||
"""
|
||||
for pv in self.pvs:
|
||||
pv.disconnect()
|
||||
|
||||
|
||||
class ReadGroupInterface(object):
|
||||
"""
|
||||
Manage group of read PVs.
|
||||
"""
|
||||
|
||||
def __init__(self, pv_names):
|
||||
"""
|
||||
Initialize the group.
|
||||
:param pv_names: PV names (or name, list or single string) to connect to.
|
||||
"""
|
||||
self.pv_names = convert_to_list(pv_names)
|
||||
self.pvs = [self.connect(pv_name) for pv_name in self.pv_names]
|
||||
|
||||
def read(self, current_position_index=None, retry=None):
|
||||
"""
|
||||
Read PVs one by one.
|
||||
:param current_position_index: Index of the current scan.
|
||||
:param retry: Is this the first read attempt or a retry.
|
||||
:return: Result
|
||||
"""
|
||||
result = []
|
||||
for pv in self.pvs:
|
||||
result.append(pv.get())
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def connect(pv_name):
|
||||
return connect_to_pv(pv_name)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close all PV connections.
|
||||
"""
|
||||
for pv in self.pvs:
|
||||
pv.disconnect()
|
||||
|
||||
|
||||
40
packages/pyscan/dal/function_dal.py
Normal file
40
packages/pyscan/dal/function_dal.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class FunctionProxy(object):
|
||||
"""
|
||||
Provide an interface for using external methods as DAL.
|
||||
"""
|
||||
def __init__(self, functions):
|
||||
"""
|
||||
Initialize the function dal.
|
||||
:param functions: List (or single item) of FUNCTION_VALUE type.
|
||||
"""
|
||||
self.functions = convert_to_list(functions)
|
||||
|
||||
def read(self, current_position_index=None, retry=False):
|
||||
"""
|
||||
Read the results from all the provided functions.
|
||||
:return: Read results.
|
||||
"""
|
||||
results = []
|
||||
for func in self.functions:
|
||||
# The function either accepts the current position index, or nothing.
|
||||
try:
|
||||
result = func.call_function()
|
||||
except TypeError:
|
||||
result = func.call_function(current_position_index)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def write(self, values):
|
||||
"""
|
||||
Write the values to the provided functions.
|
||||
:param values: Values to write.
|
||||
"""
|
||||
|
||||
values = convert_to_list(values)
|
||||
for func, value in zip(self.functions, values):
|
||||
func.call_function(value)
|
||||
118
packages/pyscan/dal/pshell_dal.py
Normal file
118
packages/pyscan/dal/pshell_dal.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import requests
|
||||
from bsread.data.helpers import get_channel_reader
|
||||
from pyscan import config
|
||||
|
||||
SERVER_URL_PATHS = {
|
||||
"run": "/run",
|
||||
"data": "/data-bs"
|
||||
}
|
||||
|
||||
|
||||
class PShellFunction(object):
|
||||
|
||||
def __init__(self, script_name, parameters, server_url=None, scan_in_background=None, multiple_parameters=False,
|
||||
return_values=None):
|
||||
if server_url is None:
|
||||
server_url = config.pshell_default_server_url
|
||||
|
||||
if scan_in_background is None:
|
||||
scan_in_background = config.pshell_default_scan_in_background
|
||||
|
||||
self.script_name = script_name
|
||||
self.parameters = parameters
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.scan_in_background = scan_in_background
|
||||
self.multiple_parameters = multiple_parameters
|
||||
self.return_values = return_values
|
||||
|
||||
@staticmethod
|
||||
def _load_raw_data(server_url, data_path):
|
||||
load_data_url = server_url + SERVER_URL_PATHS["data"] + "/" + data_path
|
||||
|
||||
raw_data = requests.get(url=load_data_url, stream=True).raw.read()
|
||||
|
||||
return raw_data
|
||||
|
||||
@classmethod
|
||||
def read_raw_data(cls, data_path, server_url=None):
|
||||
if server_url is None:
|
||||
server_url = config.pshell_default_server_url
|
||||
|
||||
raw_data_bytes = cls._load_raw_data(server_url, data_path)
|
||||
|
||||
offset = 0
|
||||
|
||||
def read_chunk():
|
||||
nonlocal offset
|
||||
nonlocal raw_data_bytes
|
||||
|
||||
size = int.from_bytes(raw_data_bytes[offset:offset + 4], byteorder='big', signed=False)
|
||||
|
||||
# Offset for the size of the length.
|
||||
offset += 4
|
||||
|
||||
data_chunk = raw_data_bytes[offset:offset + size]
|
||||
|
||||
offset += size
|
||||
|
||||
return data_chunk
|
||||
|
||||
# First chunk is main header.
|
||||
main_header = json.loads(read_chunk().decode(), object_pairs_hook=OrderedDict)
|
||||
|
||||
# Second chunk is data header.
|
||||
data_header = json.loads(read_chunk().decode(), object_pairs_hook=OrderedDict)
|
||||
|
||||
result_data = {}
|
||||
|
||||
for channel in data_header["channels"]:
|
||||
raw_channel_data = read_chunk()
|
||||
raw_channel_timestamp = read_chunk()
|
||||
|
||||
channel_name = channel["name"]
|
||||
# Default encoding is small, other valid value is 'big'.
|
||||
channel["encoding"] = "<" if channel.get("encoding", "little") else ">"
|
||||
|
||||
channel_value_reader = get_channel_reader(channel)
|
||||
|
||||
result_data[channel_name] = channel_value_reader(raw_channel_data)
|
||||
|
||||
return result_data
|
||||
|
||||
def read(self, current_position_index=None, retry=False):
|
||||
parameters = self.get_scan_parameters(current_position_index)
|
||||
|
||||
run_request = {"script": self.script_name,
|
||||
"pars": parameters,
|
||||
"background": self.scan_in_background}
|
||||
|
||||
raw_scan_result = self._execute_scan(run_request)
|
||||
scan_result = json.loads(raw_scan_result)
|
||||
|
||||
return scan_result
|
||||
|
||||
def get_scan_parameters(self, current_position_index):
|
||||
|
||||
if self.multiple_parameters:
|
||||
try:
|
||||
position_parameters = self.parameters[current_position_index]
|
||||
except IndexError:
|
||||
raise ValueError("Cannot find parameters for position index %s. Parameters: " %
|
||||
(current_position_index, self.parameters))
|
||||
|
||||
return position_parameters
|
||||
|
||||
else:
|
||||
return self.parameters
|
||||
|
||||
def _execute_scan(self, execution_parameters):
|
||||
run_url = self.server_url + SERVER_URL_PATHS["run"]
|
||||
|
||||
result = requests.put(url=run_url, json=execution_parameters)
|
||||
|
||||
if result.status_code != 200:
|
||||
raise Exception(result.text)
|
||||
|
||||
return result.text
|
||||
0
packages/pyscan/interface/__init__.py
Normal file
0
packages/pyscan/interface/__init__.py
Normal file
385
packages/pyscan/interface/pshell.py
Normal file
385
packages/pyscan/interface/pshell.py
Normal file
@@ -0,0 +1,385 @@
|
||||
from pyscan import scan, action_restore, ZigZagVectorPositioner, VectorPositioner, CompoundPositioner
|
||||
from pyscan.scan import EPICS_READER
|
||||
from pyscan.positioner.area import AreaPositioner, ZigZagAreaPositioner
|
||||
from pyscan.positioner.line import ZigZagLinePositioner, LinePositioner
|
||||
from pyscan.positioner.time import TimePositioner
|
||||
from pyscan.scan_parameters import scan_settings
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
def _generate_scan_parameters(relative, writables, latency):
|
||||
# If the scan is relative we collect the initial writables offset, and restore the state at the end of the scan.
|
||||
offsets = None
|
||||
finalization_action = []
|
||||
if relative:
|
||||
pv_names = [x.pv_name for x in convert_to_list(writables) or []]
|
||||
reader = EPICS_READER(pv_names)
|
||||
offsets = reader.read()
|
||||
reader.close()
|
||||
|
||||
finalization_action.append(action_restore(writables))
|
||||
|
||||
settings = scan_settings(settling_time=latency)
|
||||
|
||||
return offsets, finalization_action, settings
|
||||
|
||||
|
||||
def _convert_steps_parameter(steps):
|
||||
n_steps = None
|
||||
step_size = None
|
||||
|
||||
steps_list = convert_to_list(steps)
|
||||
# If steps is a float or a list of floats, then this are step sizes.
|
||||
if isinstance(steps_list[0], float):
|
||||
step_size = steps_list
|
||||
# If steps is an int, this is the number of steps.
|
||||
elif isinstance(steps, int):
|
||||
n_steps = steps
|
||||
|
||||
return n_steps, step_size
|
||||
|
||||
|
||||
def lscan(writables, readables, start, end, steps, latency=0.0, relative=False,
|
||||
passes=1, zigzag=False, before_read=None, after_read=None, title=None):
|
||||
"""Line Scan: positioners change together, linearly from start to end positions.
|
||||
|
||||
Args:
|
||||
writables(list of Writable): Positioners set on each step.
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
start(list of float): start positions of writables.
|
||||
end(list of float): final positions of writables.
|
||||
steps(int or float or list of float): number of scan steps (int) or step size (float).
|
||||
relative (bool, optional): if true, start and end positions are relative to
|
||||
current at start of the scan
|
||||
latency(float, optional): settling time for each step before readout, defaults to 0.0.
|
||||
passes(int, optional): number of passes
|
||||
zigzag(bool, optional): if true writables invert direction on each pass.
|
||||
before_read (function, optional): callback on each step, before each readout. Callback may have as
|
||||
optional parameters list of positions.
|
||||
after_read (function, optional): callback on each step, after each readout. Callback may have as
|
||||
optional parameters a ScanRecord object.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
|
||||
n_steps, step_size = _convert_steps_parameter(steps)
|
||||
|
||||
if zigzag:
|
||||
positioner_class = ZigZagLinePositioner
|
||||
else:
|
||||
positioner_class = LinePositioner
|
||||
|
||||
positioner = positioner_class(start=start, end=end, step_size=step_size,
|
||||
n_steps=n_steps, offsets=offsets, passes=passes)
|
||||
|
||||
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
|
||||
finalization=finalization_actions)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def ascan(writables, readables, start, end, steps, latency=0.0, relative=False,
|
||||
passes=1, zigzag=False, before_read=None, after_read=None, title=None):
|
||||
"""
|
||||
Area Scan: multi-dimentional scan, each positioner is a dimention.
|
||||
:param writables: List of identifiers to write to at each step.
|
||||
:param readables: List of identifiers to read from at each step.
|
||||
:param start: Start position for writables.
|
||||
:param end: Stop position for writables.
|
||||
:param steps: Number of scan steps(integer) or step size (float).
|
||||
:param latency: Settling time before each readout. Default = 0.
|
||||
:param relative: Start and stop positions are relative to the current position.
|
||||
:param passes: Number of passes for each scan.
|
||||
:param zigzag: If True and passes > 1, invert moving direction on each pass.
|
||||
:param before_read: List of callback functions on each step before readback.
|
||||
:param after_read: List of callback functions on each step after readback.
|
||||
:param title: Not used in this implementation - legacy.
|
||||
:return: Data from the scan.
|
||||
"""
|
||||
|
||||
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
|
||||
n_steps, step_size = _convert_steps_parameter(steps)
|
||||
|
||||
if zigzag:
|
||||
positioner_class = ZigZagAreaPositioner
|
||||
else:
|
||||
positioner_class = AreaPositioner
|
||||
|
||||
positioner = positioner_class(start=start, end=end, step_size=step_size,
|
||||
n_steps=n_steps, offsets=offsets, passes=passes)
|
||||
|
||||
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
|
||||
finalization=finalization_actions)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def vscan(writables, readables, vector, line=False, latency=0.0, relative=False, passes=1, zigzag=False,
|
||||
before_read=None, after_read=None, title=None):
|
||||
"""Vector Scan: positioners change following values provided in a vector.
|
||||
|
||||
Args:
|
||||
writables(list of Writable): Positioners set on each step.
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
vector(list of list of float): table of positioner values.
|
||||
line (bool, optional): if true, processs as line scan (1d)
|
||||
relative (bool, optional): if true, start and end positions are relative to current at
|
||||
start of the scan
|
||||
latency(float, optional): settling time for each step before readout, defaults to 0.0.
|
||||
passes(int, optional): number of passes
|
||||
zigzag(bool, optional): if true writables invert direction on each pass.
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
offsets, finalization_actions, settings = _generate_scan_parameters(relative, writables, latency)
|
||||
|
||||
# The compound positioner does not allow you to do zigzag positioning.
|
||||
if not line and zigzag:
|
||||
raise ValueError("Area vector scan cannot use zigzag positioning.")
|
||||
|
||||
if zigzag:
|
||||
positioner_class = ZigZagVectorPositioner
|
||||
else:
|
||||
positioner_class = VectorPositioner
|
||||
|
||||
# If the vector is treated as a line scan, move all motors to the next position at the same time.
|
||||
if line:
|
||||
positioner = positioner_class(positions=vector, passes=passes, offsets=offsets)
|
||||
# The vector is treated as an area scan. Move motors one by one, covering all positions.
|
||||
else:
|
||||
vector = convert_to_list(vector)
|
||||
if not all(isinstance(x, list) for x in vector):
|
||||
raise ValueError("In case of area scan, a list of lists is required for a vector.")
|
||||
|
||||
positioner = CompoundPositioner([VectorPositioner(positions=x, passes=passes, offsets=offsets)
|
||||
for x in vector])
|
||||
|
||||
result = scan(positioner, readables, writables, before_read=before_read, after_read=after_read, settings=settings,
|
||||
finalization=finalization_actions)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def rscan(writable, readables, regions, latency=0.0, relative=False, passes=1, zigzag=False, before_read=None,
|
||||
after_read=None, title=None):
|
||||
"""Region Scan: positioner scanned linearly, from start to end positions, in multiple regions.
|
||||
|
||||
Args:
|
||||
writable(Writable): Positioner set on each step, for each region.
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
regions (list of tuples (float,float, int) or (float,float, float)): each tuple define a scan region
|
||||
(start, stop, steps) or (start, stop, step_size)
|
||||
relative (bool, optional): if true, start and end positions are relative to
|
||||
current at start of the scan
|
||||
latency(float, optional): settling time for each step before readout, defaults to 0.0.
|
||||
passes(int, optional): number of passes
|
||||
zigzag(bool, optional): if true writable invert direction on each pass.
|
||||
before_read (function, optional): callback on each step, before each readout. Callback may have as
|
||||
optional parameters list of positions.
|
||||
after_read (function, optional): callback on each step, after each readout. Callback may have as
|
||||
optional parameters a ScanRecord object.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Region scan not supported.")
|
||||
|
||||
|
||||
def cscan(writables, readables, start, end, steps, latency=0.0, time=None, relative=False, passes=1, zigzag=False,
|
||||
before_read=None, after_read=None, title=None):
|
||||
"""Continuous Scan: positioner change continuously from start to end position and readables are sampled on the fly.
|
||||
|
||||
Args:
|
||||
writable(Speedable or list of Motor): A positioner with a getSpeed method or
|
||||
a list of motors.
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
start(float or list of float): start positions of writables.
|
||||
end(float or list of float): final positions of writabless.
|
||||
steps(int or float or list of float): number of scan steps (int) or step size (float).
|
||||
time (float, seconds): if not None then writables is Motor array and speeds are
|
||||
set according to time.
|
||||
relative (bool, optional): if true, start and end positions are relative to
|
||||
current at start of the scan
|
||||
latency(float, optional): sleep time in each step before readout, defaults to 0.0.
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
Callback may have as optional parameters list of positions.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
Callback may have as optional parameters a ScanRecord object.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Continuous scan not supported.")
|
||||
|
||||
|
||||
def hscan(config, writable, readables, start, end, steps, passes=1, zigzag=False, before_stream=None, after_stream=None,
|
||||
after_read=None, title=None):
|
||||
"""Hardware Scan: values sampled by external hardware and received asynchronously.
|
||||
|
||||
Args:
|
||||
config(dict): Configuration of the hardware scan. The "class" key provides the implementation class.
|
||||
Other keys are implementation specific.
|
||||
writable(Writable): A positioner appropriated to the hardware scan type.
|
||||
readables(list of Readable): Sensors appropriated to the hardware scan type.
|
||||
start(float): start positions of writable.
|
||||
end(float): final positions of writables.
|
||||
steps(int or float): number of scan steps (int) or step size (float).
|
||||
before_stream (function, optional): callback before just before starting positioner move.
|
||||
after_stream (function, optional): callback before just after stopping positioner move.
|
||||
after_read (function, optional): callback on each readout.
|
||||
Callback may have as optional parameters a ScanRecord object.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Hardware scan not supported.")
|
||||
|
||||
|
||||
def bscan(stream, records, before_read=None, after_read=None, title=None):
|
||||
"""BS Scan: records all values in a beam synchronous stream.
|
||||
|
||||
Args:
|
||||
stream(Stream): stream object
|
||||
records(int): number of records to store
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
Callback may have as optional parameters list of positions.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
Callback may have as optional parameters a ScanRecord object.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("BS scan not supported.")
|
||||
|
||||
|
||||
def tscan(readables, points, interval, before_read=None, after_read=None, title=None):
|
||||
"""Time Scan: sensors are sampled in fixed time intervals.
|
||||
|
||||
Args:
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
points(int): number of samples.
|
||||
interval(float): time interval between readouts. Minimum temporization is 0.001s
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
positioner = TimePositioner(interval, points)
|
||||
result = scan(positioner, readables, before_read=before_read, after_read=after_read)
|
||||
return result
|
||||
|
||||
|
||||
def mscan(trigger, readables, points, timeout=None, async=True, take_initial=False, before_read=None, after_read=None,
|
||||
title=None):
|
||||
"""Monitor Scan: sensors are sampled when received change event of the trigger device.
|
||||
|
||||
Args:
|
||||
trigger(Device): Source of the sampling triggering.
|
||||
readables(list of Readable): Sensors to be sampled on each step.
|
||||
If trigger has cache and is included in readables, it is not read
|
||||
for each step, but the change event value is used.
|
||||
points(int): number of samples.
|
||||
timeout(float, optional): maximum scan time in seconds.
|
||||
async(bool, optional): if True then records are sampled and stored on event change callback. Enforce
|
||||
reading only cached values of sensors.
|
||||
If False, the scan execution loop waits for trigger cache update. Do not make
|
||||
cache only access, but may loose change events.
|
||||
take_initial(bool, optional): if True include current values as first record (before first trigger).
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Monitor scan not supported.")
|
||||
|
||||
|
||||
def escan(name, title=None):
|
||||
"""Epics Scan: execute an Epics Scan Record.
|
||||
|
||||
Args:
|
||||
name(str): Name of scan record.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
ScanResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Epics scan not supported.")
|
||||
|
||||
|
||||
def bsearch(writables, readable, start, end, steps, maximum=True, strategy="Normal", latency=0.0, relative=False,
|
||||
before_read=None, after_read=None, title=None):
|
||||
"""Binary search: searches writables in a binary search fashion to find a local maximum for the readable.
|
||||
|
||||
Args:
|
||||
writables(list of Writable): Positioners set on each step.
|
||||
readable(Readable): Sensor to be sampled.
|
||||
start(list of float): start positions of writables.
|
||||
end(list of float): final positions of writables.
|
||||
steps(float or list of float): resolution of search for each writable.
|
||||
maximum (bool , optional): if True (default) search maximum, otherwise minimum.
|
||||
strategy (str , optional): "Normal": starts search midway to scan range and advance in the best direction.
|
||||
Uses orthogonal neighborhood (4-neighborhood for 2d)
|
||||
"Boundary": starts search on scan range.
|
||||
"FullNeighborhood": Uses complete neighborhood (8-neighborhood for 2d)
|
||||
|
||||
latency(float, optional): settling time for each step before readout, defaults to 0.0.
|
||||
relative (bool, optional): if true, start and end positions are relative to current at
|
||||
start of the scan
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
SearchResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Binary search scan not supported.")
|
||||
|
||||
|
||||
def hsearch(writables, readable, range_min, range_max, initial_step, resolution, noise_filtering_steps=1, maximum=True,
|
||||
latency=0.0, relative=False, before_read=None, after_read=None, title=None):
|
||||
"""Hill Climbing search: searches writables in decreasing steps to find a local maximum for the readable.
|
||||
Args:
|
||||
writables(list of Writable): Positioners set on each step.
|
||||
readable(Readable): Sensor to be sampled.
|
||||
range_min(list of float): minimum positions of writables.
|
||||
range_max(list of float): maximum positions of writables.
|
||||
initial_step(float or list of float):initial step size for for each writable.
|
||||
resolution(float or list of float): resolution of search for each writable (minimum step size).
|
||||
noise_filtering_steps(int): number of aditional steps to filter noise
|
||||
maximum (bool , optional): if True (default) search maximum, otherwise minimum.
|
||||
latency(float, optional): settling time for each step before readout, defaults to 0.0.
|
||||
relative (bool, optional): if true, range_min and range_max positions are relative to current at
|
||||
start of the scan
|
||||
before_read (function, optional): callback on each step, before each readout.
|
||||
after_read (function, optional): callback on each step, after each readout.
|
||||
title(str, optional): plotting window name.
|
||||
|
||||
Returns:
|
||||
SearchResult object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Hill climbing scan not supported.")
|
||||
1
packages/pyscan/interface/pyScan/__init__.py
Normal file
1
packages/pyscan/interface/pyScan/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .scan import *
|
||||
713
packages/pyscan/interface/pyScan/scan.py
Normal file
713
packages/pyscan/interface/pyScan/scan.py
Normal file
@@ -0,0 +1,713 @@
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pyscan.dal.epics_dal import PyEpicsDal
|
||||
from pyscan.interface.pyScan.utils import PyScanDataProcessor
|
||||
from pyscan.positioner.compound import CompoundPositioner
|
||||
from pyscan.positioner.serial import SerialPositioner
|
||||
from pyscan.positioner.vector import VectorPositioner
|
||||
from pyscan.scan_parameters import scan_settings
|
||||
from pyscan.scanner import Scanner
|
||||
from pyscan.utils import convert_to_list, convert_to_position_list, compare_channel_value
|
||||
|
||||
READ_GROUP = "Measurements"
|
||||
WRITE_GROUP = "Knobs"
|
||||
MONITOR_GROUP = "Monitors"
|
||||
|
||||
|
||||
class Scan(object):
|
||||
def execute_scan(self):
|
||||
|
||||
after_executor = self.get_action_executor("In-loopPostAction")
|
||||
|
||||
# Wrap the post action executor to update the number of completed scans.
|
||||
def progress_after_executor(scanner_instance, data):
|
||||
# Execute other post actions.
|
||||
after_executor(scanner_instance)
|
||||
|
||||
# Update progress.
|
||||
self.n_done_measurements += 1
|
||||
self.ProgDisp.Progress = 100.0 * (self.n_done_measurements /
|
||||
self.n_total_positions)
|
||||
|
||||
def prepare_monitors(reader):
|
||||
# If there are no monitors defined we have nothing to validate.
|
||||
if not self.dimensions[-1]["Monitor"]:
|
||||
return None
|
||||
|
||||
def validate_monitors(position, data):
|
||||
monitor_values = reader.read()
|
||||
combined_data = zip(self.dimensions[-1]['Monitor'],
|
||||
self.dimensions[-1]['MonitorValue'],
|
||||
self.dimensions[-1]['MonitorTolerance'],
|
||||
self.dimensions[-1]['MonitorAction'],
|
||||
self.dimensions[-1]['MonitorTimeout'],
|
||||
monitor_values)
|
||||
|
||||
for pv, expected_value, tolerance, action, timeout, value in combined_data:
|
||||
# Monitor value does not match.
|
||||
if not compare_channel_value(value, expected_value, tolerance):
|
||||
if action == "Abort":
|
||||
raise ValueError("Monitor %s, expected value %s, tolerance %s, has value %s. Aborting."
|
||||
% (pv, expected_value, tolerance, value))
|
||||
elif action == "WaitAndAbort":
|
||||
return False
|
||||
else:
|
||||
raise ValueError("MonitorAction %s, on PV %s, is not supported." % (pv, action))
|
||||
|
||||
return True
|
||||
|
||||
return validate_monitors
|
||||
|
||||
# Setup scan settings.
|
||||
settings = scan_settings(settling_time=self.dimensions[-1]["KnobWaitingExtra"],
|
||||
n_measurements=self.dimensions[-1]["NumberOfMeasurements"],
|
||||
measurement_interval=self.dimensions[-1]["Waiting"])
|
||||
|
||||
data_processor = PyScanDataProcessor(self.outdict,
|
||||
n_readbacks=self.n_readbacks,
|
||||
n_validations=self.n_validations,
|
||||
n_observables=self.n_observables,
|
||||
n_measurements=settings.n_measurements)
|
||||
|
||||
self.scanner = Scanner(positioner=self.get_positioner(), data_processor=data_processor,
|
||||
reader=self.epics_dal.get_group(READ_GROUP).read,
|
||||
writer=self.epics_dal.get_group(WRITE_GROUP).set_and_match,
|
||||
before_measurement_executor=self.get_action_executor("In-loopPreAction"),
|
||||
after_measurement_executor=progress_after_executor,
|
||||
initialization_executor=self.get_action_executor("PreAction"),
|
||||
finalization_executor=self.get_action_executor("PostAction"),
|
||||
data_validator=prepare_monitors(self.epics_dal.get_group(MONITOR_GROUP)),
|
||||
settings=settings)
|
||||
|
||||
self.outdict.update(self.scanner.discrete_scan())
|
||||
|
||||
def get_positioner(self):
|
||||
"""
|
||||
Generate a positioner for the provided dimensions.
|
||||
:return: Positioner object.
|
||||
"""
|
||||
# Read all the initial positions - in case we need to do an additive scan.
|
||||
initial_positions = self.epics_dal.get_group(READ_GROUP).read()
|
||||
|
||||
positioners = []
|
||||
knob_readback_offset = 0
|
||||
for dimension in self.dimensions:
|
||||
is_additive = bool(dimension.get("Additive", 0))
|
||||
is_series = bool(dimension.get("Series", 0))
|
||||
n_knob_readbacks = len(dimension["KnobReadback"])
|
||||
|
||||
# This dimension uses relative positions, read the PVs initial state.
|
||||
# We also need initial positions for the series scan.
|
||||
if is_additive or is_series:
|
||||
offsets = convert_to_list(
|
||||
initial_positions[knob_readback_offset:knob_readback_offset + n_knob_readbacks])
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
# Series scan in this dimension, use StepByStepVectorPositioner.
|
||||
if is_series:
|
||||
# In the StepByStep positioner, the initial values need to be added to the steps.
|
||||
positions = convert_to_list(dimension["ScanValues"])
|
||||
positioners.append(SerialPositioner(positions, initial_positions=offsets,
|
||||
offsets=offsets if is_additive else None))
|
||||
# Line scan in this dimension, use VectorPositioner.
|
||||
else:
|
||||
positions = convert_to_position_list(convert_to_list(dimension["KnobExpanded"]))
|
||||
positioners.append(VectorPositioner(positions, offsets=offsets))
|
||||
|
||||
# Increase the knob readback offset.
|
||||
knob_readback_offset += n_knob_readbacks
|
||||
|
||||
# Assemble all individual positioners together.
|
||||
positioner = CompoundPositioner(positioners)
|
||||
return positioner
|
||||
|
||||
def get_action_executor(self, entry_name):
|
||||
actions = []
|
||||
max_waiting = 0
|
||||
for dim_index, dim in enumerate(self.dimensions):
|
||||
for action_index, action in enumerate(dim[entry_name]):
|
||||
set_pv, read_pv, value, tolerance, timeout = action
|
||||
if set_pv == "match":
|
||||
raise NotImplementedError("match not yet implemented for PreAction.")
|
||||
|
||||
# Initialize the write group, to speed up in loop stuff.
|
||||
group_name = "%s_%d_%d" % (entry_name, dim_index, action_index)
|
||||
self.epics_dal.add_writer_group(group_name, set_pv, read_pv, tolerance, timeout)
|
||||
actions.append((group_name, value))
|
||||
|
||||
if entry_name + "Waiting" in dim:
|
||||
max_waiting = max(max_waiting, dim[entry_name + "Waiting"])
|
||||
|
||||
def execute(scanner):
|
||||
for action in actions:
|
||||
name = action[0]
|
||||
value = action[1]
|
||||
# Retrieve the epics group and write the value.
|
||||
self.epics_dal.get_group(name).set_and_match(value)
|
||||
|
||||
sleep(max_waiting)
|
||||
|
||||
return execute
|
||||
|
||||
class DummyProgress(object):
|
||||
def __init__(self):
|
||||
# For Thomas?
|
||||
self.Progress = 1
|
||||
self.abortScan = 0
|
||||
|
||||
def __init__(self):
|
||||
self.dimensions = None
|
||||
self.epics_dal = None
|
||||
self.scanner = None
|
||||
self.outdict = None
|
||||
|
||||
self.all_read_pvs = None
|
||||
self.n_readbacks = None
|
||||
self.n_validations = None
|
||||
self.n_observables = None
|
||||
self.n_total_positions = None
|
||||
self.n_measurements = None
|
||||
|
||||
# Accessed by some clients.
|
||||
self.ProgDisp = Scan.DummyProgress()
|
||||
self._pauseScan = 0
|
||||
|
||||
# Just to make old GUI work.
|
||||
self._abortScan = 0
|
||||
self.n_done_measurements = 0
|
||||
|
||||
@property
|
||||
def abortScan(self):
|
||||
return self._abort_scan
|
||||
|
||||
@abortScan.setter
|
||||
def abortScan(self, value):
|
||||
self._abortScan = value
|
||||
|
||||
if self._abortScan:
|
||||
self.scanner.abort_scan()
|
||||
|
||||
@property
|
||||
def pauseScan(self):
|
||||
return self._pauseScan
|
||||
|
||||
@pauseScan.setter
|
||||
def pauseScan(self, value):
|
||||
self._pauseScan = value
|
||||
|
||||
if self._pauseScan:
|
||||
self.scanner.pause_scan()
|
||||
else:
|
||||
self.scanner.resume_scan()
|
||||
|
||||
def initializeScan(self, inlist, dal=None):
|
||||
"""
|
||||
Initialize and verify the provided scan values.
|
||||
:param inlist: List of dictionaries for each dimension.
|
||||
:param dal: Which reader should be used to access the PVs. Default: PyEpicsDal.
|
||||
:return: Dictionary with results.
|
||||
"""
|
||||
if not inlist:
|
||||
raise ValueError("Provided inlist is empty.")
|
||||
|
||||
if dal is not None:
|
||||
self.epics_dal = dal
|
||||
else:
|
||||
self.epics_dal = PyEpicsDal()
|
||||
|
||||
# Prepare the scan dimensions.
|
||||
if isinstance(inlist, list):
|
||||
self.dimensions = inlist
|
||||
# In case it is a simple one dimensional scan.
|
||||
else:
|
||||
self.dimensions = [inlist]
|
||||
|
||||
try:
|
||||
for index, dic in enumerate(self.dimensions):
|
||||
# We read most of the PVs only if declared in the last dimension.
|
||||
is_last_dimension = index == (len(self.dimensions) - 1)
|
||||
|
||||
# Just in case there are identical input dictionaries. (Normally, it may not happen.)
|
||||
dic['ID'] = index
|
||||
|
||||
# Waiting time.
|
||||
if is_last_dimension and ('Waiting' not in dic.keys()):
|
||||
raise ValueError('Waiting for the scan was not given.')
|
||||
|
||||
# Validation channels - values just added to the results.
|
||||
if 'Validation' in dic.keys():
|
||||
if not isinstance(dic['Validation'], list):
|
||||
raise ValueError('Validation should be a list of channels. Input dictionary %d.' % index)
|
||||
else:
|
||||
dic['Validation'] = []
|
||||
|
||||
# Relative scan.
|
||||
if 'Additive' not in dic.keys():
|
||||
dic['Additive'] = 0
|
||||
|
||||
# Step back when pause is invoked.
|
||||
if is_last_dimension and ('StepbackOnPause' not in dic.keys()):
|
||||
dic['StepbackOnPause'] = 1
|
||||
|
||||
# Number of measurments per position.
|
||||
if is_last_dimension and ('NumberOfMeasurements' not in dic.keys()):
|
||||
dic['NumberOfMeasurements'] = 1
|
||||
|
||||
# PVs to sample.
|
||||
if is_last_dimension and ('Observable' not in dic.keys()):
|
||||
raise ValueError('The observable is not given.')
|
||||
elif is_last_dimension:
|
||||
if not isinstance(dic['Observable'], list):
|
||||
dic['Observable'] = [dic['Observable']]
|
||||
|
||||
self._setup_knobs(index, dic)
|
||||
|
||||
self._setup_knob_scan_values(index, dic)
|
||||
|
||||
self._setup_pre_actions(index, dic)
|
||||
|
||||
self._setup_inloop_pre_actions(index, dic)
|
||||
|
||||
self._setup_post_action(index, dic)
|
||||
|
||||
self._setup_inloop_post_action(index, dic)
|
||||
|
||||
# Total number of measurements
|
||||
self.n_total_positions = 1
|
||||
for dic in self.dimensions:
|
||||
if not dic['Series']:
|
||||
self.n_total_positions = self.n_total_positions * dic['Nstep']
|
||||
else:
|
||||
self.n_total_positions = self.n_total_positions * sum(dic['Nstep'])
|
||||
|
||||
self._setup_epics_dal()
|
||||
# Monitors only in the last dimension.
|
||||
self._setup_monitors(self.dimensions[-1])
|
||||
|
||||
# Prealocating the place for the output
|
||||
self.outdict = {"ErrorMessage": None,
|
||||
"KnobReadback": self.allocateOutput(),
|
||||
"Validation": self.allocateOutput(),
|
||||
"Observable": self.allocateOutput()}
|
||||
|
||||
except ValueError:
|
||||
self.outdict = {"ErrorMessage": traceback.format_exc()}
|
||||
|
||||
# Backward compatibility.
|
||||
self.ProgDisp.Progress = 0
|
||||
self.ProgDisp.abortScan = 0
|
||||
|
||||
self._pauseScan = 0
|
||||
self.abortScan = 0
|
||||
self.n_done_measurements = 0
|
||||
|
||||
return self.outdict
|
||||
|
||||
def allocateOutput(self):
|
||||
root_list = []
|
||||
for dimension in reversed(self.dimensions):
|
||||
n_steps = dimension['Nstep']
|
||||
|
||||
if dimension['Series']:
|
||||
# For Series scan, each step of each knob represents another result.
|
||||
current_dimension_list = []
|
||||
for n_steps_in_knob in n_steps:
|
||||
current_knob_list = []
|
||||
for _ in range(n_steps_in_knob):
|
||||
current_knob_list.append(deepcopy(root_list))
|
||||
|
||||
current_dimension_list.append(deepcopy(current_knob_list))
|
||||
root_list = current_dimension_list
|
||||
else:
|
||||
# For line scan, each step represents another result.
|
||||
current_dimension_list = []
|
||||
for _ in range(n_steps):
|
||||
current_dimension_list.append(deepcopy(root_list))
|
||||
root_list = current_dimension_list
|
||||
|
||||
return root_list
|
||||
|
||||
def _setup_epics_dal(self):
|
||||
# Collect all PVs that need to be read at each scan step.
|
||||
self.all_read_pvs = []
|
||||
all_write_pvs = []
|
||||
all_readback_pvs = []
|
||||
all_tolerances = []
|
||||
max_knob_waiting = -1
|
||||
|
||||
self.n_readbacks = 0
|
||||
for d in self.dimensions:
|
||||
self.all_read_pvs.append(d['KnobReadback'])
|
||||
self.n_readbacks += len(d['KnobReadback'])
|
||||
|
||||
# Collect all data need to write to PVs.
|
||||
all_write_pvs.append(d["Knob"])
|
||||
all_readback_pvs.append(d["KnobReadback"])
|
||||
all_tolerances.append(d["KnobTolerance"])
|
||||
max_knob_waiting = max(max_knob_waiting, max(d["KnobWaiting"]))
|
||||
|
||||
self.all_read_pvs.append(self.dimensions[-1]['Validation'])
|
||||
self.n_validations = len(self.dimensions[-1]['Validation'])
|
||||
self.all_read_pvs.append(self.dimensions[-1]['Observable'])
|
||||
self.n_observables = len(self.dimensions[-1]['Observable'])
|
||||
# Expand all read PVs
|
||||
self.all_read_pvs = [item for sublist in self.all_read_pvs for item in sublist]
|
||||
|
||||
# Expand Knobs and readbacks PVs.
|
||||
all_write_pvs = [item for sublist in all_write_pvs for item in sublist]
|
||||
all_readback_pvs = [item for sublist in all_readback_pvs for item in sublist]
|
||||
all_tolerances = [item for sublist in all_tolerances for item in sublist]
|
||||
|
||||
# Initialize PV connections and check if all PV names are valid.
|
||||
self.epics_dal.add_reader_group(READ_GROUP, self.all_read_pvs)
|
||||
self.epics_dal.add_writer_group(WRITE_GROUP, all_write_pvs, all_readback_pvs, all_tolerances, max_knob_waiting)
|
||||
|
||||
def _setup_knobs(self, index, dic):
|
||||
"""
|
||||
Setup the values for moving knobs in the scan.
|
||||
:param index: Index in the dictionary.
|
||||
:param dic: The dictionary.
|
||||
"""
|
||||
if 'Knob' not in dic.keys():
|
||||
raise ValueError('Knob for the scan was not given for the input dictionary %d.' % index)
|
||||
else:
|
||||
if not isinstance(dic['Knob'], list):
|
||||
dic['Knob'] = [dic['Knob']]
|
||||
|
||||
if 'KnobReadback' not in dic.keys():
|
||||
dic['KnobReadback'] = dic['Knob']
|
||||
if not isinstance(dic['KnobReadback'], list):
|
||||
dic['KnobReadback'] = [dic['KnobReadback']]
|
||||
if len(dic['KnobReadback']) != len(dic['Knob']):
|
||||
raise ValueError('The number of KnobReadback does not meet to the number of Knobs.')
|
||||
|
||||
if 'KnobTolerance' not in dic.keys():
|
||||
dic['KnobTolerance'] = [1.0] * len(dic['Knob'])
|
||||
if not isinstance(dic['KnobTolerance'], list):
|
||||
dic['KnobTolerance'] = [dic['KnobTolerance']]
|
||||
if len(dic['KnobTolerance']) != len(dic['Knob']):
|
||||
raise ValueError('The number of KnobTolerance does not meet to the number of Knobs.')
|
||||
|
||||
if 'KnobWaiting' not in dic.keys():
|
||||
dic['KnobWaiting'] = [10.0] * len(dic['Knob'])
|
||||
if not isinstance(dic['KnobWaiting'], list):
|
||||
dic['KnobWaiting'] = [dic['KnobWaiting']]
|
||||
if len(dic['KnobWaiting']) != len(dic['Knob']):
|
||||
raise ValueError('The number of KnobWaiting does not meet to the number of Knobs.')
|
||||
|
||||
if 'KnobWaitingExtra' not in dic.keys():
|
||||
dic['KnobWaitingExtra'] = 0.0
|
||||
else:
|
||||
try:
|
||||
dic['KnobWaitingExtra'] = float(dic['KnobWaitingExtra'])
|
||||
except:
|
||||
raise ValueError('KnobWaitingExtra is not a number in the input dictionary %d.' % index)
|
||||
|
||||
# Originally dic["Knob"] values were saved. I'm supposing this was a bug - readback values needed to be saved.
|
||||
|
||||
# TODO: We can optimize this by moving the initialization in the epics_dal init
|
||||
# but pre actions need to be moved after the epics_dal init than
|
||||
self.epics_dal.add_reader_group("KnobReadback", dic['KnobReadback'])
|
||||
dic['KnobSaved'] = self.epics_dal.get_group("KnobReadback").read()
|
||||
self.epics_dal.close_group("KnobReadback")
|
||||
|
||||
def _setup_knob_scan_values(self, index, dic):
|
||||
if 'Series' not in dic.keys():
|
||||
dic['Series'] = 0
|
||||
|
||||
if not dic['Series']: # Setting up scan values for SKS and MKS
|
||||
if 'ScanValues' not in dic.keys():
|
||||
if 'ScanRange' not in dic.keys():
|
||||
raise ValueError('Neither ScanRange nor ScanValues is given '
|
||||
'in the input dictionary %d.' % index)
|
||||
elif not isinstance(dic['ScanRange'], list):
|
||||
raise ValueError('ScanRange is not given in the right format. '
|
||||
'Input dictionary %d.' % index)
|
||||
elif not isinstance(dic['ScanRange'][0], list):
|
||||
dic['ScanRange'] = [dic['ScanRange']]
|
||||
|
||||
if ('Nstep' not in dic.keys()) and ('StepSize' not in dic.keys()):
|
||||
raise ValueError('Neither Nstep nor StepSize is given.')
|
||||
|
||||
if 'Nstep' in dic.keys(): # StepSize is ignored when Nstep is given
|
||||
if not isinstance(dic['Nstep'], int):
|
||||
raise ValueError('Nstep should be an integer. Input dictionary %d.' % index)
|
||||
ran = []
|
||||
for r in dic['ScanRange']:
|
||||
s = (r[1] - r[0]) / (dic['Nstep'] - 1)
|
||||
f = np.arange(r[0], r[1], s)
|
||||
f = np.append(f, np.array(r[1]))
|
||||
ran.append(f.tolist())
|
||||
dic['KnobExpanded'] = ran
|
||||
else: # StepSize given
|
||||
if len(dic['Knob']) > 1:
|
||||
raise ValueError('Give Nstep instead of StepSize for MKS. '
|
||||
'Input dictionary %d.' % index)
|
||||
# StepSize is only valid for SKS
|
||||
r = dic['ScanRange'][0]
|
||||
|
||||
# TODO: THIS IS RECONSTRUCTED AND MIGHT BE WRONG, CHECK!
|
||||
s = dic['StepSize'][0]
|
||||
|
||||
f = np.arange(r[0], r[1], s)
|
||||
f = np.append(f, np.array(r[1]))
|
||||
dic['Nstep'] = len(f)
|
||||
dic['KnobExpanded'] = [f.tolist()]
|
||||
else:
|
||||
# Scan values explicitly defined.
|
||||
if not isinstance(dic['ScanValues'], list):
|
||||
raise ValueError('ScanValues is not given in the right fromat. '
|
||||
'Input dictionary %d.' % index)
|
||||
|
||||
if len(dic['ScanValues']) != len(dic['Knob']) and len(dic['Knob']) != 1:
|
||||
raise ValueError('The length of ScanValues does not meet to the number of Knobs.')
|
||||
|
||||
if len(dic['Knob']) > 1:
|
||||
minlen = 100000
|
||||
for r in dic['ScanValues']:
|
||||
if minlen > len(r):
|
||||
minlen = len(r)
|
||||
ran = []
|
||||
for r in dic['ScanValues']:
|
||||
ran.append(r[0:minlen]) # Cut at the length of the shortest list.
|
||||
dic['KnobExpanded'] = ran
|
||||
dic['Nstep'] = minlen
|
||||
else:
|
||||
dic['KnobExpanded'] = [dic['ScanValues']]
|
||||
dic['Nstep'] = len(dic['ScanValues'])
|
||||
else: # Setting up scan values for Series scan
|
||||
if 'ScanValues' not in dic.keys():
|
||||
raise ValueError('ScanValues should be given for Series '
|
||||
'scan in the input dictionary %d.' % index)
|
||||
|
||||
if not isinstance(dic['ScanValues'], list):
|
||||
raise ValueError('ScanValues should be given as a list (of lists) '
|
||||
'for Series scan in the input dictionary %d.' % index)
|
||||
|
||||
if len(dic['Knob']) != len(dic['ScanValues']):
|
||||
raise ValueError('Scan values length does not match to the '
|
||||
'number of knobs in the input dictionary %d.' % index)
|
||||
|
||||
Nstep = []
|
||||
for vl in dic['ScanValues']:
|
||||
if not isinstance(vl, list):
|
||||
raise ValueError('ScanValue element should be given as a list for '
|
||||
'Series scan in the input dictionary %d.' % index)
|
||||
Nstep.append(len(vl))
|
||||
dic['Nstep'] = Nstep
|
||||
|
||||
def _setup_pre_actions(self, index, dic):
|
||||
if 'PreAction' in dic.keys():
|
||||
if not isinstance(dic['PreAction'], list):
|
||||
raise ValueError('PreAction should be a list. Input dictionary %d.' % index)
|
||||
for l in dic['PreAction']:
|
||||
if not isinstance(l, list):
|
||||
raise ValueError('Every PreAction should be a list. Input dictionary %d.' % index)
|
||||
if len(l) != 5:
|
||||
if not l[0] == 'SpecialAction':
|
||||
raise ValueError('Every PreAction should be in a form of '
|
||||
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
|
||||
'Input dictionary ' + str(index) + '.')
|
||||
|
||||
if 'PreActionWaiting' not in dic.keys():
|
||||
dic['PreActionWaiting'] = 0.0
|
||||
if not isinstance(dic['PreActionWaiting'], float) and not isinstance(dic['PreActionWaiting'], int):
|
||||
raise ValueError('PreActionWating should be a float. Input dictionary %d.' % index)
|
||||
|
||||
if 'PreActionOrder' not in dic.keys():
|
||||
dic['PreActionOrder'] = [0] * len(dic['PreAction'])
|
||||
if not isinstance(dic['PreActionOrder'], list):
|
||||
raise ValueError('PreActionOrder should be a list. Input dictionary %d.' % index)
|
||||
|
||||
else:
|
||||
dic['PreAction'] = []
|
||||
dic['PreActionWaiting'] = 0.0
|
||||
dic['PreActionOrder'] = [0] * len(dic['PreAction'])
|
||||
|
||||
def _setup_inloop_pre_actions(self, index, dic):
|
||||
if 'In-loopPreAction' in dic.keys():
|
||||
if not isinstance(dic['In-loopPreAction'], list):
|
||||
raise ValueError('In-loopPreAction should be a list. Input dictionary %d.' % index)
|
||||
for l in dic['In-loopPreAction']:
|
||||
if not isinstance(l, list):
|
||||
raise ValueError('Every In-loopPreAction should be a list. '
|
||||
'Input dictionary ' + str(index) + '.')
|
||||
if len(l) != 5:
|
||||
if not l[0] == 'SpecialAction':
|
||||
raise ValueError('Every In-loopPreAction should be in a form of '
|
||||
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
|
||||
'Input dictionary ' + str(index) + '.')
|
||||
|
||||
if 'In-loopPreActionWaiting' not in dic.keys():
|
||||
dic['In-loopPreActionWaiting'] = 0.0
|
||||
if not isinstance(dic['In-loopPreActionWaiting'], float) and not isinstance(
|
||||
dic['In-loopPreActionWaiting'], int):
|
||||
raise ValueError('In-loopPreActionWating should be a float. Input dictionary %d.' % index)
|
||||
|
||||
if 'In-loopPreActionOrder' not in dic.keys():
|
||||
dic['In-loopPreActionOrder'] = [0] * len(dic['In-loopPreAction'])
|
||||
if not isinstance(dic['In-loopPreActionOrder'], list):
|
||||
raise ValueError('In-loopPreActionOrder should be a list. Input dictionary %d.' % index)
|
||||
|
||||
else:
|
||||
dic['In-loopPreAction'] = []
|
||||
dic['In-loopPreActionWaiting'] = 0.0
|
||||
dic['In-loopPreActionOrder'] = [0] * len(dic['In-loopPreAction'])
|
||||
|
||||
def _setup_post_action(self, index, dic):
|
||||
if 'PostAction' in dic.keys():
|
||||
if dic['PostAction'] == 'Restore':
|
||||
PA = []
|
||||
for index in range(0, len(dic['Knob'])):
|
||||
k = dic['Knob'][index]
|
||||
v = dic['KnobSaved'][index]
|
||||
PA.append([k, k, v, 1.0, 10])
|
||||
dic['PostAction'] = PA
|
||||
elif not isinstance(dic['PostAction'], list):
|
||||
raise ValueError('PostAction should be a list. Input dictionary %d.' % index)
|
||||
Restore = 0
|
||||
for index in range(0, len(dic['PostAction'])):
|
||||
l = dic['PostAction'][index]
|
||||
if l == 'Restore':
|
||||
Restore = 1
|
||||
PA = []
|
||||
for j in range(0, len(dic['Knob'])):
|
||||
k = dic['Knob'][j]
|
||||
v = dic['KnobSaved'][j]
|
||||
PA.append([k, k, v, 1.0, 10])
|
||||
elif not isinstance(l, list):
|
||||
raise ValueError('Every PostAction should be a list. Input dictionary %d.' % index)
|
||||
elif len(l) != 5:
|
||||
if not l[0] == 'SpecialAction':
|
||||
raise ValueError('Every PostAction should be in a form of '
|
||||
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
|
||||
'Input dictionary %d.' % index)
|
||||
if Restore:
|
||||
dic['PostAction'].remove('Restore')
|
||||
dic['PostAction'] = dic['PostAction'] + PA
|
||||
|
||||
else:
|
||||
dic['PostAction'] = []
|
||||
|
||||
def _setup_inloop_post_action(self, index, dic):
|
||||
if 'In-loopPostAction' in dic.keys():
|
||||
if dic['In-loopPostAction'] == 'Restore':
|
||||
PA = []
|
||||
for index in range(0, len(dic['Knob'])):
|
||||
k = dic['Knob'][index]
|
||||
v = dic['KnobSaved'][index]
|
||||
PA.append([k, k, v, 1.0, 10])
|
||||
dic['In-loopPostAction'] = PA
|
||||
elif not isinstance(dic['In-loopPostAction'], list):
|
||||
raise ValueError('In-loopPostAction should be a list. Input dictionary %d.' % index)
|
||||
Restore = 0
|
||||
for index in range(0, len(dic['In-loopPostAction'])):
|
||||
l = dic['In-loopPostAction'][index]
|
||||
if l == 'Restore':
|
||||
Restore = 1
|
||||
PA = []
|
||||
for j in range(0, len(dic['Knob'])):
|
||||
k = dic['Knob'][j]
|
||||
v = dic['KnobSaved'][j]
|
||||
PA.append([k, k, v, 1.0, 10])
|
||||
dic['In-loopPostAction'][index] = PA
|
||||
elif not isinstance(l, list):
|
||||
raise ValueError('Every In-loopPostAction should be a list. '
|
||||
'Input dictionary %d.' % index)
|
||||
elif len(l) != 5:
|
||||
raise ValueError('Every In-loopPostAction should be in a form of '
|
||||
'[Ch-set, Ch-read, Value, Tolerance, Timeout]. '
|
||||
'Input dictionary %d.' % index)
|
||||
if Restore:
|
||||
dic['In-loopPostAction'].remove('Restore')
|
||||
dic['In-loopPostAction'] = dic['In-loopPostAction'] + PA
|
||||
else:
|
||||
dic['In-loopPostAction'] = []
|
||||
|
||||
def _setup_monitors(self, dic):
|
||||
if ('Monitor' in dic.keys()) and (dic['Monitor']):
|
||||
if isinstance(dic['Monitor'], str):
|
||||
dic['Monitor'] = [dic['Monitor']]
|
||||
|
||||
# Initialize monitor group and check if all monitor PVs are valid.
|
||||
self.epics_dal.add_reader_group(MONITOR_GROUP, dic["Monitor"])
|
||||
|
||||
if 'MonitorValue' not in dic.keys():
|
||||
dic["MonitorValue"] = self.epics_dal.get_group(MONITOR_GROUP).read()
|
||||
elif not isinstance(dic['MonitorValue'], list):
|
||||
dic['MonitorValue'] = [dic['MonitorValue']]
|
||||
if len(dic['MonitorValue']) != len(dic['Monitor']):
|
||||
raise ValueError('The length of MonitorValue does not meet to the length of Monitor.')
|
||||
|
||||
# Try to construct the monitor tolerance, if not given.
|
||||
if 'MonitorTolerance' not in dic.keys():
|
||||
dic['MonitorTolerance'] = []
|
||||
for value in self.epics_dal.get_group(MONITOR_GROUP).read():
|
||||
if isinstance(value, str):
|
||||
# No tolerance for string values.
|
||||
dic['MonitorTolerance'].append(None)
|
||||
elif value == 0:
|
||||
# Default tolerance for unknown values is 0.1.
|
||||
dic['MonitorTolerance'].append(0.1)
|
||||
else:
|
||||
# 10% of the current value will be the torelance when not given
|
||||
dic['MonitorTolerance'].append(abs(value * 0.1))
|
||||
|
||||
elif not isinstance(dic['MonitorTolerance'], list):
|
||||
dic['MonitorTolerance'] = [dic['MonitorTolerance']]
|
||||
if len(dic['MonitorTolerance']) != len(dic['Monitor']):
|
||||
raise ValueError('The length of MonitorTolerance does not meet to the length of Monitor.')
|
||||
|
||||
if 'MonitorAction' not in dic.keys():
|
||||
raise ValueError('MonitorAction is not give though Monitor is given.')
|
||||
|
||||
if not isinstance(dic['MonitorAction'], list):
|
||||
dic['MonitorAction'] = [dic['MonitorAction']]
|
||||
for m in dic['MonitorAction']:
|
||||
if m != 'Abort' and m != 'Wait' and m != 'WaitAndAbort':
|
||||
raise ValueError('MonitorAction shold be Wait, Abort, or WaitAndAbort.')
|
||||
|
||||
if 'MonitorTimeout' not in dic.keys():
|
||||
dic['MonitorTimeout'] = [30.0] * len(dic['Monitor'])
|
||||
elif not isinstance(dic['MonitorTimeout'], list):
|
||||
dic['MonitorValue'] = [dic['MonitorValue']]
|
||||
if len(dic['MonitorValue']) != len(dic['Monitor']):
|
||||
raise ValueError('The length of MonitorValue does not meet to the length of Monitor.')
|
||||
for m in dic['MonitorTimeout']:
|
||||
try:
|
||||
float(m)
|
||||
except:
|
||||
raise ValueError('MonitorTimeout should be a list of float(or int).')
|
||||
|
||||
else:
|
||||
dic['Monitor'] = []
|
||||
dic['MonitorValue'] = []
|
||||
dic['MonitorTolerance'] = []
|
||||
dic['MonitorAction'] = []
|
||||
dic['MonitorTimeout'] = []
|
||||
|
||||
def startScan(self):
|
||||
if self.outdict['ErrorMessage']:
|
||||
if 'After the last scan,' not in self.outdict['ErrorMessage']:
|
||||
self.outdict['ErrorMessage'] = 'It seems that the initialization was not successful... ' \
|
||||
'No scan was performed.'
|
||||
return self.outdict
|
||||
|
||||
# Execute the scan.
|
||||
self.outdict['TimeStampStart'] = datetime.now()
|
||||
self.execute_scan()
|
||||
self.outdict['TimeStampEnd'] = datetime.now()
|
||||
|
||||
self.outdict['ErrorMessage'] = 'Measurement finalized (finished/aborted) normally. ' \
|
||||
'Need initialisation before next measurement.'
|
||||
|
||||
# Cleanup after the scan.
|
||||
self.epics_dal.close_all_groups()
|
||||
|
||||
return self.outdict
|
||||
41
packages/pyscan/interface/pyScan/utils.py
Normal file
41
packages/pyscan/interface/pyScan/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pyscan.utils import flat_list_generator
|
||||
|
||||
|
||||
class PyScanDataProcessor(object):
|
||||
def __init__(self, output, n_readbacks, n_validations, n_observables, n_measurements):
|
||||
self.n_readbacks = n_readbacks
|
||||
self.n_validations = n_validations
|
||||
self.n_observables = n_observables
|
||||
self.n_measurements = n_measurements
|
||||
self.output = output
|
||||
self.KnobReadback_output_position = flat_list_generator(self.output["KnobReadback"])
|
||||
self.Validation_output_position = flat_list_generator(self.output["Validation"])
|
||||
self.Observable_output_position = flat_list_generator(self.output["Observable"])
|
||||
|
||||
def process(self, position, data):
|
||||
# Just we can always iterate over it.
|
||||
if self.n_measurements == 1:
|
||||
data = [data]
|
||||
|
||||
# Cells for each measurement are already prepared.
|
||||
readback_result = [measurement[0:self.n_readbacks]
|
||||
for measurement in data]
|
||||
validation_result = [measurement[self.n_readbacks:self.n_readbacks + self.n_validations]
|
||||
for measurement in data]
|
||||
|
||||
interval_start = self.n_readbacks + self.n_validations
|
||||
interval_end = self.n_readbacks + self.n_validations + self.n_observables
|
||||
observable_result = [measurement[interval_start:interval_end]
|
||||
for measurement in data]
|
||||
|
||||
if self.n_measurements == 1:
|
||||
next(self.KnobReadback_output_position).extend(readback_result[0])
|
||||
next(self.Validation_output_position).extend(validation_result[0])
|
||||
next(self.Observable_output_position).extend(observable_result[0])
|
||||
else:
|
||||
next(self.KnobReadback_output_position).extend(readback_result)
|
||||
next(self.Validation_output_position).extend(validation_result)
|
||||
next(self.Observable_output_position).extend(observable_result)
|
||||
|
||||
def get_data(self):
|
||||
return self.output
|
||||
0
packages/pyscan/positioner/__init__.py
Normal file
0
packages/pyscan/positioner/__init__.py
Normal file
184
packages/pyscan/positioner/area.py
Normal file
184
packages/pyscan/positioner/area.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import math
|
||||
from copy import copy
|
||||
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class AreaPositioner(object):
|
||||
def _validate_parameters(self):
|
||||
if not len(self.start) == len(self.end):
|
||||
raise ValueError("Number of start %s and end %s positions do not match." %
|
||||
(self.start, self.end))
|
||||
|
||||
if (self.n_steps and self.step_size) or (not self.n_steps and not self.step_size):
|
||||
raise ValueError("N_steps (%s) or step_sizes (%s) must be set, but not none "
|
||||
"or both of them at the same time." % (self.step_size, self.n_steps))
|
||||
|
||||
if self.n_steps and (not len(self.n_steps) == len(self.start)):
|
||||
raise ValueError("The number of n_steps %s does not match the number of start positions %s." %
|
||||
(self.n_steps, self.start))
|
||||
|
||||
if self.n_steps and not all(isinstance(x, int) for x in self.n_steps):
|
||||
raise ValueError("The n_steps %s must have only integers." % self.n_steps)
|
||||
|
||||
if self.step_size and (not len(self.step_size) == len(self.start)):
|
||||
raise ValueError("The number of step sizes %s does not match the number of start positions %s." %
|
||||
(self.step_size, self.start))
|
||||
|
||||
if not isinstance(self.passes, int) or self.passes < 1:
|
||||
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
|
||||
|
||||
if self.offsets and (not len(self.offsets) == len(self.start)):
|
||||
raise ValueError("Number of offsets %s does not match the number of start positions %s." %
|
||||
(self.offsets, self.start))
|
||||
|
||||
def __init__(self, start, end, n_steps=None, step_size=None, passes=1, offsets=None):
|
||||
self.start = convert_to_list(start)
|
||||
self.end = convert_to_list(end)
|
||||
self.n_steps = convert_to_list(n_steps)
|
||||
self.step_size = convert_to_list(step_size)
|
||||
self.passes = passes
|
||||
self.offsets = convert_to_list(offsets)
|
||||
|
||||
self._validate_parameters()
|
||||
|
||||
# Get the number of axis to scan.
|
||||
self.n_axis = len(self.start)
|
||||
|
||||
# Fix the offsets if provided.
|
||||
if self.offsets:
|
||||
self.start = [offset + original_value for original_value, offset in zip(self.start, self.offsets)]
|
||||
self.end = [offset + original_value for original_value, offset in zip(self.end, self.offsets)]
|
||||
|
||||
# Number of steps case.
|
||||
if self.n_steps:
|
||||
self.step_size = [(end - start) / steps for start, end, steps
|
||||
in zip(self.start, self.end, self.n_steps)]
|
||||
# Step size case.
|
||||
elif self.step_size:
|
||||
self.n_steps = [math.floor((end - start) / step_size) for start, end, step_size
|
||||
in zip(self.start, self.end, self.step_size)]
|
||||
|
||||
def get_generator(self):
|
||||
for _ in range(self.passes):
|
||||
positions = copy(self.start)
|
||||
# Return the initial state.
|
||||
yield copy(positions)
|
||||
|
||||
# Recursive call to print all axis values.
|
||||
def scan_axis(axis_number):
|
||||
# We should not scan axis that do not exist.
|
||||
if not axis_number < self.n_axis:
|
||||
return
|
||||
|
||||
# Output all position on the next axis while this axis is still at the start position.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Move axis step by step.
|
||||
for _ in range(self.n_steps[axis_number]):
|
||||
positions[axis_number] = positions[axis_number] + self.step_size[axis_number]
|
||||
yield copy(positions)
|
||||
# Output all positions from the next axis for each value of this axis.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Clean up after the loop - return the axis value back to the start value.
|
||||
positions[axis_number] = self.start[axis_number]
|
||||
|
||||
yield from scan_axis(0)
|
||||
|
||||
|
||||
class ZigZagAreaPositioner(AreaPositioner):
|
||||
def get_generator(self):
|
||||
for pass_number in range(self.passes):
|
||||
# Directions (positive ascending, negative descending) for each axis.
|
||||
directions = [1] * self.n_axis
|
||||
positions = copy(self.start)
|
||||
|
||||
# Return the initial state.
|
||||
yield copy(positions)
|
||||
|
||||
# Recursive call to print all axis values.
|
||||
def scan_axis(axis_number):
|
||||
# We should not scan axis that do not exist.
|
||||
if not axis_number < self.n_axis:
|
||||
return
|
||||
|
||||
# Output all position on the next axis while this axis is still at the start position.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Move axis step by step.
|
||||
for _ in range(self.n_steps[axis_number]):
|
||||
positions[axis_number] = positions[axis_number] + (self.step_size[axis_number]
|
||||
* directions[axis_number])
|
||||
yield copy(positions)
|
||||
# Output all positions from the next axis for each value of this axis.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Invert the direction for the next iteration on this axis.
|
||||
directions[axis_number] *= -1
|
||||
|
||||
yield from scan_axis(0)
|
||||
|
||||
|
||||
class MultiAreaPositioner(object):
|
||||
def __init__(self, start, end, steps, passes=1, offsets=None):
|
||||
self.offsets = offsets
|
||||
self.passes = passes
|
||||
self.end = end
|
||||
self.start = start
|
||||
|
||||
# Get the number of axis to scan.
|
||||
self.n_axis = len(self.start)
|
||||
|
||||
# Fix the offsets if provided.
|
||||
if self.offsets:
|
||||
self.start = [[original_value + offset for original_value, offset in zip(original_values, offsets)]
|
||||
for original_values, offsets in zip(self.start, self.offsets)]
|
||||
self.end = [[original_value + offset for original_value, offset in zip(original_values, offsets)]
|
||||
for original_values, offsets in zip(self.end, self.offsets)]
|
||||
|
||||
# Number of steps case.
|
||||
if isinstance(steps[0][0], int):
|
||||
# TODO: Verify that each axis has positive steps and that all are ints (all steps or step_size)
|
||||
self.n_steps = steps
|
||||
self.step_size = [[(end - start) / steps for start, end, steps in zip(starts, ends, line_steps)]
|
||||
for starts, ends, line_steps in zip(self.start, self.end, steps)]
|
||||
# Step size case.
|
||||
elif isinstance(steps[0][0], float):
|
||||
# TODO: Verify that each axis has the same number of steps and that the step_size is correct (positive etc.)
|
||||
self.n_steps = [[math.floor((end - start) / step) for start, end, step in zip(starts, ends, line_steps)]
|
||||
for starts, ends, line_steps in zip(self.start, self.end, steps)]
|
||||
self.step_size = steps
|
||||
# Something went wrong
|
||||
else:
|
||||
# TODO: Raise an exception.
|
||||
pass
|
||||
|
||||
def get_generator(self):
|
||||
for _ in range(self.passes):
|
||||
positions = copy(self.start)
|
||||
# Return the initial state.
|
||||
yield copy(positions)
|
||||
|
||||
# Recursive call to print all axis values.
|
||||
def scan_axis(axis_number):
|
||||
# We should not scan axis that do not exist.
|
||||
if not axis_number < self.n_axis:
|
||||
return
|
||||
|
||||
# Output all position on the next axis while this axis is still at the start position.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Move axis step by step.
|
||||
# TODO: Figure out what to do with this steps.
|
||||
for _ in range(self.n_steps[axis_number][0]):
|
||||
positions[axis_number] = [position + step_size for position, step_size
|
||||
in zip(positions[axis_number], self.step_size[axis_number])]
|
||||
yield copy(positions)
|
||||
# Output all positions from the next axis for each value of this axis.
|
||||
yield from scan_axis(axis_number + 1)
|
||||
|
||||
# Clean up after the loop - return the axis value back to the start value.
|
||||
positions[axis_number] = self.start[axis_number]
|
||||
|
||||
yield from scan_axis(0)
|
||||
21
packages/pyscan/positioner/bsread.py
Normal file
21
packages/pyscan/positioner/bsread.py
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
class BsreadPositioner(object):
|
||||
def __init__(self, n_messages):
|
||||
"""
|
||||
Acquire N consecutive messages from the stream.
|
||||
:param n_messages: Number of messages to acquire.
|
||||
"""
|
||||
self.n_messages = n_messages
|
||||
self.bs_reader = None
|
||||
|
||||
def set_bs_reader(self, bs_reader):
|
||||
self.bs_reader = bs_reader
|
||||
|
||||
def get_generator(self):
|
||||
|
||||
if self.bs_reader is None:
|
||||
raise RuntimeError("Set bs_reader before using this generator.")
|
||||
|
||||
for index in range(self.n_messages):
|
||||
self.bs_reader.read(index)
|
||||
yield index
|
||||
21
packages/pyscan/positioner/compound.py
Normal file
21
packages/pyscan/positioner/compound.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from copy import copy
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class CompoundPositioner(object):
|
||||
"""
|
||||
Given a list of positioners, it compounds them in given order, getting values from each of them at every step.
|
||||
"""
|
||||
def __init__(self, positioners):
|
||||
self.positioners = positioners
|
||||
self.n_positioners = len(positioners)
|
||||
|
||||
def get_generator(self):
|
||||
def walk_positioner(index, output_positions):
|
||||
if index == self.n_positioners:
|
||||
yield copy(output_positions)
|
||||
else:
|
||||
for current_positions in self.positioners[index].get_generator():
|
||||
yield from walk_positioner(index+1, output_positions + convert_to_list(current_positions))
|
||||
|
||||
yield from walk_positioner(0, [])
|
||||
91
packages/pyscan/positioner/line.py
Normal file
91
packages/pyscan/positioner/line.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import math
|
||||
from copy import copy
|
||||
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class LinePositioner(object):
|
||||
|
||||
def _validate_parameters(self):
|
||||
if not len(self.start) == len(self.end):
|
||||
raise ValueError("Number of start %s and end %s positions do not match." %
|
||||
(self.start, self.end))
|
||||
|
||||
# Only 1 among n_steps and step_sizes must be set.
|
||||
if (self.n_steps is not None and self.step_size) or (self.n_steps is None and not self.step_size):
|
||||
raise ValueError("N_steps (%s) or step_sizes (%s) must be set, but not none "
|
||||
"or both of them at the same time." % (self.step_size, self.n_steps))
|
||||
|
||||
# If n_steps is set, than it must be an integer greater than 0.
|
||||
if (self.n_steps is not None) and (not isinstance(self.n_steps, int) or self.n_steps < 1):
|
||||
raise ValueError("Steps must be a positive integer value, but %s was given." % self.n_steps)
|
||||
|
||||
if self.step_size and (not len(self.step_size) == len(self.start)):
|
||||
raise ValueError("The number of step sizes %s does not match the number of start positions %s." %
|
||||
(self.step_size, self.start))
|
||||
|
||||
if not isinstance(self.passes, int) or self.passes < 1:
|
||||
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
|
||||
|
||||
if self.offsets and (not len(self.offsets) == len(self.start)):
|
||||
raise ValueError("Number of offsets %s does not match the number of start positions %s." %
|
||||
(self.offsets, self.start))
|
||||
|
||||
def __init__(self, start, end, n_steps=None, step_size=None, passes=1, offsets=None):
|
||||
self.start = convert_to_list(start)
|
||||
self.end = convert_to_list(end)
|
||||
self.n_steps = n_steps
|
||||
self.step_size = convert_to_list(step_size)
|
||||
self.passes = passes
|
||||
self.offsets = convert_to_list(offsets)
|
||||
|
||||
self._validate_parameters()
|
||||
|
||||
# Fix the offsets if provided.
|
||||
if self.offsets:
|
||||
self.start = [offset + original_value for original_value, offset in zip(self.start, self.offsets)]
|
||||
self.end = [offset + original_value for original_value, offset in zip(self.end, self.offsets)]
|
||||
|
||||
# Number of steps case.
|
||||
if self.n_steps:
|
||||
self.step_size = [(end - start) / self.n_steps for start, end in zip(self.start, self.end)]
|
||||
# Step size case.
|
||||
elif self.step_size:
|
||||
n_steps_per_axis = [math.floor((end - start) / step_size) for start, end, step_size
|
||||
in zip(self.start, self.end, self.step_size)]
|
||||
# Verify that all axis do the same number of steps.
|
||||
if not all(x == n_steps_per_axis[0] for x in n_steps_per_axis):
|
||||
raise ValueError("The step sizes %s must give the same number of steps for each start %s "
|
||||
"and end % pair." % (self.step_size, self.start, self.end))
|
||||
|
||||
# All the elements in n_steps_per_axis must be the same anyway.
|
||||
self.n_steps = n_steps_per_axis[0]
|
||||
|
||||
def get_generator(self):
|
||||
for _ in range(self.passes):
|
||||
# The initial position is always the start position.
|
||||
current_positions = copy(self.start)
|
||||
yield current_positions
|
||||
|
||||
for __ in range(self.n_steps):
|
||||
current_positions = [position + step_size for position, step_size
|
||||
in zip(current_positions, self.step_size)]
|
||||
|
||||
yield current_positions
|
||||
|
||||
|
||||
class ZigZagLinePositioner(LinePositioner):
|
||||
def get_generator(self):
|
||||
# The initial position is always the start position.
|
||||
current_positions = copy(self.start)
|
||||
yield current_positions
|
||||
|
||||
for pass_number in range(self.passes):
|
||||
# Positive direction means we increase the position each step, negative we decrease.
|
||||
direction = 1 if pass_number % 2 == 0 else -1
|
||||
|
||||
for __ in range(self.n_steps):
|
||||
current_positions = [position + (step_size * direction) for position, step_size
|
||||
in zip(current_positions, self.step_size)]
|
||||
|
||||
yield current_positions
|
||||
40
packages/pyscan/positioner/serial.py
Normal file
40
packages/pyscan/positioner/serial.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from copy import copy
|
||||
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class SerialPositioner(object):
|
||||
"""
|
||||
Scan over all provided points, one by one, returning the previous to the initial state.
|
||||
Each axis is treated as a separate line.
|
||||
"""
|
||||
def __init__(self, positions, initial_positions, passes=1, offsets=None):
|
||||
self.positions = positions
|
||||
self.passes = passes
|
||||
self.offsets = offsets
|
||||
|
||||
if passes < 1:
|
||||
raise ValueError("Number of passes cannot be less than 1, but %d was provided." % passes)
|
||||
|
||||
self.initial_positions = initial_positions
|
||||
self.n_axis = len(self.initial_positions)
|
||||
|
||||
# In case only 1 axis is provided, still wrap it in a list, because it makes the generator code easier.
|
||||
if self.n_axis == 1:
|
||||
self.positions = [positions]
|
||||
|
||||
# Fix the offset if provided.
|
||||
if self.offsets:
|
||||
for axis_positions, offset in zip(self.positions, self.offsets):
|
||||
axis_positions[:] = [original_position + offset for original_position in axis_positions]
|
||||
|
||||
def get_generator(self):
|
||||
for _ in range(self.passes):
|
||||
# For each axis.
|
||||
for axis_index in range(self.n_axis):
|
||||
current_state = copy(self.initial_positions)
|
||||
|
||||
n_steps_in_axis = len(self.positions[axis_index])
|
||||
for axis_position_index in range(n_steps_in_axis):
|
||||
current_state[axis_index] = convert_to_list(self.positions[axis_index])[axis_position_index]
|
||||
yield copy(current_state)
|
||||
12
packages/pyscan/positioner/static.py
Normal file
12
packages/pyscan/positioner/static.py
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
class StaticPositioner(object):
|
||||
def __init__(self, n_images):
|
||||
"""
|
||||
Acquire N consecutive images in a static position.
|
||||
:param n_images: Number of images to acquire.
|
||||
"""
|
||||
self.n_images = n_images
|
||||
|
||||
def get_generator(self):
|
||||
for index in range(self.n_images):
|
||||
yield index
|
||||
52
packages/pyscan/positioner/time.py
Normal file
52
packages/pyscan/positioner/time.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from time import time, sleep
|
||||
|
||||
from pyscan.config import max_time_tolerance
|
||||
|
||||
smoothing_factor = 0.95
|
||||
|
||||
|
||||
class TimePositioner(object):
|
||||
def __init__(self, time_interval, n_intervals, tolerance=None):
|
||||
"""
|
||||
Time interval at which to read data.
|
||||
:param time_interval: Time interval in seconds.
|
||||
:param n_intervals: How many intervals to measure.
|
||||
"""
|
||||
self.time_interval = time_interval
|
||||
# Tolerance cannot be less than the min set tolerance.
|
||||
if tolerance is None or tolerance < max_time_tolerance:
|
||||
tolerance = max_time_tolerance
|
||||
self.tolerance = tolerance
|
||||
|
||||
# Minimum one measurement.
|
||||
if n_intervals < 1:
|
||||
n_intervals = 1
|
||||
self.n_intervals = n_intervals
|
||||
|
||||
def get_generator(self):
|
||||
measurement_time_start = time()
|
||||
last_time_to_sleep = 0
|
||||
|
||||
for _ in range(self.n_intervals):
|
||||
measurement_time_stop = time()
|
||||
# How much time did the measurement take.
|
||||
measurement_time = measurement_time_stop - measurement_time_start
|
||||
|
||||
time_to_sleep = self.time_interval - measurement_time
|
||||
# Use the smoothing factor to attenuate variations in the measurement time.
|
||||
time_to_sleep = (smoothing_factor * time_to_sleep) + ((1-smoothing_factor) * last_time_to_sleep)
|
||||
|
||||
# Time to sleep is negative (more time has elapsed, we cannot achieve the requested time interval.
|
||||
if time_to_sleep < (-1 * max_time_tolerance):
|
||||
raise ValueError("The requested time interval cannot be achieved. Last iteration took %.2f seconds, "
|
||||
"but a %.2f seconds time interval was set." % (measurement_time, self.time_interval))
|
||||
|
||||
# Sleep only if time to sleep is positive.
|
||||
if time_to_sleep > 0:
|
||||
sleep(time_to_sleep)
|
||||
|
||||
last_time_to_sleep = time_to_sleep
|
||||
measurement_time_start = time()
|
||||
|
||||
# Return the timestamp at which the measurement should begin.
|
||||
yield measurement_time_start
|
||||
52
packages/pyscan/positioner/vector.py
Normal file
52
packages/pyscan/positioner/vector.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from itertools import cycle, chain
|
||||
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
|
||||
class VectorPositioner(object):
|
||||
"""
|
||||
Moves over the provided positions.
|
||||
"""
|
||||
|
||||
def _validate_parameters(self):
|
||||
if not all(len(convert_to_list(x)) == len(convert_to_list(self.positions[0])) for x in self.positions):
|
||||
raise ValueError("All positions %s must have the same number of axis." % self.positions)
|
||||
|
||||
if not isinstance(self.passes, int) or self.passes < 1:
|
||||
raise ValueError("Passes must be a positive integer value, but %s was given." % self.passes)
|
||||
|
||||
if self.offsets and (not len(self.offsets) == len(self.positions[0])):
|
||||
raise ValueError("Number of offsets %s does not match the number of positions %s." %
|
||||
(self.offsets, self.positions[0]))
|
||||
|
||||
def __init__(self, positions, passes=1, offsets=None):
|
||||
self.positions = convert_to_list(positions)
|
||||
self.passes = passes
|
||||
self.offsets = convert_to_list(offsets)
|
||||
|
||||
self._validate_parameters()
|
||||
|
||||
# Number of positions to move to.
|
||||
self.n_positions = len(self.positions)
|
||||
|
||||
# Fix the offset if provided.
|
||||
if self.offsets:
|
||||
for step_positions in self.positions:
|
||||
step_positions[:] = [original_position + offset
|
||||
for original_position, offset in zip(step_positions, self.offsets)]
|
||||
|
||||
def get_generator(self):
|
||||
for _ in range(self.passes):
|
||||
for position in self.positions:
|
||||
yield position
|
||||
|
||||
|
||||
class ZigZagVectorPositioner(VectorPositioner):
|
||||
def get_generator(self):
|
||||
# This creates a generator for [0, 1, 2, 3... n, n-1, n-2.. 2, 1, 0.....]
|
||||
indexes = cycle(chain(range(0, self.n_positions, 1), range(self.n_positions - 2, 0, -1)))
|
||||
# First pass has the full number of items, each subsequent has one less (extreme sequence item).
|
||||
n_indexes = self.n_positions + ((self.passes - 1) * (self.n_positions - 1))
|
||||
|
||||
for x in range(n_indexes):
|
||||
yield self.positions[next(indexes)]
|
||||
260
packages/pyscan/scan.py
Normal file
260
packages/pyscan/scan.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import logging
|
||||
|
||||
from pyscan.dal import epics_dal, bsread_dal, function_dal
|
||||
from pyscan.dal.function_dal import FunctionProxy
|
||||
from pyscan.positioner.bsread import BsreadPositioner
|
||||
from pyscan.scanner import Scanner
|
||||
from pyscan.scan_parameters import EPICS_PV, EPICS_CONDITION, BS_PROPERTY, BS_CONDITION, scan_settings, convert_input, \
|
||||
FUNCTION_VALUE, FUNCTION_CONDITION, convert_conditions, ConditionAction, ConditionComparison
|
||||
from pyscan.utils import convert_to_list, SimpleDataProcessor, ActionExecutor, compare_channel_value
|
||||
|
||||
# Instances to use.
|
||||
EPICS_WRITER = epics_dal.WriteGroupInterface
|
||||
EPICS_READER = epics_dal.ReadGroupInterface
|
||||
BS_READER = bsread_dal.ReadGroupInterface
|
||||
FUNCTION_PROXY = function_dal.FunctionProxy
|
||||
DATA_PROCESSOR = SimpleDataProcessor
|
||||
ACTION_EXECUTOR = ActionExecutor
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scan(positioner, readables, writables=None, conditions=None, before_read=None, after_read=None, initialization=None,
|
||||
finalization=None, settings=None, data_processor=None, before_move=None, after_move=None):
|
||||
# Initialize the scanner instance.
|
||||
scanner_instance = scanner(positioner, readables, writables, conditions, before_read, after_read, initialization,
|
||||
finalization, settings, data_processor, before_move, after_move)
|
||||
|
||||
return scanner_instance.discrete_scan()
|
||||
|
||||
|
||||
def scanner(positioner, readables, writables=None, conditions=None, before_read=None, after_read=None,
|
||||
initialization=None, finalization=None, settings=None, data_processor=None,
|
||||
before_move=None, after_move=None):
|
||||
# Allow a list or a single value to be passed. Initialize None values.
|
||||
writables = convert_input(convert_to_list(writables) or [])
|
||||
readables = convert_input(convert_to_list(readables) or [])
|
||||
conditions = convert_conditions(convert_to_list(conditions) or [])
|
||||
before_read = convert_to_list(before_read) or []
|
||||
after_read = convert_to_list(after_read) or []
|
||||
before_move = convert_to_list(before_move) or []
|
||||
after_move = convert_to_list(after_move) or []
|
||||
initialization = convert_to_list(initialization) or []
|
||||
finalization = convert_to_list(finalization) or []
|
||||
settings = settings or scan_settings()
|
||||
|
||||
# TODO: Ugly. The scanner should not depend on a particular positioner implementation.
|
||||
if isinstance(positioner, BsreadPositioner) and settings.n_measurements > 1:
|
||||
raise ValueError("When using BsreadPositioner the maximum number of n_measurements = 1.")
|
||||
|
||||
bs_reader = _initialize_bs_dal(readables, conditions, settings.bs_read_filter, positioner)
|
||||
|
||||
epics_writer, epics_pv_reader, epics_condition_reader = _initialize_epics_dal(writables,
|
||||
readables,
|
||||
conditions,
|
||||
settings)
|
||||
|
||||
function_writer, function_reader, function_condition = _initialize_function_dal(writables,
|
||||
readables,
|
||||
conditions)
|
||||
|
||||
writables_order = [type(writable) for writable in writables]
|
||||
|
||||
# Write function needs to merge PV and function proxy data.
|
||||
def write_data(positions):
|
||||
positions = convert_to_list(positions)
|
||||
pv_values = [x for x, source in zip(positions, writables_order) if source == EPICS_PV]
|
||||
function_values = [x for x, source in zip(positions, writables_order) if source == FUNCTION_VALUE]
|
||||
|
||||
if epics_writer:
|
||||
epics_writer.set_and_match(pv_values)
|
||||
|
||||
if function_writer:
|
||||
function_writer.write(function_values)
|
||||
|
||||
# Order of value sources, needed to reconstruct the correct order of the result.
|
||||
readables_order = [type(readable) for readable in readables]
|
||||
|
||||
# Read function needs to merge BS, PV, and function proxy data.
|
||||
def read_data(current_position_index, retry=False):
|
||||
_logger.debug("Reading data for position index %s." % current_position_index)
|
||||
|
||||
bs_values = iter(bs_reader.read(current_position_index, retry) if bs_reader else [])
|
||||
epics_values = iter(epics_pv_reader.read(current_position_index) if epics_pv_reader else [])
|
||||
function_values = iter(function_reader.read(current_position_index) if function_reader else [])
|
||||
|
||||
# Interleave the values correctly.
|
||||
result = []
|
||||
for source in readables_order:
|
||||
if source == BS_PROPERTY:
|
||||
next_result = next(bs_values)
|
||||
elif source == EPICS_PV:
|
||||
next_result = next(epics_values)
|
||||
elif source == FUNCTION_VALUE:
|
||||
next_result = next(function_values)
|
||||
else:
|
||||
raise ValueError("Unknown type of readable %s used." % source)
|
||||
|
||||
# We flatten the result, whenever possible.
|
||||
if isinstance(next_result, list) and source != FUNCTION_VALUE:
|
||||
result.extend(next_result)
|
||||
else:
|
||||
result.append(next_result)
|
||||
|
||||
return result
|
||||
|
||||
# Order of value sources, needed to reconstruct the correct order of the result.
|
||||
conditions_order = [type(condition) for condition in conditions]
|
||||
|
||||
# Validate function needs to validate both BS, PV, and function proxy data.
|
||||
def validate_data(current_position_index, data):
|
||||
_logger.debug("Reading data for position index %s." % current_position_index)
|
||||
|
||||
bs_values = iter(bs_reader.read_cached_conditions() if bs_reader else [])
|
||||
epics_values = iter(epics_condition_reader.read(current_position_index) if epics_condition_reader else [])
|
||||
function_values = iter(function_condition.read(current_position_index) if function_condition else [])
|
||||
|
||||
for index, source in enumerate(conditions_order):
|
||||
|
||||
if source == BS_CONDITION:
|
||||
value = next(bs_values)
|
||||
elif source == EPICS_CONDITION:
|
||||
value = next(epics_values)
|
||||
elif source == FUNCTION_CONDITION:
|
||||
value = next(function_values)
|
||||
else:
|
||||
raise ValueError("Unknown type of condition %s used." % source)
|
||||
|
||||
value_valid = False
|
||||
|
||||
# Function conditions are self contained.
|
||||
if source == FUNCTION_CONDITION:
|
||||
if value:
|
||||
value_valid = True
|
||||
|
||||
else:
|
||||
expected_value = conditions[index].value
|
||||
tolerance = conditions[index].tolerance
|
||||
operation = conditions[index].operation
|
||||
|
||||
if compare_channel_value(value, expected_value, tolerance, operation):
|
||||
value_valid = True
|
||||
|
||||
if not value_valid:
|
||||
|
||||
if conditions[index].action == ConditionAction.Retry:
|
||||
return False
|
||||
|
||||
if source == FUNCTION_CONDITION:
|
||||
raise ValueError("Function condition %s returned False." % conditions[index].identifier)
|
||||
|
||||
else:
|
||||
raise ValueError("Condition %s failed, expected value %s, actual value %s, "
|
||||
"tolerance %s, operation %s." %
|
||||
(conditions[index].identifier,
|
||||
conditions[index].value,
|
||||
value,
|
||||
conditions[index].tolerance,
|
||||
conditions[index].operation))
|
||||
|
||||
return True
|
||||
|
||||
if not data_processor:
|
||||
data_processor = DATA_PROCESSOR()
|
||||
|
||||
# Before acquisition hook.
|
||||
before_measurement_executor = None
|
||||
if before_read:
|
||||
before_measurement_executor = ACTION_EXECUTOR(before_read).execute
|
||||
|
||||
# After acquisition hook.
|
||||
after_measurement_executor = None
|
||||
if after_read:
|
||||
after_measurement_executor = ACTION_EXECUTOR(after_read).execute
|
||||
|
||||
# Executor before each move.
|
||||
before_move_executor = None
|
||||
if before_move:
|
||||
before_move_executor = ACTION_EXECUTOR(before_move).execute
|
||||
|
||||
# Executor after each move.
|
||||
after_move_executor = None
|
||||
if after_move:
|
||||
after_move_executor = ACTION_EXECUTOR(after_move).execute
|
||||
|
||||
# Initialization (before move to first position) hook.
|
||||
initialization_executor = None
|
||||
if initialization:
|
||||
initialization_executor = ACTION_EXECUTOR(initialization).execute
|
||||
|
||||
# Finalization (after last acquisition AND on error) hook.
|
||||
finalization_executor = None
|
||||
if finalization:
|
||||
finalization_executor = ACTION_EXECUTOR(finalization).execute
|
||||
|
||||
scanner = Scanner(positioner=positioner, data_processor=data_processor, reader=read_data,
|
||||
writer=write_data, before_measurement_executor=before_measurement_executor,
|
||||
after_measurement_executor=after_measurement_executor,
|
||||
initialization_executor=initialization_executor,
|
||||
finalization_executor=finalization_executor, data_validator=validate_data, settings=settings,
|
||||
before_move_executor=before_move_executor, after_move_executor=after_move_executor)
|
||||
|
||||
return scanner
|
||||
|
||||
|
||||
def _initialize_epics_dal(writables, readables, conditions, settings):
|
||||
epics_writer = None
|
||||
if writables:
|
||||
epics_writables = [x for x in writables if isinstance(x, EPICS_PV)]
|
||||
if epics_writables:
|
||||
# Instantiate the PVs to move the motors.
|
||||
epics_writer = EPICS_WRITER(pv_names=[pv.pv_name for pv in epics_writables],
|
||||
readback_pv_names=[pv.readback_pv_name for pv in epics_writables],
|
||||
tolerances=[pv.tolerance for pv in epics_writables],
|
||||
timeout=settings.write_timeout)
|
||||
|
||||
epics_readables_pv_names = [x.pv_name for x in filter(lambda x: isinstance(x, EPICS_PV), readables)]
|
||||
epics_conditions_pv_names = [x.pv_name for x in filter(lambda x: isinstance(x, EPICS_CONDITION), conditions)]
|
||||
|
||||
# Reading epics PV values.
|
||||
epics_pv_reader = None
|
||||
if epics_readables_pv_names:
|
||||
epics_pv_reader = EPICS_READER(pv_names=epics_readables_pv_names)
|
||||
|
||||
# Reading epics condition values.
|
||||
epics_condition_reader = None
|
||||
if epics_conditions_pv_names:
|
||||
epics_condition_reader = EPICS_READER(pv_names=epics_conditions_pv_names)
|
||||
|
||||
return epics_writer, epics_pv_reader, epics_condition_reader
|
||||
|
||||
|
||||
def _initialize_bs_dal(readables, conditions, filter_function, positioner):
|
||||
bs_readables = [x for x in filter(lambda x: isinstance(x, BS_PROPERTY), readables)]
|
||||
bs_conditions = [x for x in filter(lambda x: isinstance(x, BS_CONDITION), conditions)]
|
||||
|
||||
bs_reader = None
|
||||
if bs_readables or bs_conditions:
|
||||
|
||||
# TODO: The scanner should not depend on a particular positioner. Refactor.
|
||||
if isinstance(positioner, BsreadPositioner):
|
||||
bs_reader = bsread_dal.ImmediateReadGroupInterface(properties=bs_readables,
|
||||
conditions=bs_conditions,
|
||||
filter_function=filter_function)
|
||||
|
||||
positioner.set_bs_reader(bs_reader)
|
||||
|
||||
return bs_reader
|
||||
|
||||
else:
|
||||
bs_reader = BS_READER(properties=bs_readables, conditions=bs_conditions, filter_function=filter_function)
|
||||
|
||||
return bs_reader
|
||||
|
||||
|
||||
def _initialize_function_dal(writables, readables, conditions):
|
||||
function_writer = FunctionProxy([x for x in writables if isinstance(x, FUNCTION_VALUE)])
|
||||
function_reader = FunctionProxy([x for x in readables if isinstance(x, FUNCTION_VALUE)])
|
||||
function_condition = FunctionProxy([x for x in conditions if isinstance(x, FUNCTION_CONDITION)])
|
||||
|
||||
return function_writer, function_reader, function_condition
|
||||
58
packages/pyscan/scan_actions.py
Normal file
58
packages/pyscan/scan_actions.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections import namedtuple
|
||||
from pyscan import config, convert_input
|
||||
from pyscan.scan import EPICS_WRITER, EPICS_READER
|
||||
from pyscan.scan_parameters import epics_pv
|
||||
from pyscan.utils import convert_to_list
|
||||
|
||||
SET_EPICS_PV = namedtuple("SET_EPICS_PV", ["pv_name", "value", "readback_pv_name", "tolerance", "timeout"])
|
||||
RESTORE_WRITABLE_PVS = namedtuple("RESTORE_WRITABLE_PVS", [])
|
||||
|
||||
|
||||
def action_set_epics_pv(pv_name, value, readback_pv_name=None, tolerance=None, timeout=None):
|
||||
"""
|
||||
Construct a tuple for set PV representation.
|
||||
:param pv_name: Name of the PV.
|
||||
:param value: Value to set the PV to.
|
||||
:param readback_pv_name: Name of the readback PV.
|
||||
:param tolerance: Tolerance if the PV is writable.
|
||||
:param timeout: Timeout for setting the pv value.
|
||||
:return: Tuple of (pv_name, pv_readback, tolerance)
|
||||
"""
|
||||
_, pv_name, readback_pv_name, tolerance, readback_pv_value = epics_pv(pv_name, readback_pv_name, tolerance)
|
||||
|
||||
if value is None:
|
||||
raise ValueError("pv value not specified.")
|
||||
|
||||
if not timeout or timeout < 0:
|
||||
timeout = config.epics_default_set_and_match_timeout
|
||||
|
||||
def execute():
|
||||
writer = EPICS_WRITER(pv_name, readback_pv_name, tolerance, timeout)
|
||||
writer.set_and_match(value)
|
||||
writer.close()
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def action_restore(writables):
|
||||
"""
|
||||
Restore the initial state of the writable PVs.
|
||||
:return: Empty tuple, to be replaced with the initial values.
|
||||
"""
|
||||
writables = convert_input(convert_to_list(writables))
|
||||
pv_names = [pv.pv_name for pv in writables]
|
||||
readback_pv_names = [pv.readback_pv_name for pv in writables]
|
||||
tolerances = [pv.tolerance for pv in writables]
|
||||
|
||||
# Get the initial values.
|
||||
reader = EPICS_READER(pv_names)
|
||||
initial_values = reader.read()
|
||||
reader.close()
|
||||
|
||||
def execute():
|
||||
writer = EPICS_WRITER(pv_names, readback_pv_names, tolerances)
|
||||
writer.set_and_match(initial_values)
|
||||
writer.close()
|
||||
|
||||
return execute
|
||||
|
||||
280
packages/pyscan/scan_parameters.py
Normal file
280
packages/pyscan/scan_parameters.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
from pyscan import config
|
||||
|
||||
EPICS_PV = namedtuple("EPICS_PV", ["identifier", "pv_name", "readback_pv_name", "tolerance", "readback_pv_value"])
|
||||
EPICS_CONDITION = namedtuple("EPICS_CONDITION", ["identifier", "pv_name", "value", "action", "tolerance", "operation"])
|
||||
BS_PROPERTY = namedtuple("BS_PROPERTY", ["identifier", "property", "default_value"])
|
||||
BS_CONDITION = namedtuple("BS_CONDITION", ["identifier", "property", "value", "action", "tolerance", "operation",
|
||||
"default_value"])
|
||||
SCAN_SETTINGS = namedtuple("SCAN_SETTINGS", ["measurement_interval", "n_measurements",
|
||||
"write_timeout", "settling_time", "progress_callback", "bs_read_filter"])
|
||||
FUNCTION_VALUE = namedtuple("FUNCTION_VALUE", ["identifier", "call_function"])
|
||||
FUNCTION_CONDITION = namedtuple("FUNCTION_CONDITION", ["identifier", "call_function", "action"])
|
||||
|
||||
|
||||
class ConditionComparison(Enum):
|
||||
EQUAL = 0
|
||||
NOT_EQUAL = 1
|
||||
LOWER = 2
|
||||
LOWER_OR_EQUAL = 3
|
||||
HIGHER = 4
|
||||
HIGHER_OR_EQUAL = 5
|
||||
|
||||
|
||||
class ConditionAction(Enum):
|
||||
Abort = 1
|
||||
Retry = 2
|
||||
|
||||
|
||||
# Used to determine if a parameter was passed or the default value is used.
|
||||
_default_value_placeholder = object()
|
||||
|
||||
|
||||
def function_value(call_function, name=None):
|
||||
"""
|
||||
Construct a tuple for function representation.
|
||||
:param call_function: Function to invoke.
|
||||
:param name: Name to assign to this function.
|
||||
:return: Tuple of ("identifier", "call_function")
|
||||
"""
|
||||
# If the name is not specified, use a counter to set the function name.
|
||||
if not name:
|
||||
name = "function_%d" % function_value.function_count
|
||||
function_value.function_count += 1
|
||||
identifier = name
|
||||
|
||||
return FUNCTION_VALUE(identifier, call_function)
|
||||
function_value.function_count = 0
|
||||
|
||||
|
||||
def function_condition(call_function, name=None, action=None):
|
||||
"""
|
||||
Construct a tuple for condition checking function representation.
|
||||
:param call_function: Function to invoke.
|
||||
:param name: Name to assign to this function.
|
||||
:param action: What to do then the return value is False.
|
||||
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
|
||||
:return: Tuple of ("identifier", "call_function", "action")
|
||||
"""
|
||||
# If the name is not specified, use a counter to set the function name.
|
||||
if not name:
|
||||
name = "function_condition_%d" % function_condition.function_count
|
||||
function_condition.function_count += 1
|
||||
identifier = name
|
||||
|
||||
# The default action is Abort - used for conditions.
|
||||
if not action:
|
||||
action = ConditionAction.Abort
|
||||
|
||||
return FUNCTION_CONDITION(identifier, call_function, action)
|
||||
function_condition.function_count = 0
|
||||
|
||||
|
||||
def epics_pv(pv_name, readback_pv_name=None, tolerance=None, readback_pv_value=None):
|
||||
"""
|
||||
Construct a tuple for PV representation
|
||||
:param pv_name: Name of the PV.
|
||||
:param readback_pv_name: Name of the readback PV.
|
||||
:param tolerance: Tolerance if the PV is writable.
|
||||
:param readback_pv_value: If the readback_pv_value is set, the readback is compared against this instead of
|
||||
comparing it to the setpoint.
|
||||
:return: Tuple of (identifier, pv_name, pv_readback, tolerance)
|
||||
"""
|
||||
identifier = pv_name
|
||||
|
||||
if not pv_name:
|
||||
raise ValueError("pv_name not specified.")
|
||||
|
||||
if not readback_pv_name:
|
||||
readback_pv_name = pv_name
|
||||
|
||||
if not tolerance or tolerance < config.max_float_tolerance:
|
||||
tolerance = config.max_float_tolerance
|
||||
|
||||
return EPICS_PV(identifier, pv_name, readback_pv_name, tolerance, readback_pv_value)
|
||||
|
||||
|
||||
def epics_condition(pv_name, value, action=None, tolerance=None, operation=ConditionComparison.EQUAL):
|
||||
"""
|
||||
Construct a tuple for an epics condition representation.
|
||||
:param pv_name: Name of the PV to monitor.
|
||||
:param value: Value we expect the PV to be in.
|
||||
:param action: What to do when the condition fails.
|
||||
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
|
||||
:param tolerance: Tolerance within which the condition needs to be.
|
||||
:param operation: How to compare the received value with the expected value.
|
||||
Allowed values: ConditionComparison.[EQUAL,NOT_EQUAL, LOWER, LOWER_OR_EQUAL, HIGHER, HIGHER_OR_EQUAL]
|
||||
:return: Tuple of ("pv_name", "value", "action", "tolerance", "timeout", "operation")
|
||||
"""
|
||||
identifier = pv_name
|
||||
|
||||
if not pv_name:
|
||||
raise ValueError("pv_name not specified.")
|
||||
|
||||
if value is None:
|
||||
raise ValueError("pv value not specified.")
|
||||
|
||||
# the default action is Abort.
|
||||
if not action:
|
||||
action = ConditionAction.Abort
|
||||
|
||||
if not tolerance or tolerance < config.max_float_tolerance:
|
||||
tolerance = config.max_float_tolerance
|
||||
|
||||
return EPICS_CONDITION(identifier, pv_name, value, action, tolerance, operation)
|
||||
|
||||
|
||||
def bs_property(name, default_value=_default_value_placeholder):
|
||||
"""
|
||||
Construct a tuple for bs read property representation.
|
||||
:param name: Complete property name.
|
||||
:param default_value: The default value that is assigned to the property if it is missing.
|
||||
:return: Tuple of ("identifier", "property", "default_value")
|
||||
"""
|
||||
identifier = name
|
||||
|
||||
if not name:
|
||||
raise ValueError("name not specified.")
|
||||
|
||||
# We need this to allow the user to change the config at runtime.
|
||||
if default_value is _default_value_placeholder:
|
||||
default_value = config.bs_default_missing_property_value
|
||||
|
||||
return BS_PROPERTY(identifier, name, default_value)
|
||||
|
||||
|
||||
def bs_condition(name, value, action=None, tolerance=None, operation=ConditionComparison.EQUAL,
|
||||
default_value=_default_value_placeholder):
|
||||
"""
|
||||
Construct a tuple for bs condition property representation.
|
||||
:param name: Complete property name.
|
||||
:param value: Expected value.
|
||||
:param action: What to do when the condition fails.
|
||||
('ConditionAction.Abort' and 'ConditionAction.Retry' supported)
|
||||
:param tolerance: Tolerance within which the condition needs to be.
|
||||
:param operation: How to compare the received value with the expected value.
|
||||
Allowed values: ConditionComparison.[EQUAL,NOT_EQUAL, LOWER, LOWER_OR_EQUAL, HIGHER, HIGHER_OR_EQUAL]
|
||||
:param default_value: Default value of a condition, if not present in the bs stream.
|
||||
:return: Tuple of ("identifier", "property", "value", "action", "tolerance", "operation", "default_value")
|
||||
"""
|
||||
identifier = name
|
||||
|
||||
if not name:
|
||||
raise ValueError("name not specified.")
|
||||
|
||||
if value is None:
|
||||
raise ValueError("value not specified.")
|
||||
|
||||
if not tolerance or tolerance < config.max_float_tolerance:
|
||||
tolerance = config.max_float_tolerance
|
||||
|
||||
if not action:
|
||||
action = ConditionAction.Abort
|
||||
|
||||
# We need this to allow the user to change the config at runtime.
|
||||
if default_value is _default_value_placeholder:
|
||||
default_value = config.bs_default_missing_property_value
|
||||
|
||||
return BS_CONDITION(identifier, name, value, action, tolerance, operation, default_value)
|
||||
|
||||
|
||||
def scan_settings(measurement_interval=None, n_measurements=None, write_timeout=None, settling_time=None,
|
||||
progress_callback=None, bs_read_filter=None):
|
||||
"""
|
||||
Set the scan settings.
|
||||
:param measurement_interval: Default 0. Interval between each measurement, in case n_measurements is more than 1.
|
||||
:param n_measurements: Default 1. How many measurements to make at each position.
|
||||
:param write_timeout: How much time to wait in seconds for set_and_match operations on epics PVs.
|
||||
:param settling_time: How much time to wait in seconds after the motors have reached the desired destination.
|
||||
:param progress_callback: Function to call after each scan step is completed.
|
||||
Signature: def callback(current_position, total_positions)
|
||||
:param bs_read_filter: Filter to apply to the bs read receive function, to filter incoming messages.
|
||||
Signature: def callback(message)
|
||||
:return: Scan settings named tuple.
|
||||
"""
|
||||
if not measurement_interval or measurement_interval < 0:
|
||||
measurement_interval = config.scan_default_measurement_interval
|
||||
|
||||
if not n_measurements or n_measurements < 1:
|
||||
n_measurements = config.scan_default_n_measurements
|
||||
|
||||
if not write_timeout or write_timeout < 0:
|
||||
write_timeout = config.epics_default_set_and_match_timeout
|
||||
|
||||
if not settling_time or settling_time < 0:
|
||||
settling_time = config.epics_default_settling_time
|
||||
|
||||
if not progress_callback:
|
||||
def default_progress_callback(current_position, total_positions):
|
||||
completed_percentage = 100.0 * (current_position / total_positions)
|
||||
print("Scan: %.2f %% completed (%d/%d)" % (completed_percentage, current_position, total_positions))
|
||||
|
||||
progress_callback = default_progress_callback
|
||||
|
||||
return SCAN_SETTINGS(measurement_interval, n_measurements, write_timeout, settling_time, progress_callback,
|
||||
bs_read_filter)
|
||||
|
||||
|
||||
def convert_input(input_parameters):
|
||||
"""
|
||||
Convert any type of input parameter into appropriate named tuples.
|
||||
:param input_parameters: Parameter input from the user.
|
||||
:return: Inputs converted into named tuples.
|
||||
"""
|
||||
converted_inputs = []
|
||||
for input in input_parameters:
|
||||
# Input already of correct type.
|
||||
if isinstance(input, (EPICS_PV, BS_PROPERTY, FUNCTION_VALUE)):
|
||||
converted_inputs.append(input)
|
||||
# We need to convert it.
|
||||
elif isinstance(input, str):
|
||||
# Check if the string is valid.
|
||||
if not input:
|
||||
raise ValueError("Input cannot be an empty string.")
|
||||
|
||||
if "://" in input:
|
||||
# Epics PV!
|
||||
if input.lower().startswith("ca://"):
|
||||
converted_inputs.append(epics_pv(input[5:]))
|
||||
# bs_read property.
|
||||
elif input.lower().startswith("bs://"):
|
||||
converted_inputs.append(bs_property(input[5:]))
|
||||
# A new protocol we don't know about?
|
||||
else:
|
||||
raise ValueError("Readable %s uses an unexpected protocol. "
|
||||
"'ca://' and 'bs://' are supported." % input)
|
||||
# No protocol specified, default is epics.
|
||||
else:
|
||||
converted_inputs.append(epics_pv(input))
|
||||
|
||||
elif callable(input):
|
||||
converted_inputs.append(function_value(input))
|
||||
# Supported named tuples or string, we cannot interpret the rest.
|
||||
else:
|
||||
raise ValueError("Input of unexpected type %s. Value: '%s'." % (type(input), input))
|
||||
|
||||
return converted_inputs
|
||||
|
||||
|
||||
def convert_conditions(input_conditions):
|
||||
"""
|
||||
Convert any type type of condition input parameter into appropriate named tuples.
|
||||
:param input_conditions: Condition input from the used.
|
||||
:return: Input conditions converted into named tuples.
|
||||
"""
|
||||
|
||||
converted_inputs = []
|
||||
for input in input_conditions:
|
||||
# Input already of correct type.
|
||||
if isinstance(input, (EPICS_CONDITION, BS_CONDITION, FUNCTION_CONDITION)):
|
||||
converted_inputs.append(input)
|
||||
# Function call.
|
||||
elif callable(input):
|
||||
converted_inputs.append(function_condition(input))
|
||||
# Unknown.
|
||||
else:
|
||||
raise ValueError("Condition of unexpected type %s. Value: '%s'." % (type(input), input))
|
||||
|
||||
return converted_inputs
|
||||
202
packages/pyscan/scanner.py
Normal file
202
packages/pyscan/scanner.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from itertools import count
|
||||
from time import sleep
|
||||
|
||||
from pyscan import config
|
||||
from pyscan.scan_parameters import scan_settings
|
||||
|
||||
STATUS_INITIALIZED = "INITIALIZED"
|
||||
STATUS_RUNNING = "RUNNING"
|
||||
STATUS_FINISHED = "FINISHED"
|
||||
STATUS_PAUSED = "PAUSED"
|
||||
STATUS_ABORTED = "ABORTED"
|
||||
|
||||
|
||||
class Scanner(object):
|
||||
"""
|
||||
Perform discrete and continues scans.
|
||||
"""
|
||||
|
||||
def __init__(self, positioner, data_processor, reader, writer=None, before_measurement_executor=None,
|
||||
after_measurement_executor=None, initialization_executor=None, finalization_executor=None,
|
||||
data_validator=None, settings=None, before_move_executor=None, after_move_executor=None):
|
||||
"""
|
||||
Initialize scanner.
|
||||
:param positioner: Positioner should provide a generator to get the positions to move to.
|
||||
:param writer: Object that implements the write(position) method and sets the positions.
|
||||
:param data_processor: How to store and handle the data.
|
||||
:param reader: Object that implements the read() method to return data to the data_processor.
|
||||
:param before_measurement_executor: Callbacks executor that executed before measurements.
|
||||
:param after_measurement_executor: Callbacks executor that executed after measurements.
|
||||
:param before_move_executor: Callbacks executor that executes before each move.
|
||||
:param after_move_executor: Callbacks executor that executes after each move.
|
||||
"""
|
||||
self.positioner = positioner
|
||||
self.writer = writer
|
||||
self.data_processor = data_processor
|
||||
self.reader = reader
|
||||
self.before_measurement_executor = before_measurement_executor
|
||||
self.after_measurement_executor = after_measurement_executor
|
||||
self.initialization_executor = initialization_executor
|
||||
self.finalization_executor = finalization_executor
|
||||
self.settings = settings or scan_settings()
|
||||
self.before_move_executor = before_move_executor
|
||||
self.after_move_executor = after_move_executor
|
||||
|
||||
# If no data validator is provided, data is always valid.
|
||||
self.data_validator = data_validator or (lambda position, data: True)
|
||||
|
||||
self._user_abort_scan_flag = False
|
||||
self._user_pause_scan_flag = False
|
||||
|
||||
self._status = STATUS_INITIALIZED
|
||||
|
||||
def abort_scan(self):
|
||||
"""
|
||||
Abort the scan after the next measurement.
|
||||
"""
|
||||
self._user_abort_scan_flag = True
|
||||
|
||||
def pause_scan(self):
|
||||
"""
|
||||
Pause the scan after the next measurement.
|
||||
"""
|
||||
self._user_pause_scan_flag = True
|
||||
|
||||
def get_status(self):
|
||||
return self._status
|
||||
|
||||
def resume_scan(self):
|
||||
"""
|
||||
Resume the scan.
|
||||
"""
|
||||
self._user_pause_scan_flag = False
|
||||
|
||||
def _verify_scan_status(self):
|
||||
"""
|
||||
Check if the conditions to pause or abort the scan are met.
|
||||
:raise Exception in case the conditions are met.
|
||||
"""
|
||||
# Check if the abort flag is set.
|
||||
if self._user_abort_scan_flag:
|
||||
self._status = STATUS_ABORTED
|
||||
raise Exception("User aborted scan.")
|
||||
|
||||
# If the scan is in pause, wait until it is resumed or the user aborts the scan.
|
||||
if self._user_pause_scan_flag:
|
||||
self._status = STATUS_PAUSED
|
||||
|
||||
while self._user_pause_scan_flag:
|
||||
if self._user_abort_scan_flag:
|
||||
self._status = STATUS_ABORTED
|
||||
raise Exception("User aborted scan in pause.")
|
||||
sleep(config.scan_pause_sleep_interval)
|
||||
# Once the pause flag is cleared, the scanning continues.
|
||||
self._status = STATUS_RUNNING
|
||||
|
||||
def _perform_single_read(self, current_position_index):
|
||||
"""
|
||||
Read a single result from the channel.
|
||||
:param current_position_index: Current position, passed to the validator.
|
||||
:return: Single result (all channels).
|
||||
"""
|
||||
n_current_acquisition = 0
|
||||
# Collect data until acquired data is valid or retry limit reached.
|
||||
while n_current_acquisition < config.scan_acquisition_retry_limit:
|
||||
retry_acquisition = n_current_acquisition != 0
|
||||
single_measurement = self.reader(current_position_index, retry=retry_acquisition)
|
||||
|
||||
# If the data is valid, break out of the loop.
|
||||
if self.data_validator(current_position_index, single_measurement):
|
||||
return single_measurement
|
||||
|
||||
n_current_acquisition += 1
|
||||
sleep(config.scan_acquisition_retry_delay)
|
||||
# Could not read the data within the retry limit.
|
||||
else:
|
||||
raise Exception("Number of maximum read attempts (%d) exceeded. Cannot read valid data at position %s."
|
||||
% (config.scan_acquisition_retry_limit, current_position_index))
|
||||
|
||||
def _read_and_process_data(self, current_position):
|
||||
"""
|
||||
Read the data and pass it on only if valid.
|
||||
:param current_position: Current position reached by the scan.
|
||||
:return: Current position scan data.
|
||||
"""
|
||||
# We do a single acquisition per position.
|
||||
if self.settings.n_measurements == 1:
|
||||
result = self._perform_single_read(current_position)
|
||||
|
||||
# Multiple acquisitions.
|
||||
else:
|
||||
result = []
|
||||
for n_measurement in range(self.settings.n_measurements):
|
||||
result.append(self._perform_single_read(current_position))
|
||||
sleep(self.settings.measurement_interval)
|
||||
|
||||
# Process only valid data.
|
||||
self.data_processor.process(current_position, result)
|
||||
|
||||
return result
|
||||
|
||||
def discrete_scan(self):
|
||||
"""
|
||||
Perform a discrete scan - set a position, read, continue. Return value at the end.
|
||||
"""
|
||||
try:
|
||||
self._status = STATUS_RUNNING
|
||||
|
||||
# Get how many positions we have in total.
|
||||
n_of_positions = sum(1 for _ in self.positioner.get_generator())
|
||||
# Report the 0% completed.
|
||||
self.settings.progress_callback(0, n_of_positions)
|
||||
|
||||
# Set up the experiment.
|
||||
if self.initialization_executor:
|
||||
self.initialization_executor(self)
|
||||
|
||||
for position_index, next_positions in zip(count(1), self.positioner.get_generator()):
|
||||
# Execute before moving to the next position.
|
||||
if self.before_move_executor:
|
||||
self.before_move_executor(next_positions)
|
||||
|
||||
# Position yourself before reading.
|
||||
if self.writer:
|
||||
self.writer(next_positions)
|
||||
|
||||
# Settling time, wait after positions has been reached.
|
||||
sleep(self.settings.settling_time)
|
||||
|
||||
# Execute the after move executor.
|
||||
if self.after_move_executor:
|
||||
self.after_move_executor(next_positions)
|
||||
|
||||
# Pre reading callbacks.
|
||||
if self.before_measurement_executor:
|
||||
self.before_measurement_executor(next_positions)
|
||||
|
||||
# Read and process the data in the current position.
|
||||
position_data = self._read_and_process_data(next_positions)
|
||||
|
||||
# Post reading callbacks.
|
||||
if self.after_measurement_executor:
|
||||
self.after_measurement_executor(next_positions, position_data)
|
||||
|
||||
# Report about the progress.
|
||||
self.settings.progress_callback(position_index, n_of_positions)
|
||||
|
||||
# Verify is the scan should continue.
|
||||
self._verify_scan_status()
|
||||
finally:
|
||||
# Clean up after yourself.
|
||||
if self.finalization_executor:
|
||||
self.finalization_executor(self)
|
||||
|
||||
# If the scan was aborted we do not change the status to finished.
|
||||
if self._status != STATUS_ABORTED:
|
||||
self._status = STATUS_FINISHED
|
||||
|
||||
return self.data_processor.get_data()
|
||||
|
||||
def continuous_scan(self):
|
||||
# TODO: Needs implementation.
|
||||
pass
|
||||
216
packages/pyscan/utils.py
Normal file
216
packages/pyscan/utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
from time import sleep
|
||||
|
||||
from epics.pv import PV
|
||||
|
||||
from pyscan import config
|
||||
from pyscan.scan_parameters import convert_input, ConditionComparison
|
||||
|
||||
|
||||
def compare_channel_value(current_value, expected_value, tolerance=0.0, operation=ConditionComparison.EQUAL):
|
||||
"""
|
||||
Check if the pv value is the same as the expected value, within tolerance for int and float.
|
||||
:param current_value: Current value to compare it to.
|
||||
:param expected_value: Expected value of the PV.
|
||||
:param tolerance: Tolerance for number comparison. Cannot be less than the minimum tolerance.
|
||||
:param operation: Operation to perform on the current and expected value - works for int and floats.
|
||||
:return: True if the value matches.
|
||||
"""
|
||||
# Minimum tolerance allowed.
|
||||
tolerance = max(tolerance, config.max_float_tolerance)
|
||||
|
||||
def compare_value(value):
|
||||
|
||||
# For numbers we compare them within tolerance.
|
||||
if isinstance(current_value, (float, int)):
|
||||
|
||||
if operation == ConditionComparison.EQUAL:
|
||||
return abs(current_value - expected_value) <= tolerance
|
||||
|
||||
elif operation == ConditionComparison.HIGHER:
|
||||
return (current_value - expected_value) > tolerance
|
||||
|
||||
elif operation == ConditionComparison.HIGHER_OR_EQUAL:
|
||||
return (current_value - expected_value) >= tolerance
|
||||
|
||||
elif operation == ConditionComparison.LOWER:
|
||||
return (current_value - expected_value) < 0 or abs(current_value - expected_value) < tolerance
|
||||
|
||||
elif operation == ConditionComparison.LOWER_OR_EQUAL:
|
||||
return (current_value - expected_value) <= 0 or abs(current_value - expected_value) <= tolerance
|
||||
|
||||
elif operation == ConditionComparison.NOT_EQUAL:
|
||||
return abs(current_value - expected_value) > tolerance
|
||||
|
||||
# Otherwise use the object comparison.
|
||||
else:
|
||||
try:
|
||||
if operation == ConditionComparison.EQUAL:
|
||||
return current_value == expected_value
|
||||
|
||||
elif operation == ConditionComparison.HIGHER:
|
||||
return current_value > expected_value
|
||||
|
||||
elif operation == ConditionComparison.HIGHER_OR_EQUAL:
|
||||
return current_value >= expected_value
|
||||
|
||||
elif operation == ConditionComparison.LOWER:
|
||||
return current_value < expected_value
|
||||
|
||||
elif operation == ConditionComparison.LOWER_OR_EQUAL:
|
||||
return current_value <= expected_value
|
||||
|
||||
elif operation == ConditionComparison.NOT_EQUAL:
|
||||
return current_value != expected_value
|
||||
|
||||
except:
|
||||
raise ValueError("Do not know how to compare current_value %s with expected_value %s and action %s."
|
||||
% (current_value, expected_value, operation))
|
||||
|
||||
return False
|
||||
|
||||
if isinstance(current_value, list):
|
||||
# In case of a list, any of the provided values will do.
|
||||
return any((compare_value(value) for value in expected_value))
|
||||
else:
|
||||
return compare_value(current_value)
|
||||
|
||||
|
||||
def connect_to_pv(pv_name, n_connection_attempts=3):
|
||||
"""
|
||||
Start a connection to a PV.
|
||||
:param pv_name: PV name to connect to.
|
||||
:param n_connection_attempts: How many times you should try to connect before raising an exception.
|
||||
:return: PV object.
|
||||
:raises ValueError if cannot connect to PV.
|
||||
"""
|
||||
pv = PV(pv_name, auto_monitor=False)
|
||||
for i in range(n_connection_attempts):
|
||||
if pv.connect():
|
||||
return pv
|
||||
sleep(0.1)
|
||||
raise ValueError("Cannot connect to PV '%s'." % pv_name)
|
||||
|
||||
|
||||
def validate_lists_length(*args):
|
||||
"""
|
||||
Check if all the provided lists are of the same length.
|
||||
:param args: Lists.
|
||||
:raise ValueError if they are not of the same length.
|
||||
"""
|
||||
if not args:
|
||||
raise ValueError("Cannot compare lengths of None.")
|
||||
|
||||
initial_length = len(args[0])
|
||||
if not all([len(element) == initial_length for element in args]):
|
||||
error = "The provided lists must be of same length.\n"
|
||||
for element in args:
|
||||
error += "%s\n" % element
|
||||
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def convert_to_list(value):
|
||||
"""
|
||||
If the input parameter is not a list, convert to one.
|
||||
:return: The value in a list, or None.
|
||||
"""
|
||||
# If None or a list, just return the value as it is.
|
||||
if (value is None) or isinstance(value, list):
|
||||
return value
|
||||
|
||||
# Otherwise treat the value as the first element in a list.
|
||||
return [value]
|
||||
|
||||
|
||||
def convert_to_position_list(axis_list):
|
||||
"""
|
||||
# Change the PER KNOB to PER INDEX of positions.
|
||||
:param axis_list: PER KNOB list of positions.
|
||||
:return: PER INDEX list of positions.
|
||||
"""
|
||||
return [list(positions) for positions in zip(*axis_list)]
|
||||
|
||||
|
||||
def flat_list_generator(list_to_flatten):
|
||||
# Just return the most inner list.
|
||||
if (len(list_to_flatten) == 0) or (not isinstance(list_to_flatten[0], list)):
|
||||
yield list_to_flatten
|
||||
# Otherwise we have to go deeper.
|
||||
else:
|
||||
for inner_list in list_to_flatten:
|
||||
yield from flat_list_generator(inner_list)
|
||||
|
||||
|
||||
class ActionExecutor(object):
|
||||
"""
|
||||
Execute all callbacks in the same thread.
|
||||
Each callback method should accept 2 parameters: position, sampled values.
|
||||
"""
|
||||
|
||||
def __init__(self, actions):
|
||||
"""
|
||||
Initialize the action executor.
|
||||
:param actions: Actions to execute. Single action or list of.
|
||||
"""
|
||||
self.actions = convert_to_list(actions)
|
||||
|
||||
def execute(self, position, position_data=None):
|
||||
for action in self.actions:
|
||||
n_parameters = len(inspect.signature(action).parameters)
|
||||
|
||||
if n_parameters == 2:
|
||||
action(position, position_data)
|
||||
|
||||
elif n_parameters == 1:
|
||||
action(position)
|
||||
|
||||
else:
|
||||
action()
|
||||
|
||||
|
||||
class SimpleDataProcessor(object):
|
||||
"""
|
||||
Save the position and the received data at this position.
|
||||
"""
|
||||
|
||||
def __init__(self, positions=None, data=None):
|
||||
"""
|
||||
Initialize the simple data processor.
|
||||
:param positions: List to store the visited positions. Default: internal list.
|
||||
:param data: List to store the data at each position. Default: internal list.
|
||||
"""
|
||||
self.positions = positions if positions is not None else []
|
||||
self.data = data if data is not None else []
|
||||
|
||||
def process(self, position, data):
|
||||
self.positions.append(position)
|
||||
self.data.append(data)
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
|
||||
|
||||
class DictionaryDataProcessor(SimpleDataProcessor):
|
||||
"""
|
||||
Save the positions and the received data for each position in a dictionary.
|
||||
"""
|
||||
def __init__(self, readables, positions=None, data=None):
|
||||
"""
|
||||
Readables specified in the scan.
|
||||
:param readables: Same readables that were passed to the scan function.
|
||||
"""
|
||||
super(DictionaryDataProcessor, self).__init__(positions=positions, data=data)
|
||||
|
||||
readables = convert_input(readables)
|
||||
self.readable_ids = [x.identifier for x in readables]
|
||||
|
||||
def process(self, position, data):
|
||||
self.positions.append(position)
|
||||
# Create a dictionary with the results.
|
||||
values = OrderedDict(zip(self.readable_ids, data))
|
||||
self.data.append(values)
|
||||
111
python310/packages/EGG-INFO/PKG-INFO
Normal file
111
python310/packages/EGG-INFO/PKG-INFO
Normal file
@@ -0,0 +1,111 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: elog
|
||||
Version: 1.3.4
|
||||
Summary: Python library to access Elog.
|
||||
Home-page: https://github.com/paulscherrerinstitute/py_elog
|
||||
Author: Paul Scherrer Institute (PSI)
|
||||
Author-email: UNKNOWN
|
||||
License: UNKNOWN
|
||||
Description: [](https://travis-ci.org/paulscherrerinstitute/py_elog) [](https://ci.appveyor.com/project/simongregorebner/py-elog)
|
||||
|
||||
# Overview
|
||||
This Python module provides a native interface [electronic logbooks](https://midas.psi.ch/elog/). It is compatible with Python versions 3.5 and higher.
|
||||
|
||||
# Usage
|
||||
|
||||
For accessing a logbook at ```http[s]://<hostename>:<port>/[<subdir>/]<logbook>/[<msg_id>]``` a logbook handle must be retrieved.
|
||||
|
||||
```python
|
||||
import elog
|
||||
|
||||
# Open GFA SwissFEL test logbook
|
||||
logbook = elog.open('https://elog-gfa.psi.ch/SwissFEL+test/')
|
||||
|
||||
# Contstructor using detailed arguments
|
||||
# Open demo logbook on local host: http://localhost:8080/demo/
|
||||
logbook = elog.open('localhost', 'demo', port=8080, use_ssl=False)
|
||||
```
|
||||
|
||||
Once you have hold of the logbook handle one of its public methods can be used to read, create, reply to, edit or delete the message.
|
||||
|
||||
## Get Existing Message Ids
|
||||
Get all the existing message ids of a logbook
|
||||
|
||||
```python
|
||||
message_ids = logbook.get_message_ids()
|
||||
```
|
||||
|
||||
To get if of the last inserted message
|
||||
```python
|
||||
last_message_id = logbook.get_last_message_id()
|
||||
```
|
||||
|
||||
## Read Message
|
||||
|
||||
```python
|
||||
# Read message with with message ID = 23
|
||||
message, attributes, attachments = logbook.read(23)
|
||||
```
|
||||
|
||||
## Create Message
|
||||
|
||||
```python
|
||||
# Create new message with some text, attributes (dict of attributes + kwargs) and attachments
|
||||
new_msg_id = logbook.post('This is message text', attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
|
||||
```
|
||||
|
||||
What attributes are required is determined by the configuration of the elog server (keywork `Required Attributes`).
|
||||
If the configuration looks like this:
|
||||
|
||||
```
|
||||
Required Attributes = Author, Type
|
||||
```
|
||||
|
||||
You have to provide author and type when posting a message.
|
||||
|
||||
In case type need to be specified, the supported keywords can as well be found in the elog configuration with the key `Options Type`.
|
||||
|
||||
If the config looks like this:
|
||||
```
|
||||
Options Type = Routine, Software Installation, Problem Fixed, Configuration, Other
|
||||
```
|
||||
|
||||
A working create call would look like this:
|
||||
|
||||
```python
|
||||
new_msg_id = logbook.post('This is message text', author='me', type='Routine')
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Reply to Message
|
||||
|
||||
```python
|
||||
# Reply to message with ID=23
|
||||
new_msg_id = logbook.post('This is a reply', msg_id=23, reply=True, attributes=dict_of_attributes, attachments=list_of_attachments, attribute_as_param='value')
|
||||
```
|
||||
|
||||
## Edit Message
|
||||
|
||||
```python
|
||||
# Edit message with ID=23. Changed message text, some attributes (dict of edited attributes + kwargs) and new attachments
|
||||
edited_msg_id = logbook.post('This is new message text', msg_id=23, attributes=dict_of_changed_attributes, attachments=list_of_new_attachments, attribute_as_param='new value')
|
||||
```
|
||||
|
||||
## Delete Message (and all its replies)
|
||||
|
||||
```python
|
||||
# Delete message with ID=23. All its replies will also be deleted.
|
||||
logbook.delete(23)
|
||||
```
|
||||
|
||||
__Note:__ Due to the way elog implements delete this function is only supported on english logbooks.
|
||||
|
||||
# Installation
|
||||
The Elog module and only depends on the `passlib` and `requests` library used for password encryption and http(s) communication. It is packed as [anaconda package](https://anaconda.org/paulscherrerinstitute/elog) and can be installed as follows:
|
||||
|
||||
```bash
|
||||
conda install -c paulscherrerinstitute elog
|
||||
```
|
||||
Keywords: elog,electronic,logbook
|
||||
Platform: UNKNOWN
|
||||
8
python310/packages/EGG-INFO/SOURCES.txt
Normal file
8
python310/packages/EGG-INFO/SOURCES.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
setup.py
|
||||
elog/__init__.py
|
||||
elog/logbook.py
|
||||
elog/logbook_exceptions.py
|
||||
elog.egg-info/PKG-INFO
|
||||
elog.egg-info/SOURCES.txt
|
||||
elog.egg-info/dependency_links.txt
|
||||
elog.egg-info/top_level.txt
|
||||
1
python310/packages/EGG-INFO/dependency_links.txt
Normal file
1
python310/packages/EGG-INFO/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
python310/packages/EGG-INFO/top_level.txt
Normal file
1
python310/packages/EGG-INFO/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
elog
|
||||
1
python310/packages/EGG-INFO/zip-safe
Normal file
1
python310/packages/EGG-INFO/zip-safe
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
78
python310/packages/bigtree/__init__.py
Normal file
78
python310/packages/bigtree/__init__.py
Normal file
@@ -0,0 +1,78 @@
|
||||
__version__ = "0.16.2"
|
||||
|
||||
from bigtree.binarytree.construct import list_to_binarytree
|
||||
from bigtree.dag.construct import dataframe_to_dag, dict_to_dag, list_to_dag
|
||||
from bigtree.dag.export import dag_to_dataframe, dag_to_dict, dag_to_dot, dag_to_list
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import (
|
||||
add_dataframe_to_tree_by_name,
|
||||
add_dataframe_to_tree_by_path,
|
||||
add_dict_to_tree_by_name,
|
||||
add_dict_to_tree_by_path,
|
||||
add_path_to_tree,
|
||||
dataframe_to_tree,
|
||||
dataframe_to_tree_by_relation,
|
||||
dict_to_tree,
|
||||
list_to_tree,
|
||||
list_to_tree_by_relation,
|
||||
nested_dict_to_tree,
|
||||
newick_to_tree,
|
||||
str_to_tree,
|
||||
)
|
||||
from bigtree.tree.export import (
|
||||
hprint_tree,
|
||||
hyield_tree,
|
||||
print_tree,
|
||||
tree_to_dataframe,
|
||||
tree_to_dict,
|
||||
tree_to_dot,
|
||||
tree_to_mermaid,
|
||||
tree_to_nested_dict,
|
||||
tree_to_newick,
|
||||
tree_to_pillow,
|
||||
yield_tree,
|
||||
)
|
||||
from bigtree.tree.helper import clone_tree, get_subtree, get_tree_diff, prune_tree
|
||||
from bigtree.tree.modify import (
|
||||
copy_and_replace_nodes_from_tree_to_tree,
|
||||
copy_nodes,
|
||||
copy_nodes_from_tree_to_tree,
|
||||
copy_or_shift_logic,
|
||||
replace_logic,
|
||||
shift_and_replace_nodes,
|
||||
shift_nodes,
|
||||
)
|
||||
from bigtree.tree.search import (
|
||||
find,
|
||||
find_attr,
|
||||
find_attrs,
|
||||
find_child,
|
||||
find_child_by_name,
|
||||
find_children,
|
||||
find_full_path,
|
||||
find_name,
|
||||
find_names,
|
||||
find_path,
|
||||
find_paths,
|
||||
find_relative_path,
|
||||
findall,
|
||||
)
|
||||
from bigtree.utils.groot import speak_like_groot, whoami
|
||||
from bigtree.utils.iterators import (
|
||||
dag_iterator,
|
||||
inorder_iter,
|
||||
levelorder_iter,
|
||||
levelordergroup_iter,
|
||||
postorder_iter,
|
||||
preorder_iter,
|
||||
zigzag_iter,
|
||||
zigzaggroup_iter,
|
||||
)
|
||||
from bigtree.utils.plot import reingold_tilford
|
||||
from bigtree.workflows.app_calendar import Calendar
|
||||
from bigtree.workflows.app_todo import AppToDo
|
||||
|
||||
sphinx_versions = ["latest", "0.16.2", "0.15.7", "0.14.8"]
|
||||
0
python310/packages/bigtree/binarytree/__init__.py
Normal file
0
python310/packages/bigtree/binarytree/__init__.py
Normal file
53
python310/packages/bigtree/binarytree/construct.py
Normal file
53
python310/packages/bigtree/binarytree/construct.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import List, Type
|
||||
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
|
||||
__all__ = ["list_to_binarytree"]
|
||||
|
||||
|
||||
def list_to_binarytree(
|
||||
heapq_list: List[int], node_type: Type[BinaryNode] = BinaryNode
|
||||
) -> BinaryNode:
|
||||
"""Construct tree from a list of numbers (int or float) in heapq format.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import list_to_binarytree, tree_to_dot
|
||||
>>> nums_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> root = list_to_binarytree(nums_list)
|
||||
>>> root.show()
|
||||
1
|
||||
├── 2
|
||||
│ ├── 4
|
||||
│ │ ├── 8
|
||||
│ │ └── 9
|
||||
│ └── 5
|
||||
│ └── 10
|
||||
└── 3
|
||||
├── 6
|
||||
└── 7
|
||||
>>> graph = tree_to_dot(root, node_colour="gold")
|
||||
>>> graph.write_png("assets/construct_binarytree.png")
|
||||
|
||||

|
||||
|
||||
Args:
|
||||
heapq_list (List[int]): list containing integer node names, ordered in heapq fashion
|
||||
node_type (Type[BinaryNode]): node type of tree to be created, defaults to ``BinaryNode``
|
||||
|
||||
Returns:
|
||||
(BinaryNode)
|
||||
"""
|
||||
if not len(heapq_list):
|
||||
raise ValueError("Input list does not contain any data, check `heapq_list`")
|
||||
|
||||
root_node = node_type(heapq_list[0])
|
||||
node_list = [root_node]
|
||||
for idx, num in enumerate(heapq_list):
|
||||
if idx:
|
||||
if idx % 2:
|
||||
parent_idx = int((idx - 1) / 2)
|
||||
else:
|
||||
parent_idx = int((idx - 2) / 2)
|
||||
node = node_type(num, parent=node_list[parent_idx]) # type: ignore
|
||||
node_list.append(node)
|
||||
return root_node
|
||||
0
python310/packages/bigtree/dag/__init__.py
Normal file
0
python310/packages/bigtree/dag/__init__.py
Normal file
206
python310/packages/bigtree/dag/construct.py
Normal file
206
python310/packages/bigtree/dag/construct.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.utils.exceptions import optional_dependencies_pandas
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError: # pragma: no cover
|
||||
pd = None
|
||||
|
||||
__all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"]
|
||||
|
||||
|
||||
@optional_dependencies_pandas
|
||||
def list_to_dag(
|
||||
relations: List[Tuple[str, str]],
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from list of tuples containing parent-child names.
|
||||
Note that node names must be unique.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import list_to_dag, dag_iterator
|
||||
>>> relations_list = [("a", "c"), ("a", "d"), ("b", "c"), ("c", "d"), ("d", "e")]
|
||||
>>> dag = list_to_dag(relations_list)
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
relations (List[Tuple[str, str]]): list containing tuple of parent-child names
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
if not len(relations):
|
||||
raise ValueError("Input list does not contain any data, check `relations`")
|
||||
|
||||
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
|
||||
return dataframe_to_dag(
|
||||
relation_data, child_col="child", parent_col="parent", node_type=node_type
|
||||
)
|
||||
|
||||
|
||||
def dict_to_dag(
|
||||
relation_attrs: Dict[str, Any],
|
||||
parent_key: str = "parents",
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from nested dictionary, ``key``: child name, ``value``: dictionary of parent names, attribute
|
||||
name, and attribute value.
|
||||
Note that node names must be unique.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import dict_to_dag, dag_iterator
|
||||
>>> relation_dict = {
|
||||
... "a": {"step": 1},
|
||||
... "b": {"step": 1},
|
||||
... "c": {"parents": ["a", "b"], "step": 2},
|
||||
... "d": {"parents": ["a", "c"], "step": 2},
|
||||
... "e": {"parents": ["d"], "step": 3},
|
||||
... }
|
||||
>>> dag = dict_to_dag(relation_dict, parent_key="parents")
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
relation_attrs (Dict[str, Any]): dictionary containing node, node parents, and node attribute information,
|
||||
key: child name, value: dictionary of parent names, node attribute, and attribute value
|
||||
parent_key (str): key of dictionary to retrieve list of parents name, defaults to 'parent'
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
if not len(relation_attrs):
|
||||
raise ValueError("Dictionary does not contain any data, check `relation_attrs`")
|
||||
|
||||
# Convert dictionary to dataframe
|
||||
data = pd.DataFrame(relation_attrs).T.rename_axis("_tmp_child").reset_index()
|
||||
if parent_key not in data:
|
||||
raise ValueError(
|
||||
f"Parent key {parent_key} not in dictionary, check `relation_attrs` and `parent_key`"
|
||||
)
|
||||
|
||||
data = data.explode(parent_key)
|
||||
return dataframe_to_dag(
|
||||
data,
|
||||
child_col="_tmp_child",
|
||||
parent_col=parent_key,
|
||||
node_type=node_type,
|
||||
)
|
||||
|
||||
|
||||
@optional_dependencies_pandas
|
||||
def dataframe_to_dag(
|
||||
data: pd.DataFrame,
|
||||
child_col: str = "",
|
||||
parent_col: str = "",
|
||||
attribute_cols: List[str] = [],
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from pandas DataFrame.
|
||||
Note that node names must be unique.
|
||||
|
||||
- `child_col` and `parent_col` specify columns for child name and parent name to construct DAG.
|
||||
- `attribute_cols` specify columns for node attribute for child name.
|
||||
- If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
|
||||
columns are `attribute_cols`.
|
||||
|
||||
Examples:
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import dataframe_to_dag, dag_iterator
|
||||
>>> relation_data = pd.DataFrame([
|
||||
... ["a", None, 1],
|
||||
... ["b", None, 1],
|
||||
... ["c", "a", 2],
|
||||
... ["c", "b", 2],
|
||||
... ["d", "a", 2],
|
||||
... ["d", "c", 2],
|
||||
... ["e", "d", 3],
|
||||
... ],
|
||||
... columns=["child", "parent", "step"]
|
||||
... )
|
||||
>>> dag = dataframe_to_dag(relation_data)
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): data containing path and node attribute information
|
||||
child_col (str): column of data containing child name information, defaults to ''
|
||||
if not set, it will take the first column of data
|
||||
parent_col (str): column of data containing parent name information, defaults to ''
|
||||
if not set, it will take the second column of data
|
||||
attribute_cols (List[str]): columns of data containing child node attribute information,
|
||||
if not set, it will take all columns of data except `child_col` and `parent_col`
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to ``DAGNode``
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
data = data.copy()
|
||||
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not child_col:
|
||||
child_col = data.columns[0]
|
||||
elif child_col not in data.columns:
|
||||
raise ValueError(f"Child column not in data, check `child_col`: {child_col}")
|
||||
if not parent_col:
|
||||
parent_col = data.columns[1]
|
||||
elif parent_col not in data.columns:
|
||||
raise ValueError(f"Parent column not in data, check `parent_col`: {parent_col}")
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(child_col)
|
||||
attribute_cols.remove(parent_col)
|
||||
elif any([col not in data.columns for col in attribute_cols]):
|
||||
raise ValueError(
|
||||
f"One or more attribute column(s) not in data, check `attribute_cols`: {attribute_cols}"
|
||||
)
|
||||
|
||||
data_check = data.copy()[[child_col, parent_col] + attribute_cols].drop_duplicates(
|
||||
subset=[child_col] + attribute_cols
|
||||
)
|
||||
_duplicate_check = (
|
||||
data_check[child_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(child_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate child name with different attributes\n"
|
||||
f"Check {_duplicate_check}"
|
||||
)
|
||||
if sum(data[child_col].isnull()):
|
||||
raise ValueError(f"Child name cannot be empty, check column: {child_col}")
|
||||
|
||||
node_dict: Dict[str, DAGNode] = dict()
|
||||
parent_node = DAGNode()
|
||||
|
||||
for row in data.reset_index(drop=True).to_dict(orient="index").values():
|
||||
child_name = row[child_col]
|
||||
parent_name = row[parent_col]
|
||||
node_attrs = row.copy()
|
||||
del node_attrs[child_col]
|
||||
del node_attrs[parent_col]
|
||||
node_attrs = {k: v for k, v in node_attrs.items() if not pd.isnull(v)}
|
||||
child_node = node_dict.get(child_name, node_type(child_name))
|
||||
child_node.set_attrs(node_attrs)
|
||||
node_dict[child_name] = child_node
|
||||
|
||||
if not pd.isnull(parent_name):
|
||||
parent_node = node_dict.get(parent_name, node_type(parent_name))
|
||||
node_dict[parent_name] = parent_node
|
||||
child_node.parents = [parent_node]
|
||||
|
||||
return parent_node
|
||||
298
python310/packages/bigtree/dag/export.py
Normal file
298
python310/packages/bigtree/dag/export.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple, TypeVar, Union
|
||||
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.utils.exceptions import (
|
||||
optional_dependencies_image,
|
||||
optional_dependencies_pandas,
|
||||
)
|
||||
from bigtree.utils.iterators import dag_iterator
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError: # pragma: no cover
|
||||
pd = None
|
||||
|
||||
try:
|
||||
import pydot
|
||||
except ImportError: # pragma: no cover
|
||||
pydot = None
|
||||
|
||||
__all__ = ["dag_to_list", "dag_to_dict", "dag_to_dataframe", "dag_to_dot"]
|
||||
|
||||
|
||||
T = TypeVar("T", bound=DAGNode)
|
||||
|
||||
|
||||
def dag_to_list(
|
||||
dag: T,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Export DAG to list of tuples containing parent-child names
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode, dag_to_list
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_list(a)
|
||||
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
|
||||
Returns:
|
||||
(List[Tuple[str, str]])
|
||||
"""
|
||||
relations = []
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
relations.append((parent_node.node_name, child_node.node_name))
|
||||
return relations
|
||||
|
||||
|
||||
def dag_to_dict(
|
||||
dag: T,
|
||||
parent_key: str = "parents",
|
||||
attr_dict: Dict[str, str] = {},
|
||||
all_attrs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Export DAG to dictionary.
|
||||
|
||||
Exported dictionary will have key as child name, and parent names and node attributes as a nested dictionary.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode, dag_to_dict
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_dict(a, parent_key="parent", attr_dict={"step": "step no."})
|
||||
{'a': {'step no.': 1}, 'c': {'parent': ['a', 'b'], 'step no.': 2}, 'd': {'parent': ['a', 'c'], 'step no.': 2}, 'b': {'step no.': 1}, 'e': {'parent': ['d'], 'step no.': 3}}
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
parent_key (str): dictionary key for `node.parent.node_name`, defaults to `parents`
|
||||
attr_dict (Dict[str, str]): dictionary mapping node attributes to dictionary key,
|
||||
key: node attributes, value: corresponding dictionary key, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes, defaults to False
|
||||
|
||||
Returns:
|
||||
(Dict[str, Any])
|
||||
"""
|
||||
dag = dag.copy()
|
||||
data_dict = {}
|
||||
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
if parent_node.is_root:
|
||||
data_parent: Dict[str, Any] = {}
|
||||
if all_attrs:
|
||||
data_parent.update(
|
||||
parent_node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_parent[v] = parent_node.get_attr(k)
|
||||
data_dict[parent_node.node_name] = data_parent
|
||||
|
||||
if data_dict.get(child_node.node_name):
|
||||
data_dict[child_node.node_name][parent_key].append(parent_node.node_name)
|
||||
else:
|
||||
data_child = {parent_key: [parent_node.node_name]}
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = child_node.get_attr(k)
|
||||
data_dict[child_node.node_name] = data_child
|
||||
return data_dict
|
||||
|
||||
|
||||
@optional_dependencies_pandas
|
||||
def dag_to_dataframe(
|
||||
dag: T,
|
||||
name_col: str = "name",
|
||||
parent_col: str = "parent",
|
||||
attr_dict: Dict[str, str] = {},
|
||||
all_attrs: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Export DAG to pandas DataFrame.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode, dag_to_dataframe
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_dataframe(a, name_col="name", parent_col="parent", attr_dict={"step": "step no."})
|
||||
name parent step no.
|
||||
0 a None 1
|
||||
1 c a 2
|
||||
2 d a 2
|
||||
3 b None 1
|
||||
4 c b 2
|
||||
5 d c 2
|
||||
6 e d 3
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
name_col (str): column name for `node.node_name`, defaults to 'name'
|
||||
parent_col (str): column name for `node.parent.node_name`, defaults to 'parent'
|
||||
attr_dict (Dict[str, str]): dictionary mapping node attributes to column name,
|
||||
key: node attributes, value: corresponding column in dataframe, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes, defaults to False
|
||||
|
||||
Returns:
|
||||
(pd.DataFrame)
|
||||
"""
|
||||
dag = dag.copy()
|
||||
data_list: List[Dict[str, Any]] = []
|
||||
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
if parent_node.is_root:
|
||||
data_parent = {name_col: parent_node.node_name, parent_col: None}
|
||||
if all_attrs:
|
||||
data_parent.update(
|
||||
parent_node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_parent[v] = parent_node.get_attr(k)
|
||||
data_list.append(data_parent)
|
||||
|
||||
data_child = {name_col: child_node.node_name, parent_col: parent_node.node_name}
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = child_node.get_attr(k)
|
||||
data_list.append(data_child)
|
||||
return pd.DataFrame(data_list).drop_duplicates().reset_index(drop=True)
|
||||
|
||||
|
||||
@optional_dependencies_image("pydot")
|
||||
def dag_to_dot(
|
||||
dag: Union[T, List[T]],
|
||||
rankdir: str = "TB",
|
||||
bg_colour: str = "",
|
||||
node_colour: str = "",
|
||||
node_shape: str = "",
|
||||
edge_colour: str = "",
|
||||
node_attr: str = "",
|
||||
edge_attr: str = "",
|
||||
) -> pydot.Dot:
|
||||
r"""Export DAG or list of DAGs to image.
|
||||
Note that node names must be unique.
|
||||
Possible node attributes include style, fillcolor, shape.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode, dag_to_dot
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_graph = dag_to_dot(a)
|
||||
|
||||
Display image directly without saving (requires IPython)
|
||||
|
||||
>>> from IPython.display import Image, display
|
||||
>>> plt = Image(dag_graph.create_png())
|
||||
>>> display(plt)
|
||||
<IPython.core.display.Image object>
|
||||
|
||||
Export to image, dot file, etc.
|
||||
|
||||
>>> dag_graph.write_png("assets/docstr/tree_dag.png")
|
||||
>>> dag_graph.write_dot("assets/docstr/tree_dag.dot")
|
||||
|
||||
Export to string
|
||||
|
||||
>>> dag_graph.to_string()
|
||||
'strict digraph G {\nrankdir=TB;\nc [label=c];\na [label=a];\na -> c;\nd [label=d];\na [label=a];\na -> d;\nc [label=c];\nb [label=b];\nb -> c;\nd [label=d];\nc [label=c];\nc -> d;\ne [label=e];\nd [label=d];\nd -> e;\n}\n'
|
||||
|
||||
Args:
|
||||
dag (Union[DAGNode, List[DAGNode]]): DAG or list of DAGs to be exported
|
||||
rankdir (str): set direction of graph layout, defaults to 'TB', can be 'BT, 'LR', 'RL'
|
||||
bg_colour (str): background color of image, defaults to ''
|
||||
node_colour (str): fill colour of nodes, defaults to ''
|
||||
node_shape (str): shape of nodes, defaults to None
|
||||
Possible node_shape include "circle", "square", "diamond", "triangle"
|
||||
edge_colour (str): colour of edges, defaults to ''
|
||||
node_attr (str): node attribute for style, overrides node_colour, defaults to ''
|
||||
Possible node attributes include {"style": "filled", "fillcolor": "gold"}
|
||||
edge_attr (str): edge attribute for style, overrides edge_colour, defaults to ''
|
||||
Possible edge attributes include {"style": "bold", "label": "edge label", "color": "black"}
|
||||
|
||||
Returns:
|
||||
(pydot.Dot)
|
||||
"""
|
||||
# Get style
|
||||
if bg_colour:
|
||||
graph_style = dict(bgcolor=bg_colour)
|
||||
else:
|
||||
graph_style = dict()
|
||||
|
||||
if node_colour:
|
||||
node_style = dict(style="filled", fillcolor=node_colour)
|
||||
else:
|
||||
node_style = dict()
|
||||
|
||||
if node_shape:
|
||||
node_style["shape"] = node_shape
|
||||
|
||||
if edge_colour:
|
||||
edge_style = dict(color=edge_colour)
|
||||
else:
|
||||
edge_style = dict()
|
||||
|
||||
_graph = pydot.Dot(
|
||||
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
|
||||
)
|
||||
|
||||
if not isinstance(dag, list):
|
||||
dag = [dag]
|
||||
|
||||
for _dag in dag:
|
||||
if not isinstance(_dag, DAGNode):
|
||||
raise TypeError(
|
||||
"Tree should be of type `DAGNode`, or inherit from `DAGNode`"
|
||||
)
|
||||
_dag = _dag.copy()
|
||||
|
||||
for parent_node, child_node in dag_iterator(_dag):
|
||||
_node_style = node_style.copy()
|
||||
_edge_style = edge_style.copy()
|
||||
|
||||
child_name = child_node.name
|
||||
if node_attr and child_node.get_attr(node_attr):
|
||||
_node_style.update(child_node.get_attr(node_attr))
|
||||
if edge_attr and child_node.get_attr(edge_attr):
|
||||
_edge_style.update(child_node.get_attr(edge_attr))
|
||||
pydot_child = pydot.Node(name=child_name, label=child_name, **_node_style)
|
||||
_graph.add_node(pydot_child)
|
||||
|
||||
parent_name = parent_node.name
|
||||
parent_node_style = node_style.copy()
|
||||
if node_attr and parent_node.get_attr(node_attr):
|
||||
parent_node_style.update(parent_node.get_attr(node_attr))
|
||||
pydot_parent = pydot.Node(
|
||||
name=parent_name, label=parent_name, **parent_node_style
|
||||
)
|
||||
_graph.add_node(pydot_parent)
|
||||
|
||||
edge = pydot.Edge(parent_name, child_name, **_edge_style)
|
||||
_graph.add_edge(edge)
|
||||
|
||||
return _graph
|
||||
3
python310/packages/bigtree/globals.py
Normal file
3
python310/packages/bigtree/globals.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
ASSERTIONS: bool = bool(os.environ.get("BIGTREE_CONF_ASSERTIONS", True))
|
||||
0
python310/packages/bigtree/node/__init__.py
Normal file
0
python310/packages/bigtree/node/__init__.py
Normal file
780
python310/packages/bigtree/node/basenode.py
Normal file
780
python310/packages/bigtree/node/basenode.py
Normal file
@@ -0,0 +1,780 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
from bigtree.globals import ASSERTIONS
|
||||
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
|
||||
class BaseNode:
|
||||
"""
|
||||
BaseNode extends any Python class to a tree node.
|
||||
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
Nodes can be linked to each other with `parent` and `children` setter methods,
|
||||
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65)
|
||||
>>> c = Node("c", age=60)
|
||||
>>> d = Node("d", age=40)
|
||||
>>> root.children = [b, c]
|
||||
>>> d.parent = b
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ └── d [age=40]
|
||||
└── c [age=60]
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65)
|
||||
>>> c = Node("c", age=60)
|
||||
>>> d = Node("d", age=40)
|
||||
>>> root >> b
|
||||
>>> root >> c
|
||||
>>> d << b
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ └── d [age=40]
|
||||
└── c [age=60]
|
||||
|
||||
Directly passing `parent` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=b)
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> d = Node("d")
|
||||
>>> c = Node("c")
|
||||
>>> b = Node("b", children=[d])
|
||||
>>> a = Node("a", children=[b, c])
|
||||
|
||||
**BaseNode Creation**
|
||||
|
||||
Node can be created by instantiating a `BaseNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**BaseNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set other `BaseNode`
|
||||
|
||||
1. ``parent``: Get/set parent node
|
||||
2. ``children``: Get/set child nodes
|
||||
|
||||
Get other `BaseNode`
|
||||
|
||||
1. ``ancestors``: Get ancestors of node excluding self, iterator
|
||||
2. ``descendants``: Get descendants of node excluding self, iterator
|
||||
3. ``leaves``: Get all leaf node(s) from self, iterator
|
||||
4. ``siblings``: Get siblings of self
|
||||
5. ``left_sibling``: Get sibling left of self
|
||||
6. ``right_sibling``: Get sibling right of self
|
||||
|
||||
Get `BaseNode` configuration
|
||||
|
||||
1. ``node_path``: Get tuple of nodes from root
|
||||
2. ``is_root``: Get indicator if self is root node
|
||||
3. ``is_leaf``: Get indicator if self is leaf node
|
||||
4. ``root``: Get root node of tree
|
||||
5. ``depth``: Get depth of self
|
||||
6. ``max_depth``: Get maximum depth from root to leaf node
|
||||
|
||||
**BaseNode Methods**
|
||||
|
||||
These are methods available to be performed on `BaseNode`.
|
||||
|
||||
Constructor methods
|
||||
|
||||
1. ``from_dict()``: Create BaseNode from dictionary
|
||||
|
||||
`BaseNode` methods
|
||||
|
||||
1. ``describe()``: Get node information sorted by attributes, return list of tuples
|
||||
2. ``get_attr(attr_name: str)``: Get value of node attribute
|
||||
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
|
||||
4. ``go_to(node: Self)``: Get a path from own node to another node from same tree
|
||||
5. ``append(node: Self)``: Add child to node
|
||||
6. ``extend(nodes: List[Self])``: Add multiple children to node
|
||||
7. ``copy()``: Deep copy self
|
||||
8. ``sort()``: Sort child nodes
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: Optional[T] = None,
|
||||
children: Optional[List[T]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.__parent: Optional[T] = None
|
||||
self.__children: List[T] = []
|
||||
if children is None:
|
||||
children = []
|
||||
self.parent = parent
|
||||
self.children = children # type: ignore
|
||||
if "parents" in kwargs:
|
||||
raise AttributeError(
|
||||
"Attempting to set `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parent: T) -> None:
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if not (isinstance(new_parent, BaseNode) or new_parent is None):
|
||||
raise TypeError(
|
||||
f"Expect parent to be BaseNode type or NoneType, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
def __check_parent_loop(self, new_parent: T) -> None:
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if new_parent is not None:
|
||||
if new_parent is self:
|
||||
raise LoopError("Error setting parent: Node cannot be parent of itself")
|
||||
if any(
|
||||
ancestor is self
|
||||
for ancestor in new_parent.ancestors
|
||||
if new_parent.ancestors
|
||||
):
|
||||
raise LoopError(
|
||||
"Error setting parent: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
@property
|
||||
def parent(self: T) -> Optional[T]:
|
||||
"""Get parent node
|
||||
|
||||
Returns:
|
||||
(Optional[Self])
|
||||
"""
|
||||
return self.__parent
|
||||
|
||||
@parent.setter
|
||||
def parent(self: T, new_parent: T) -> None:
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if ASSERTIONS:
|
||||
self.__check_parent_type(new_parent)
|
||||
self.__check_parent_loop(new_parent)
|
||||
|
||||
current_parent = self.parent
|
||||
current_child_idx = None
|
||||
|
||||
# Assign new parent - rollback if error
|
||||
self.__pre_assign_parent(new_parent)
|
||||
try:
|
||||
# Remove self from old parent
|
||||
if current_parent is not None:
|
||||
if not any(
|
||||
child is self for child in current_parent.children
|
||||
): # pragma: no cover
|
||||
raise CorruptedTreeError(
|
||||
"Error setting parent: Node does not exist as children of its parent"
|
||||
)
|
||||
current_child_idx = current_parent.__children.index(self)
|
||||
current_parent.__children.remove(self)
|
||||
|
||||
# Assign self to new parent
|
||||
self.__parent = new_parent
|
||||
if new_parent is not None:
|
||||
new_parent.__children.append(self)
|
||||
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
if new_parent is not None:
|
||||
new_parent.__children.remove(self)
|
||||
|
||||
# Reassign self to old parent
|
||||
self.__parent = current_parent
|
||||
if current_child_idx is not None:
|
||||
current_parent.__children.insert(current_child_idx, self)
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_parent(self, new_parent: T) -> None:
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overridden with `_BaseNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self, new_parent: T) -> None:
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overridden with `_BaseNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def parents(self) -> None:
|
||||
"""Do not allow `parents` attribute to be accessed
|
||||
|
||||
Raises:
|
||||
AttributeError: No such attribute
|
||||
"""
|
||||
raise AttributeError(
|
||||
"Attempting to access `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
|
||||
@parents.setter
|
||||
def parents(self, new_parent: T) -> None:
|
||||
"""Do not allow `parents` attribute to be set
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
|
||||
Raises:
|
||||
AttributeError: No such attribute
|
||||
"""
|
||||
raise AttributeError(
|
||||
"Attempting to set `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
|
||||
def __check_children_type(
|
||||
self: T, new_children: List[T] | Tuple[T] | Set[T]
|
||||
) -> None:
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): child node
|
||||
"""
|
||||
if (
|
||||
not isinstance(new_children, list)
|
||||
and not isinstance(new_children, tuple)
|
||||
and not isinstance(new_children, set)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Expect children to be List or Tuple or Set type, received input type {type(new_children)}"
|
||||
)
|
||||
|
||||
def __check_children_loop(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if not isinstance(new_child, BaseNode):
|
||||
raise TypeError(
|
||||
f"Expect children to be BaseNode type, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@property
|
||||
def children(self: T) -> Tuple[T, ...]:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Tuple[Self, ...])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
@children.setter
|
||||
def children(self: T, new_children: List[T] | Tuple[T] | Set[T]) -> None:
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
if ASSERTIONS:
|
||||
self.__check_children_type(new_children)
|
||||
self.__check_children_loop(new_children)
|
||||
new_children = list(new_children)
|
||||
|
||||
current_new_children = {
|
||||
new_child: (new_child.parent.__children.index(new_child), new_child.parent)
|
||||
for new_child in new_children
|
||||
if new_child.parent is not None
|
||||
}
|
||||
current_new_orphan = [
|
||||
new_child for new_child in new_children if new_child.parent is None
|
||||
]
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Remove old children from self
|
||||
del self.children
|
||||
|
||||
# Assign new children to self
|
||||
self.__children = new_children
|
||||
for new_child in new_children:
|
||||
if new_child.parent:
|
||||
new_child.parent.__children.remove(new_child)
|
||||
new_child.__parent = self
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign new children to their original parent
|
||||
for child, idx_parent in current_new_children.items():
|
||||
child_idx, parent = idx_parent
|
||||
child.__parent = parent
|
||||
parent.__children.insert(child_idx, child)
|
||||
for child in current_new_orphan:
|
||||
child.__parent = None
|
||||
|
||||
# Reassign old children to self
|
||||
self.__children = current_children
|
||||
for child in current_children:
|
||||
child.__parent = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
@children.deleter
|
||||
def children(self) -> None:
|
||||
"""Delete child node(s)"""
|
||||
for child in self.children:
|
||||
child.parent.__children.remove(child) # type: ignore
|
||||
child.__parent = None
|
||||
|
||||
def __pre_assign_children(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Custom method to check before attaching children
|
||||
Can be overridden with `_BaseNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Custom method to check after attaching children
|
||||
Can be overridden with `_BaseNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ancestors(self: T) -> Iterable[T]:
|
||||
"""Get iterator to yield all ancestors of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
node = self.parent
|
||||
while node is not None:
|
||||
yield node
|
||||
node = node.parent
|
||||
|
||||
@property
|
||||
def descendants(self: T) -> Iterable[T]:
|
||||
"""Get iterator to yield all descendants of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
yield from preorder_iter(self, filter_condition=lambda _node: _node != self)
|
||||
|
||||
@property
|
||||
def leaves(self: T) -> Iterable[T]:
|
||||
"""Get iterator to yield all leaf nodes from self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
yield from preorder_iter(self, filter_condition=lambda _node: _node.is_leaf)
|
||||
|
||||
@property
|
||||
def siblings(self: T) -> Iterable[T]:
|
||||
"""Get siblings of self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.parent is None:
|
||||
return ()
|
||||
return tuple(child for child in self.parent.children if child is not self)
|
||||
|
||||
@property
|
||||
def left_sibling(self: T) -> T:
|
||||
"""Get sibling left of self
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.parent:
|
||||
children = self.parent.children
|
||||
child_idx = children.index(self)
|
||||
if child_idx:
|
||||
return self.parent.children[child_idx - 1]
|
||||
|
||||
@property
|
||||
def right_sibling(self: T) -> T:
|
||||
"""Get sibling right of self
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.parent:
|
||||
children = self.parent.children
|
||||
child_idx = children.index(self)
|
||||
if child_idx + 1 < len(children):
|
||||
return self.parent.children[child_idx + 1]
|
||||
|
||||
@property
|
||||
def node_path(self: T) -> Iterable[T]:
|
||||
"""Get tuple of nodes starting from root
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.parent is None:
|
||||
return [self]
|
||||
return tuple(list(self.parent.node_path) + [self])
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Get indicator if self is root node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return self.parent is None
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.children))
|
||||
|
||||
@property
|
||||
def root(self: T) -> T:
|
||||
"""Get root node of tree
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.parent is None:
|
||||
return self
|
||||
return self.parent.root
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
"""Get depth of self, indexing starts from 1
|
||||
|
||||
Returns:
|
||||
(int)
|
||||
"""
|
||||
if self.parent is None:
|
||||
return 1
|
||||
return self.parent.depth + 1
|
||||
|
||||
@property
|
||||
def max_depth(self) -> int:
|
||||
"""Get maximum depth from root to leaf node
|
||||
|
||||
Returns:
|
||||
(int)
|
||||
"""
|
||||
return max(
|
||||
[self.root.depth] + [node.depth for node in list(self.root.descendants)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, input_dict: Dict[str, Any]) -> BaseNode:
|
||||
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
|
||||
Input dictionary must have key `name` if not `Node` will not have any name
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
Args:
|
||||
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
return cls(**input_dict)
|
||||
|
||||
def describe(
|
||||
self, exclude_attributes: List[str] = [], exclude_prefix: str = ""
|
||||
) -> List[Tuple[str, Any]]:
|
||||
"""Get node information sorted by attribute name, returns list of tuples
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a', age=90)
|
||||
>>> a.describe()
|
||||
[('_BaseNode__children', []), ('_BaseNode__parent', None), ('_sep', '/'), ('age', 90), ('name', 'a')]
|
||||
>>> a.describe(exclude_prefix="_")
|
||||
[('age', 90), ('name', 'a')]
|
||||
>>> a.describe(exclude_prefix="_", exclude_attributes=["name"])
|
||||
[('age', 90)]
|
||||
|
||||
Args:
|
||||
exclude_attributes (List[str]): list of attributes to exclude
|
||||
exclude_prefix (str): prefix of attributes to exclude
|
||||
|
||||
Returns:
|
||||
(List[Tuple[str, Any]])
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
|
||||
if (item[0] not in exclude_attributes)
|
||||
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
|
||||
]
|
||||
|
||||
def get_attr(self, attr_name: str, default_value: Any = None) -> Any:
|
||||
"""Get value of node attribute
|
||||
Returns default value if attribute name does not exist
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a', age=90)
|
||||
>>> a.get_attr("age")
|
||||
90
|
||||
|
||||
Args:
|
||||
attr_name (str): attribute name
|
||||
default_value (Any): default value if attribute does not exist, defaults to None
|
||||
|
||||
Returns:
|
||||
(Any)
|
||||
"""
|
||||
try:
|
||||
return getattr(self, attr_name)
|
||||
except AttributeError:
|
||||
return default_value
|
||||
|
||||
def set_attrs(self, attrs: Dict[str, Any]) -> None:
|
||||
"""Set node attributes
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a.set_attrs({"age": 90})
|
||||
>>> a
|
||||
Node(/a, age=90)
|
||||
|
||||
Args:
|
||||
attrs (Dict[str, Any]): attribute dictionary,
|
||||
key: attribute name, value: attribute value
|
||||
"""
|
||||
self.__dict__.update(attrs)
|
||||
|
||||
def go_to(self: T, node: T) -> Iterable[T]:
|
||||
"""Get path from current node to specified node from same tree
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> a = Node(name="a")
|
||||
>>> b = Node(name="b", parent=a)
|
||||
>>> c = Node(name="c", parent=a)
|
||||
>>> d = Node(name="d", parent=b)
|
||||
>>> e = Node(name="e", parent=b)
|
||||
>>> f = Node(name="f", parent=c)
|
||||
>>> g = Node(name="g", parent=e)
|
||||
>>> h = Node(name="h", parent=e)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
>>> d.go_to(d)
|
||||
[Node(/a/b/d, )]
|
||||
>>> d.go_to(g)
|
||||
[Node(/a/b/d, ), Node(/a/b, ), Node(/a/b/e, ), Node(/a/b/e/g, )]
|
||||
>>> d.go_to(f)
|
||||
[Node(/a/b/d, ), Node(/a/b, ), Node(/a, ), Node(/a/c, ), Node(/a/c/f, )]
|
||||
|
||||
Args:
|
||||
node (Self): node to travel to from current node, inclusive of start and end node
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if not isinstance(node, BaseNode):
|
||||
raise TypeError(
|
||||
f"Expect node to be BaseNode type, received input type {type(node)}"
|
||||
)
|
||||
if self.root != node.root:
|
||||
raise TreeError(
|
||||
f"Nodes are not from the same tree. Check {self} and {node}"
|
||||
)
|
||||
if self == node:
|
||||
return [self]
|
||||
self_path = [self] + list(self.ancestors)
|
||||
node_path = ([node] + list(node.ancestors))[::-1]
|
||||
common_nodes = set(self_path).intersection(set(node_path))
|
||||
self_min_index, min_common_node = sorted(
|
||||
[(self_path.index(_node), _node) for _node in common_nodes]
|
||||
)[0]
|
||||
node_min_index = node_path.index(min_common_node)
|
||||
return self_path[:self_min_index] + node_path[node_min_index:]
|
||||
|
||||
def append(self: T, other: T) -> None:
|
||||
"""Add other as child of self
|
||||
|
||||
Args:
|
||||
other (Self): other node, child to be added
|
||||
"""
|
||||
other.parent = self
|
||||
|
||||
def extend(self: T, others: List[T]) -> None:
|
||||
"""Add others as children of self
|
||||
|
||||
Args:
|
||||
others (Self): other nodes, children to be added
|
||||
"""
|
||||
for child in others:
|
||||
child.parent = self
|
||||
|
||||
def copy(self: T) -> T:
|
||||
"""Deep copy self; clone self
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a_copy = a.copy()
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def sort(self: T, **kwargs: Any) -> None:
|
||||
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> a = Node('a')
|
||||
>>> c = Node("c", parent=a)
|
||||
>>> b = Node("b", parent=a)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── c
|
||||
└── b
|
||||
>>> a.sort(key=lambda node: node.name)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
"""
|
||||
children = list(self.children)
|
||||
children.sort(**kwargs)
|
||||
self.__children = children
|
||||
|
||||
def __copy__(self: T) -> T:
|
||||
"""Shallow copy self
|
||||
|
||||
Examples:
|
||||
>>> import copy
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a_copy = copy.deepcopy(a)
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
obj: T = type(self).__new__(self.__class__)
|
||||
obj.__dict__.update(self.__dict__)
|
||||
return obj
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print format of BaseNode
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_")
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({node_description})"
|
||||
|
||||
def __rshift__(self: T, other: T) -> None:
|
||||
"""Set children using >> bitshift operator for self >> children (other)
|
||||
|
||||
Args:
|
||||
other (Self): other node, children
|
||||
"""
|
||||
other.parent = self
|
||||
|
||||
def __lshift__(self: T, other: T) -> None:
|
||||
"""Set parent using << bitshift operator for self << parent (other)
|
||||
|
||||
Args:
|
||||
other (Self): other node, parent
|
||||
"""
|
||||
self.parent = other
|
||||
|
||||
def __iter__(self) -> Generator[T, None, None]:
|
||||
"""Iterate through child nodes
|
||||
|
||||
Returns:
|
||||
(Self): child node
|
||||
"""
|
||||
yield from self.children # type: ignore
|
||||
|
||||
def __contains__(self, other_node: T) -> bool:
|
||||
"""Check if child node exists
|
||||
|
||||
Args:
|
||||
other_node (T): child node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return other_node in self.children
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseNode)
|
||||
418
python310/packages/bigtree/node/binarynode.py
Normal file
418
python310/packages/bigtree/node/binarynode.py
Normal file
@@ -0,0 +1,418 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from bigtree.globals import ASSERTIONS
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
|
||||
|
||||
|
||||
class BinaryNode(Node):
|
||||
"""
|
||||
BinaryNode is an extension of Node, and is able to extend to any Python class for Binary Tree implementation.
|
||||
Nodes can have attributes if they are initialized from `BinaryNode`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
BinaryNode can be linked to each other with `children`, `left`, or `right` setter methods.
|
||||
If initialized with `children`, it must be length 2, denoting left and right child.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import BinaryNode, print_tree
|
||||
>>> a = BinaryNode(1)
|
||||
>>> b = BinaryNode(2)
|
||||
>>> c = BinaryNode(3)
|
||||
>>> d = BinaryNode(4)
|
||||
>>> a.children = [b, c]
|
||||
>>> b.right = d
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 2
|
||||
│ └── 4
|
||||
└── 3
|
||||
|
||||
Directly passing `left`, `right`, or `children` argument.
|
||||
|
||||
>>> from bigtree import BinaryNode
|
||||
>>> d = BinaryNode(4)
|
||||
>>> c = BinaryNode(3)
|
||||
>>> b = BinaryNode(2, right=d)
|
||||
>>> a = BinaryNode(1, children=[b, c])
|
||||
|
||||
**BinaryNode Creation**
|
||||
|
||||
Node can be created by instantiating a `BinaryNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import BinaryNode
|
||||
>>> a = BinaryNode.from_dict({"name": "1"})
|
||||
>>> a
|
||||
BinaryNode(name=1, val=1)
|
||||
|
||||
**BinaryNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get `BinaryNode` configuration
|
||||
|
||||
1. ``left``: Get left children
|
||||
2. ``right``: Get right children
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, int] = "",
|
||||
left: Optional[T] = None,
|
||||
right: Optional[T] = None,
|
||||
parent: Optional[T] = None,
|
||||
children: Optional[List[Optional[T]]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
self.val: Union[str, int] = int(name)
|
||||
except ValueError:
|
||||
self.val = str(name)
|
||||
self.name = str(name)
|
||||
self._sep = "/"
|
||||
self.__parent: Optional[T] = None
|
||||
self.__children: List[Optional[T]] = [None, None]
|
||||
if not children:
|
||||
children = []
|
||||
if len(children):
|
||||
if len(children) and len(children) != 2:
|
||||
raise ValueError("Children input must have length 2")
|
||||
if left and left != children[0]:
|
||||
raise ValueError(
|
||||
f"Error setting child: Attempting to set both left and children with mismatched values\n"
|
||||
f"Check left {left} and children {children}"
|
||||
)
|
||||
if right and right != children[1]:
|
||||
raise ValueError(
|
||||
f"Error setting child: Attempting to set both right and children with mismatched values\n"
|
||||
f"Check right {right} and children {children}"
|
||||
)
|
||||
else:
|
||||
children = [left, right]
|
||||
self.parent = parent
|
||||
self.children = children # type: ignore
|
||||
if "parents" in kwargs:
|
||||
raise AttributeError(
|
||||
"Attempting to set `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@property
|
||||
def left(self: T) -> T:
|
||||
"""Get left children
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__children[0]
|
||||
|
||||
@left.setter
|
||||
def left(self: T, left_child: Optional[T]) -> None:
|
||||
"""Set left children
|
||||
|
||||
Args:
|
||||
left_child (Optional[Self]): left child
|
||||
"""
|
||||
self.children = [left_child, self.right] # type: ignore
|
||||
|
||||
@property
|
||||
def right(self: T) -> T:
|
||||
"""Get right children
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__children[1]
|
||||
|
||||
@right.setter
|
||||
def right(self: T, right_child: Optional[T]) -> None:
|
||||
"""Set right children
|
||||
|
||||
Args:
|
||||
right_child (Optional[Self]): right child
|
||||
"""
|
||||
self.children = [self.left, right_child] # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parent: Optional[T]) -> None:
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Optional[Self]): parent node
|
||||
"""
|
||||
if not (isinstance(new_parent, BinaryNode) or new_parent is None):
|
||||
raise TypeError(
|
||||
f"Expect parent to be BinaryNode type or NoneType, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def parent(self: T) -> Optional[T]:
|
||||
"""Get parent node
|
||||
|
||||
Returns:
|
||||
(Optional[Self])
|
||||
"""
|
||||
return self.__parent
|
||||
|
||||
@parent.setter
|
||||
def parent(self: T, new_parent: Optional[T]) -> None:
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parent (Optional[Self]): parent node
|
||||
"""
|
||||
if ASSERTIONS:
|
||||
self.__check_parent_type(new_parent)
|
||||
self._BaseNode__check_parent_loop(new_parent) # type: ignore
|
||||
|
||||
current_parent = self.parent
|
||||
current_child_idx = None
|
||||
|
||||
# Assign new parent - rollback if error
|
||||
self.__pre_assign_parent(new_parent)
|
||||
try:
|
||||
# Remove self from old parent
|
||||
if current_parent is not None:
|
||||
if not any(
|
||||
child is self for child in current_parent.children
|
||||
): # pragma: no cover
|
||||
raise CorruptedTreeError(
|
||||
"Error setting parent: Node does not exist as children of its parent"
|
||||
)
|
||||
current_child_idx = current_parent.__children.index(self)
|
||||
current_parent.__children[current_child_idx] = None
|
||||
|
||||
# Assign self to new parent
|
||||
self.__parent = new_parent
|
||||
if new_parent is not None:
|
||||
inserted = False
|
||||
for child_idx, child in enumerate(new_parent.__children):
|
||||
if not child and not inserted:
|
||||
new_parent.__children[child_idx] = self
|
||||
inserted = True
|
||||
if not inserted:
|
||||
raise TreeError(f"Parent {new_parent} already has 2 children")
|
||||
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
if new_parent is not None and self in new_parent.__children:
|
||||
child_idx = new_parent.__children.index(self)
|
||||
new_parent.__children[child_idx] = None
|
||||
|
||||
# Reassign self to old parent
|
||||
self.__parent = current_parent
|
||||
if current_child_idx is not None:
|
||||
current_parent.__children[current_child_idx] = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_parent(self: T, new_parent: Optional[T]) -> None:
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overridden with `_BinaryNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Optional[Self]): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self: T, new_parent: Optional[T]) -> None:
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overridden with `_BinaryNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Optional[Self]): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __check_children_type(
|
||||
self: T, new_children: List[Optional[T]]
|
||||
) -> List[Optional[T]]:
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (List[Optional[Self]]): child node
|
||||
|
||||
Returns:
|
||||
(List[Optional[Self]])
|
||||
"""
|
||||
if not len(new_children):
|
||||
new_children = [None, None]
|
||||
if len(new_children) != 2:
|
||||
raise ValueError("Children input must have length 2")
|
||||
return new_children
|
||||
|
||||
def __check_children_loop(self: T, new_children: List[Optional[T]]) -> None:
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (List[Optional[Self]]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if new_child is not None and not isinstance(new_child, BinaryNode):
|
||||
raise TypeError(
|
||||
f"Expect children to be BinaryNode type or NoneType, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if new_child is not None:
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@property
|
||||
def children(self: T) -> Tuple[T, ...]:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Tuple[Optional[Self]])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
@children.setter
|
||||
def children(self: T, _new_children: List[Optional[T]]) -> None:
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
_new_children (List[Optional[Self]]): child node
|
||||
"""
|
||||
self._BaseNode__check_children_type(_new_children) # type: ignore
|
||||
new_children = self.__check_children_type(_new_children)
|
||||
if ASSERTIONS:
|
||||
self.__check_children_loop(new_children)
|
||||
|
||||
current_new_children = {
|
||||
new_child: (
|
||||
new_child.parent.__children.index(new_child),
|
||||
new_child.parent,
|
||||
)
|
||||
for new_child in new_children
|
||||
if new_child is not None and new_child.parent is not None
|
||||
}
|
||||
current_new_orphan = [
|
||||
new_child
|
||||
for new_child in new_children
|
||||
if new_child is not None and new_child.parent is None
|
||||
]
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Remove old children from self
|
||||
del self.children
|
||||
|
||||
# Assign new children to self
|
||||
self.__children = new_children
|
||||
for new_child in new_children:
|
||||
if new_child is not None:
|
||||
if new_child.parent:
|
||||
child_idx = new_child.parent.__children.index(new_child)
|
||||
new_child.parent.__children[child_idx] = None
|
||||
new_child.__parent = self
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign new children to their original parent
|
||||
for child, idx_parent in current_new_children.items():
|
||||
child_idx, parent = idx_parent
|
||||
child.__parent = parent
|
||||
parent.__children[child_idx] = child
|
||||
for child in current_new_orphan:
|
||||
child.__parent = None
|
||||
|
||||
# Reassign old children to self
|
||||
self.__children = current_children
|
||||
for child in current_children:
|
||||
if child:
|
||||
child.__parent = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
@children.deleter
|
||||
def children(self) -> None:
|
||||
"""Delete child node(s)"""
|
||||
for child in self.children:
|
||||
if child is not None:
|
||||
child.parent.__children.remove(child) # type: ignore
|
||||
child.__parent = None
|
||||
|
||||
def __pre_assign_children(self: T, new_children: List[Optional[T]]) -> None:
|
||||
"""Custom method to check before attaching children
|
||||
Can be overridden with `_BinaryNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Optional[Self]]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self: T, new_children: List[Optional[T]]) -> None:
|
||||
"""Custom method to check after attaching children
|
||||
Can be overridden with `_BinaryNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Optional[Self]]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len([child for child in self.children if child])
|
||||
|
||||
def sort(self, **kwargs: Any) -> None:
|
||||
"""Sort children, possible keyword arguments include ``key=lambda node: node.val``, ``reverse=True``
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import BinaryNode, print_tree
|
||||
>>> a = BinaryNode(1)
|
||||
>>> c = BinaryNode(3, parent=a)
|
||||
>>> b = BinaryNode(2, parent=a)
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 3
|
||||
└── 2
|
||||
>>> a.sort(key=lambda node: node.val)
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 2
|
||||
└── 3
|
||||
"""
|
||||
children = [child for child in self.children if child]
|
||||
if len(children) == 2:
|
||||
children.sort(**kwargs)
|
||||
self.__children = children # type: ignore
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print format of BinaryNode
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_", exclude_attributes=[])
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({node_description})"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BinaryNode)
|
||||
672
python310/packages/bigtree/node/dagnode.py
Normal file
672
python310/packages/bigtree/node/dagnode.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar
|
||||
|
||||
from bigtree.globals import ASSERTIONS
|
||||
from bigtree.utils.exceptions import LoopError, TreeError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
|
||||
class DAGNode:
|
||||
"""
|
||||
Base DAGNode extends any Python class to a DAG node, for DAG implementation.
|
||||
In DAG implementation, a node can have multiple parents.
|
||||
|
||||
Parents and children cannot be reassigned once assigned, as Nodes are allowed to have multiple parents and children.
|
||||
If each node only has one parent, use `Node` class.
|
||||
DAGNodes can have attributes if they are initialized from `DAGNode` or dictionary.
|
||||
|
||||
DAGNode can be linked to each other with `parents` and `children` setter methods,
|
||||
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> c.parents = [a, b]
|
||||
>>> c.children = [d]
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> a >> c
|
||||
>>> b >> c
|
||||
>>> d << c
|
||||
|
||||
Directly passing `parents` argument.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c", parents=[a, b])
|
||||
>>> d = DAGNode("d", parents=[c])
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> d = DAGNode("d")
|
||||
>>> c = DAGNode("c", children=[d])
|
||||
>>> b = DAGNode("b", children=[c])
|
||||
>>> a = DAGNode("a", children=[c])
|
||||
|
||||
**DAGNode Creation**
|
||||
|
||||
Node can be created by instantiating a `DAGNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**DAGNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set other `DAGNode`
|
||||
|
||||
1. ``parents``: Get/set parent nodes
|
||||
2. ``children``: Get/set child nodes
|
||||
|
||||
Get other `DAGNode`
|
||||
|
||||
1. ``ancestors``: Get ancestors of node excluding self, iterator
|
||||
2. ``descendants``: Get descendants of node excluding self, iterator
|
||||
3. ``siblings``: Get siblings of self
|
||||
|
||||
Get `DAGNode` configuration
|
||||
|
||||
1. ``node_name``: Get node name, without accessing `name` directly
|
||||
2. ``is_root``: Get indicator if self is root node
|
||||
3. ``is_leaf``: Get indicator if self is leaf node
|
||||
|
||||
**DAGNode Methods**
|
||||
|
||||
These are methods available to be performed on `DAGNode`.
|
||||
|
||||
Constructor methods
|
||||
|
||||
1. ``from_dict()``: Create DAGNode from dictionary
|
||||
|
||||
`DAGNode` methods
|
||||
|
||||
1. ``describe()``: Get node information sorted by attributes, return list of tuples
|
||||
2. ``get_attr(attr_name: str)``: Get value of node attribute
|
||||
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
|
||||
4. ``go_to(node: Self)``: Get a path from own node to another node from same DAG
|
||||
5. ``copy()``: Deep copy self
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "",
|
||||
parents: Optional[List[T]] = None,
|
||||
children: Optional[List[T]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
self.__parents: List[T] = []
|
||||
self.__children: List[T] = []
|
||||
if parents is None:
|
||||
parents = []
|
||||
if children is None:
|
||||
children = []
|
||||
self.parents = parents
|
||||
self.children = children
|
||||
if "parent" in kwargs:
|
||||
raise AttributeError(
|
||||
"Attempting to set `parent` attribute, do you mean `parents`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@property
|
||||
def parent(self) -> None:
|
||||
"""Do not allow `parent` attribute to be accessed
|
||||
|
||||
Raises:
|
||||
AttributeError: No such attribute
|
||||
"""
|
||||
raise AttributeError(
|
||||
"Attempting to access `parent` attribute, do you mean `parents`?"
|
||||
)
|
||||
|
||||
@parent.setter
|
||||
def parent(self, new_parent: T) -> None:
|
||||
"""Do not allow `parent` attribute to be set
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
|
||||
Raises:
|
||||
AttributeError
|
||||
"""
|
||||
raise AttributeError(
|
||||
"Attempting to set `parent` attribute, do you mean `parents`?"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parents: List[T]) -> None:
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
if not isinstance(new_parents, list):
|
||||
raise TypeError(
|
||||
f"Parents input should be list type, received input type {type(new_parents)}"
|
||||
)
|
||||
|
||||
def __check_parent_loop(self: T, new_parents: List[T]) -> None:
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
seen_parent = []
|
||||
for new_parent in new_parents:
|
||||
# Check type
|
||||
if not isinstance(new_parent, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect parent to be DAGNode type, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_parent is self:
|
||||
raise LoopError("Error setting parent: Node cannot be parent of itself")
|
||||
if new_parent.ancestors:
|
||||
if any(ancestor is self for ancestor in new_parent.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting parent: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_parent) in seen_parent:
|
||||
raise TreeError(
|
||||
"Error setting parent: Node cannot be added multiple times as a parent"
|
||||
)
|
||||
else:
|
||||
seen_parent.append(id(new_parent))
|
||||
|
||||
@property
|
||||
def parents(self: T) -> Iterable[T]:
|
||||
"""Get parent nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__parents)
|
||||
|
||||
@parents.setter
|
||||
def parents(self: T, new_parents: List[T]) -> None:
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
if ASSERTIONS:
|
||||
self.__check_parent_type(new_parents)
|
||||
self.__check_parent_loop(new_parents)
|
||||
|
||||
current_parents = self.__parents.copy()
|
||||
|
||||
# Assign new parents - rollback if error
|
||||
self.__pre_assign_parents(new_parents)
|
||||
try:
|
||||
# Assign self to new parent
|
||||
for new_parent in new_parents:
|
||||
if new_parent not in self.__parents:
|
||||
self.__parents.append(new_parent)
|
||||
new_parent.__children.append(self)
|
||||
|
||||
self.__post_assign_parents(new_parents)
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
for new_parent in new_parents:
|
||||
if new_parent not in current_parents:
|
||||
self.__parents.remove(new_parent)
|
||||
new_parent.__children.remove(self)
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_parents(self: T, new_parents: List[T]) -> None:
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overridden with `_DAGNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): new parents to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parents(self: T, new_parents: List[T]) -> None:
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overridden with `_DAGNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): new parents to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __check_children_type(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): child node
|
||||
"""
|
||||
if not isinstance(new_children, Iterable):
|
||||
raise TypeError(
|
||||
f"Expect children to be Iterable type, received input type {type(new_children)}"
|
||||
)
|
||||
|
||||
def __check_children_loop(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if not isinstance(new_child, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect children to be DAGNode type, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@property
|
||||
def children(self: T) -> Iterable[T]:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
@children.setter
|
||||
def children(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
new_children (Iterable[Self]): child node
|
||||
"""
|
||||
if ASSERTIONS:
|
||||
self.__check_children_type(new_children)
|
||||
self.__check_children_loop(new_children)
|
||||
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Assign new children to self
|
||||
for new_child in new_children:
|
||||
if self not in new_child.__parents:
|
||||
new_child.__parents.append(self)
|
||||
self.__children.append(new_child)
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign old children to self
|
||||
for new_child in new_children:
|
||||
if new_child not in current_children:
|
||||
new_child.__parents.remove(self)
|
||||
self.__children.remove(new_child)
|
||||
raise TreeError(exc_info)
|
||||
|
||||
@children.deleter
|
||||
def children(self) -> None:
|
||||
"""Delete child node(s)"""
|
||||
for child in self.children:
|
||||
self.__children.remove(child) # type: ignore
|
||||
child.__parents.remove(self) # type: ignore
|
||||
|
||||
def __pre_assign_children(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Custom method to check before attaching children
|
||||
Can be overridden with `_DAGNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self: T, new_children: Iterable[T]) -> None:
|
||||
"""Custom method to check after attaching children
|
||||
Can be overridden with `_DAGNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ancestors(self: T) -> Iterable[T]:
|
||||
"""Get iterator to yield all ancestors of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if not len(list(self.parents)):
|
||||
return ()
|
||||
|
||||
def _recursive_parent(node: T) -> Iterable[T]:
|
||||
"""Recursively yield parent of current node, returns earliest to latest ancestor
|
||||
|
||||
Args:
|
||||
node (DAGNode): current node
|
||||
|
||||
Returns:
|
||||
(Iterable[DAGNode])
|
||||
"""
|
||||
for _node in node.parents:
|
||||
yield from _recursive_parent(_node)
|
||||
yield _node
|
||||
|
||||
ancestors = list(_recursive_parent(self))
|
||||
return list(dict.fromkeys(ancestors))
|
||||
|
||||
@property
|
||||
def descendants(self: T) -> Iterable[T]:
|
||||
"""Get iterator to yield all descendants of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
descendants = preorder_iter(self, filter_condition=lambda _node: _node != self)
|
||||
return list(dict.fromkeys(descendants))
|
||||
|
||||
@property
|
||||
def siblings(self: T) -> Iterable[T]:
|
||||
"""Get siblings of self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.is_root:
|
||||
return ()
|
||||
return tuple(
|
||||
child
|
||||
for parent in self.parents
|
||||
for child in parent.children
|
||||
if child is not self
|
||||
)
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
"""Get node name
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Get indicator if self is root node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.parents))
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.children))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, input_dict: Dict[str, Any]) -> DAGNode:
|
||||
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
|
||||
Input dictionary must have key `name` if not `Node` will not have any name
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
|
||||
|
||||
Args:
|
||||
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
return cls(**input_dict)
|
||||
|
||||
def describe(
|
||||
self, exclude_attributes: List[str] = [], exclude_prefix: str = ""
|
||||
) -> List[Tuple[str, Any]]:
|
||||
"""Get node information sorted by attribute name, returns list of tuples
|
||||
|
||||
Args:
|
||||
exclude_attributes (List[str]): list of attributes to exclude
|
||||
exclude_prefix (str): prefix of attributes to exclude
|
||||
|
||||
Returns:
|
||||
(List[Tuple[str, Any]])
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
|
||||
if (item[0] not in exclude_attributes)
|
||||
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
|
||||
]
|
||||
|
||||
def get_attr(self, attr_name: str, default_value: Any = None) -> Any:
|
||||
"""Get value of node attribute
|
||||
Returns default value if attribute name does not exist
|
||||
|
||||
Args:
|
||||
attr_name (str): attribute name
|
||||
default_value (Any): default value if attribute does not exist, defaults to None
|
||||
|
||||
Returns:
|
||||
(Any)
|
||||
"""
|
||||
try:
|
||||
return getattr(self, attr_name)
|
||||
except AttributeError:
|
||||
return default_value
|
||||
|
||||
def set_attrs(self, attrs: Dict[str, Any]) -> None:
|
||||
"""Set node attributes
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a.set_attrs({"age": 90})
|
||||
>>> a
|
||||
DAGNode(a, age=90)
|
||||
|
||||
Args:
|
||||
attrs (Dict[str, Any]): attribute dictionary,
|
||||
key: attribute name, value: attribute value
|
||||
"""
|
||||
self.__dict__.update(attrs)
|
||||
|
||||
def go_to(self: T, node: T) -> List[List[T]]:
|
||||
"""Get list of possible paths from current node to specified node from same tree
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> a >> c
|
||||
>>> b >> c
|
||||
>>> c >> d
|
||||
>>> a >> d
|
||||
>>> a.go_to(c)
|
||||
[[DAGNode(a, ), DAGNode(c, )]]
|
||||
>>> a.go_to(d)
|
||||
[[DAGNode(a, ), DAGNode(c, ), DAGNode(d, )], [DAGNode(a, ), DAGNode(d, )]]
|
||||
>>> a.go_to(b)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bigtree.utils.exceptions.TreeError: It is not possible to go to DAGNode(b, )
|
||||
|
||||
Args:
|
||||
node (Self): node to travel to from current node, inclusive of start and end node
|
||||
|
||||
Returns:
|
||||
(List[List[Self]])
|
||||
"""
|
||||
if not isinstance(node, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect node to be DAGNode type, received input type {type(node)}"
|
||||
)
|
||||
if self == node:
|
||||
return [[self]]
|
||||
if node not in self.descendants:
|
||||
raise TreeError(f"It is not possible to go to {node}")
|
||||
|
||||
self.__path: List[List[T]] = []
|
||||
|
||||
def _recursive_path(_node: T, _path: List[T]) -> Optional[List[T]]:
|
||||
"""Get path to specified node
|
||||
|
||||
Args:
|
||||
_node (DAGNode): current node
|
||||
_path (List[DAGNode]): current path, from start node to current node, excluding current node
|
||||
|
||||
Returns:
|
||||
(List[DAGNode])
|
||||
"""
|
||||
if _node: # pragma: no cover
|
||||
_path.append(_node)
|
||||
if _node == node:
|
||||
return _path
|
||||
for _child in _node.children:
|
||||
ans = _recursive_path(_child, _path.copy())
|
||||
if ans:
|
||||
self.__path.append(ans)
|
||||
return None
|
||||
|
||||
_recursive_path(self, [])
|
||||
return self.__path
|
||||
|
||||
def copy(self: T) -> T:
|
||||
"""Deep copy self; clone DAGNode
|
||||
|
||||
Examples:
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a_copy = a.copy()
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def __copy__(self: T) -> T:
|
||||
"""Shallow copy self
|
||||
|
||||
Examples:
|
||||
>>> import copy
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a_copy = copy.deepcopy(a)
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
obj: T = type(self).__new__(self.__class__)
|
||||
obj.__dict__.update(self.__dict__)
|
||||
return obj
|
||||
|
||||
def __getitem__(self, child_name: str) -> T:
|
||||
"""Get child by name identifier
|
||||
|
||||
Args:
|
||||
child_name (str): name of child node
|
||||
|
||||
Returns:
|
||||
(Self): child node
|
||||
"""
|
||||
from bigtree.tree.search import find_child_by_name
|
||||
|
||||
return find_child_by_name(self, child_name) # type: ignore
|
||||
|
||||
def __delitem__(self, child_name: str) -> None:
|
||||
"""Delete child by name identifier, will not throw error if child does not exist
|
||||
|
||||
Args:
|
||||
child_name (str): name of child node
|
||||
"""
|
||||
from bigtree.tree.search import find_child_by_name
|
||||
|
||||
child = find_child_by_name(self, child_name)
|
||||
if child:
|
||||
self.__children.remove(child) # type: ignore
|
||||
child.__parents.remove(self) # type: ignore
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print format of DAGNode
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_attributes=["name"])
|
||||
node_description = ", ".join(
|
||||
[f"{k}={v}" for k, v in node_dict if not k.startswith("_")]
|
||||
)
|
||||
return f"{class_name}({self.node_name}, {node_description})"
|
||||
|
||||
def __rshift__(self: T, other: T) -> None:
|
||||
"""Set children using >> bitshift operator for self >> children (other)
|
||||
|
||||
Args:
|
||||
other (Self): other node, children
|
||||
"""
|
||||
other.parents = [self]
|
||||
|
||||
def __lshift__(self: T, other: T) -> None:
|
||||
"""Set parent using << bitshift operator for self << parent (other)
|
||||
|
||||
Args:
|
||||
other (Self): other node, parent
|
||||
"""
|
||||
self.parents = [other]
|
||||
|
||||
def __iter__(self) -> Generator[T, None, None]:
|
||||
"""Iterate through child nodes
|
||||
|
||||
Returns:
|
||||
(Self): child node
|
||||
"""
|
||||
yield from self.children # type: ignore
|
||||
|
||||
def __contains__(self, other_node: T) -> bool:
|
||||
"""Check if child node exists
|
||||
|
||||
Args:
|
||||
other_node (T): child node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return other_node in self.children
|
||||
|
||||
|
||||
T = TypeVar("T", bound=DAGNode)
|
||||
261
python310/packages/bigtree/node/node.py
Normal file
261
python310/packages/bigtree/node/node.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from typing import Any, List, TypeVar
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.utils.exceptions import TreeError
|
||||
|
||||
|
||||
class Node(BaseNode):
|
||||
"""
|
||||
Node is an extension of BaseNode, and is able to extend to any Python class.
|
||||
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
Nodes can be linked to each other with `parent` and `children` setter methods.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node("a")
|
||||
>>> b = Node("b")
|
||||
>>> c = Node("c")
|
||||
>>> d = Node("d")
|
||||
>>> b.parent = a
|
||||
>>> b.children = [c, d]
|
||||
|
||||
Directly passing `parent` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node("a")
|
||||
>>> b = Node("b", parent=a)
|
||||
>>> c = Node("c", parent=b)
|
||||
>>> d = Node("d", parent=b)
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> d = Node("d")
|
||||
>>> c = Node("c")
|
||||
>>> b = Node("b", children=[c, d])
|
||||
>>> a = Node("a", children=[b])
|
||||
|
||||
**Node Creation**
|
||||
|
||||
Node can be created by instantiating a `Node` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**Node Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set `Node` configuration
|
||||
|
||||
1. ``sep``: Get/set separator for path name
|
||||
|
||||
Get `Node` configuration
|
||||
|
||||
1. ``node_name``: Get node name, without accessing `name` directly
|
||||
2. ``path_name``: Get path name from root, separated by `sep`
|
||||
|
||||
**Node Methods**
|
||||
|
||||
These are methods available to be performed on `Node`.
|
||||
|
||||
`Node` methods
|
||||
|
||||
1. ``show()``: Print tree to console
|
||||
2. ``hshow()``: Print tree in horizontal orientation to console
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", sep: str = "/", **kwargs: Any):
|
||||
self.name = name
|
||||
self._sep = sep
|
||||
super().__init__(**kwargs)
|
||||
if not self.node_name:
|
||||
raise TreeError("Node must have a `name` attribute")
|
||||
|
||||
@property
|
||||
def sep(self) -> str:
|
||||
"""Get separator, gets from root node
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
if self.parent is None:
|
||||
return self._sep
|
||||
return self.parent.sep
|
||||
|
||||
@sep.setter
|
||||
def sep(self, value: str) -> None:
|
||||
"""Set separator, affects root node
|
||||
|
||||
Args:
|
||||
value (str): separator to replace default separator
|
||||
"""
|
||||
self.root._sep = value
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
"""Get node name
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def path_name(self) -> str:
|
||||
"""Get path name, separated by self.sep
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
ancestors = [self] + list(self.ancestors)
|
||||
sep = ancestors[-1].sep
|
||||
return sep + sep.join([str(node.name) for node in reversed(ancestors)])
|
||||
|
||||
def __pre_assign_children(self: T, new_children: List[T]) -> None:
|
||||
"""Custom method to check before attaching children
|
||||
Can be overridden with `_Node__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self: T, new_children: List[T]) -> None:
|
||||
"""Custom method to check after attaching children
|
||||
Can be overridden with `_Node__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __pre_assign_parent(self: T, new_parent: T) -> None:
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overridden with `_Node__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self: T, new_parent: T) -> None:
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overridden with `_Node__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def _BaseNode__pre_assign_parent(self: T, new_parent: T) -> None:
|
||||
"""Do not allow duplicate nodes of same path
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
self.__pre_assign_parent(new_parent)
|
||||
if new_parent is not None:
|
||||
if any(
|
||||
child.node_name == self.node_name and child is not self
|
||||
for child in new_parent.children
|
||||
):
|
||||
raise TreeError(
|
||||
f"Duplicate node with same path\n"
|
||||
f"There exist a node with same path {new_parent.path_name}{new_parent.sep}{self.node_name}"
|
||||
)
|
||||
|
||||
def _BaseNode__post_assign_parent(self: T, new_parent: T) -> None:
|
||||
"""No rules
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
def _BaseNode__pre_assign_children(self: T, new_children: List[T]) -> None:
|
||||
"""Do not allow duplicate nodes of same path
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
self.__pre_assign_children(new_children)
|
||||
children_names = [node.node_name for node in new_children]
|
||||
duplicate_names = [
|
||||
item[0] for item in Counter(children_names).items() if item[1] > 1
|
||||
]
|
||||
if len(duplicate_names):
|
||||
duplicate_names_str = " and ".join(
|
||||
[f"{self.path_name}{self.sep}{name}" for name in duplicate_names]
|
||||
)
|
||||
raise TreeError(
|
||||
f"Duplicate node with same path\n"
|
||||
f"Attempting to add nodes with same path {duplicate_names_str}"
|
||||
)
|
||||
|
||||
def _BaseNode__post_assign_children(self: T, new_children: List[T]) -> None:
|
||||
"""No rules
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
self.__post_assign_children(new_children)
|
||||
|
||||
def show(self, **kwargs: Any) -> None:
|
||||
"""Print tree to console, takes in same keyword arguments as `print_tree` function"""
|
||||
from bigtree.tree.export import print_tree
|
||||
|
||||
print_tree(self, **kwargs)
|
||||
|
||||
def hshow(self, **kwargs: Any) -> None:
|
||||
"""Print tree in horizontal orientation to console, takes in same keyword arguments as `hprint_tree` function"""
|
||||
from bigtree.tree.export import hprint_tree
|
||||
|
||||
hprint_tree(self, **kwargs)
|
||||
|
||||
def __getitem__(self, child_name: str) -> T:
|
||||
"""Get child by name identifier
|
||||
|
||||
Args:
|
||||
child_name (str): name of child node
|
||||
|
||||
Returns:
|
||||
(Self): child node
|
||||
"""
|
||||
from bigtree.tree.search import find_child_by_name
|
||||
|
||||
return find_child_by_name(self, child_name) # type: ignore
|
||||
|
||||
def __delitem__(self, child_name: str) -> None:
|
||||
"""Delete child by name identifier, will not throw error if child does not exist
|
||||
|
||||
Args:
|
||||
child_name (str): name of child node
|
||||
"""
|
||||
from bigtree.tree.search import find_child_by_name
|
||||
|
||||
child = find_child_by_name(self, child_name)
|
||||
if child:
|
||||
child.parent = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print format of Node
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_", exclude_attributes=["name"])
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({self.path_name}, {node_description})"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Node)
|
||||
0
python310/packages/bigtree/py.typed
Normal file
0
python310/packages/bigtree/py.typed
Normal file
0
python310/packages/bigtree/tree/__init__.py
Normal file
0
python310/packages/bigtree/tree/__init__.py
Normal file
1327
python310/packages/bigtree/tree/construct.py
Normal file
1327
python310/packages/bigtree/tree/construct.py
Normal file
File diff suppressed because it is too large
Load Diff
1660
python310/packages/bigtree/tree/export.py
Normal file
1660
python310/packages/bigtree/tree/export.py
Normal file
File diff suppressed because it is too large
Load Diff
415
python310/packages/bigtree/tree/helper.py
Normal file
415
python310/packages/bigtree/tree/helper.py
Normal file
@@ -0,0 +1,415 @@
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Set, Type, TypeVar, Union
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import add_dict_to_tree_by_path, dataframe_to_tree
|
||||
from bigtree.tree.export import tree_to_dataframe
|
||||
from bigtree.tree.search import find_path
|
||||
from bigtree.utils.exceptions import NotFoundError
|
||||
from bigtree.utils.iterators import levelordergroup_iter
|
||||
|
||||
__all__ = ["clone_tree", "get_subtree", "prune_tree", "get_tree_diff"]
|
||||
BaseNodeT = TypeVar("BaseNodeT", bound=BaseNode)
|
||||
BinaryNodeT = TypeVar("BinaryNodeT", bound=BinaryNode)
|
||||
NodeT = TypeVar("NodeT", bound=Node)
|
||||
|
||||
|
||||
def clone_tree(tree: BaseNode, node_type: Type[BaseNodeT]) -> BaseNodeT:
|
||||
"""Clone tree to another ``Node`` type.
|
||||
If the same type is needed, simply do a tree.copy().
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import BaseNode, Node, clone_tree
|
||||
>>> root = BaseNode(name="a")
|
||||
>>> b = BaseNode(name="b", parent=root)
|
||||
>>> clone_tree(root, Node)
|
||||
Node(/a, )
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to be cloned, must inherit from BaseNode
|
||||
node_type (Type[BaseNode]): type of cloned tree
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
if not isinstance(tree, BaseNode):
|
||||
raise TypeError("Tree should be of type `BaseNode`, or inherit from `BaseNode`")
|
||||
|
||||
# Start from root
|
||||
root_info = dict(tree.root.describe(exclude_prefix="_"))
|
||||
root_node = node_type(**root_info)
|
||||
|
||||
def _recursive_add_child(
|
||||
_new_parent_node: BaseNodeT, _parent_node: BaseNode
|
||||
) -> None:
|
||||
"""Recursively clone current node
|
||||
|
||||
Args:
|
||||
_new_parent_node (BaseNode): cloned parent node
|
||||
_parent_node (BaseNode): parent node to be cloned
|
||||
"""
|
||||
for _child in _parent_node.children:
|
||||
if _child:
|
||||
child_info = dict(_child.describe(exclude_prefix="_"))
|
||||
child_node = node_type(**child_info)
|
||||
child_node.parent = _new_parent_node
|
||||
_recursive_add_child(child_node, _child)
|
||||
|
||||
_recursive_add_child(root_node, tree.root)
|
||||
return root_node
|
||||
|
||||
|
||||
def get_subtree(
|
||||
tree: NodeT,
|
||||
node_name_or_path: str = "",
|
||||
max_depth: int = 0,
|
||||
) -> NodeT:
|
||||
"""Get subtree based on node name or node path, and/or maximum depth of tree.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, get_subtree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=b)
|
||||
>>> d = Node("d", parent=b)
|
||||
>>> e = Node("e", parent=root)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ └── d
|
||||
└── e
|
||||
|
||||
Get subtree
|
||||
|
||||
>>> root_subtree = get_subtree(root, "b")
|
||||
>>> root_subtree.show()
|
||||
b
|
||||
├── c
|
||||
└── d
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
node_name_or_path (str): node name or path to get subtree, defaults to None
|
||||
max_depth (int): maximum depth of subtree, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
tree = tree.copy()
|
||||
|
||||
if node_name_or_path:
|
||||
tree = find_path(tree, node_name_or_path)
|
||||
if not tree:
|
||||
raise ValueError(f"Node name or path {node_name_or_path} not found")
|
||||
|
||||
if not tree.is_root:
|
||||
tree.parent = None
|
||||
|
||||
if max_depth:
|
||||
tree = prune_tree(tree, max_depth=max_depth)
|
||||
return tree
|
||||
|
||||
|
||||
def prune_tree(
|
||||
tree: Union[BinaryNodeT, NodeT],
|
||||
prune_path: Union[List[str], str] = "",
|
||||
exact: bool = False,
|
||||
sep: str = "/",
|
||||
max_depth: int = 0,
|
||||
) -> Union[BinaryNodeT, NodeT]:
|
||||
"""Prune tree by path or depth, returns the root of a *copy* of the original tree.
|
||||
|
||||
For pruning by `prune_path`,
|
||||
|
||||
- All siblings along the prune path will be removed.
|
||||
- If ``exact=True``, all descendants of prune path will be removed.
|
||||
- Prune path can be string (only one path) or a list of strings (multiple paths).
|
||||
- Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
|
||||
|
||||
For pruning by `max_depth`,
|
||||
|
||||
- All nodes that are beyond `max_depth` will be removed.
|
||||
|
||||
Path should contain ``Node`` name, separated by `sep`.
|
||||
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, prune_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=b)
|
||||
>>> d = Node("d", parent=b)
|
||||
>>> e = Node("e", parent=root)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ └── d
|
||||
└── e
|
||||
|
||||
Prune (default is keep descendants)
|
||||
|
||||
>>> root_pruned = prune_tree(root, "a/b")
|
||||
>>> root_pruned.show()
|
||||
a
|
||||
└── b
|
||||
├── c
|
||||
└── d
|
||||
|
||||
Prune exact path
|
||||
|
||||
>>> root_pruned = prune_tree(root, "a/b", exact=True)
|
||||
>>> root_pruned.show()
|
||||
a
|
||||
└── b
|
||||
|
||||
Prune multiple paths
|
||||
|
||||
>>> root_pruned = prune_tree(root, ["a/b/d", "a/e"])
|
||||
>>> root_pruned.show()
|
||||
a
|
||||
├── b
|
||||
│ └── d
|
||||
└── e
|
||||
|
||||
Prune by depth
|
||||
|
||||
>>> root_pruned = prune_tree(root, max_depth=2)
|
||||
>>> root_pruned.show()
|
||||
a
|
||||
├── b
|
||||
└── e
|
||||
|
||||
Args:
|
||||
tree (Union[BinaryNode, Node]): existing tree
|
||||
prune_path (List[str] | str): prune path(s), all siblings along the prune path(s) will be removed
|
||||
exact (bool): prune path(s) to be exactly the path, defaults to False (descendants of the path are retained)
|
||||
sep (str): path separator of `prune_path`
|
||||
max_depth (int): maximum depth of pruned tree, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Union[BinaryNode, Node])
|
||||
"""
|
||||
if isinstance(prune_path, str):
|
||||
prune_path = [prune_path] if prune_path else []
|
||||
|
||||
if not len(prune_path) and not max_depth:
|
||||
raise ValueError("Please specify either `prune_path` or `max_depth` or both.")
|
||||
|
||||
tree_copy = tree.copy()
|
||||
|
||||
# Prune by path (prune bottom-up)
|
||||
if len(prune_path):
|
||||
ancestors_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
|
||||
nodes_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
|
||||
for path in prune_path:
|
||||
path = path.replace(sep, tree.sep)
|
||||
child = find_path(tree_copy, path)
|
||||
if not child:
|
||||
raise NotFoundError(
|
||||
f"Cannot find any node matching path_name ending with {path}"
|
||||
)
|
||||
nodes_to_prune.add(child)
|
||||
ancestors_to_prune.update(list(child.ancestors))
|
||||
|
||||
if exact:
|
||||
ancestors_to_prune.update(nodes_to_prune)
|
||||
|
||||
for node in ancestors_to_prune:
|
||||
for child in node.children:
|
||||
if (
|
||||
child
|
||||
and child not in ancestors_to_prune
|
||||
and child not in nodes_to_prune
|
||||
):
|
||||
child.parent = None
|
||||
|
||||
# Prune by depth (prune top-down)
|
||||
if max_depth:
|
||||
for depth, level_nodes in enumerate(levelordergroup_iter(tree_copy), 1):
|
||||
if depth == max_depth:
|
||||
for level_node in level_nodes:
|
||||
del level_node.children
|
||||
return tree_copy
|
||||
|
||||
|
||||
def get_tree_diff(
|
||||
tree: Node, other_tree: Node, only_diff: bool = True, attr_list: List[str] = []
|
||||
) -> Node:
|
||||
"""Get difference of `tree` to `other_tree`, changes are relative to `tree`.
|
||||
|
||||
Compares the difference in tree structure (default), but can also compare tree attributes using `attr_list`.
|
||||
Function can return only the differences (default), or all original tree nodes and differences.
|
||||
|
||||
Comparing tree structure:
|
||||
|
||||
- (+) and (-) will be added to node name relative to `tree`.
|
||||
- For example: (+) refers to nodes that are in `other_tree` but not `tree`.
|
||||
- For example: (-) refers to nodes that are in `tree` but not `other_tree`.
|
||||
|
||||
Examples:
|
||||
>>> # Create original tree
|
||||
>>> from bigtree import Node, get_tree_diff, list_to_tree
|
||||
>>> root = list_to_tree(["Downloads/Pictures/photo1.jpg", "Downloads/file1.doc", "Downloads/photo2.jpg"])
|
||||
>>> root.show()
|
||||
Downloads
|
||||
├── Pictures
|
||||
│ └── photo1.jpg
|
||||
├── file1.doc
|
||||
└── photo2.jpg
|
||||
|
||||
>>> # Create other tree
|
||||
>>> root_other = list_to_tree(["Downloads/Pictures/photo1.jpg", "Downloads/Pictures/photo2.jpg", "Downloads/file1.doc"])
|
||||
>>> root_other.show()
|
||||
Downloads
|
||||
├── Pictures
|
||||
│ ├── photo1.jpg
|
||||
│ └── photo2.jpg
|
||||
└── file1.doc
|
||||
|
||||
>>> # Get tree differences
|
||||
>>> tree_diff = get_tree_diff(root, root_other)
|
||||
>>> tree_diff.show()
|
||||
Downloads
|
||||
├── photo2.jpg (-)
|
||||
└── Pictures
|
||||
└── photo2.jpg (+)
|
||||
|
||||
>>> tree_diff = get_tree_diff(root, root_other, only_diff=False)
|
||||
>>> tree_diff.show()
|
||||
Downloads
|
||||
├── Pictures
|
||||
│ ├── photo1.jpg
|
||||
│ └── photo2.jpg (+)
|
||||
├── file1.doc
|
||||
└── photo2.jpg (-)
|
||||
|
||||
Comparing tree attributes
|
||||
|
||||
- (~) will be added to node name if there are differences in tree attributes defined in `attr_list`.
|
||||
- The node's attributes will be a list of [value in `tree`, value in `other_tree`]
|
||||
|
||||
>>> # Create original tree
|
||||
>>> root = Node("Downloads")
|
||||
>>> picture_folder = Node("Pictures", parent=root)
|
||||
>>> photo2 = Node("photo1.jpg", tags="photo1", parent=picture_folder)
|
||||
>>> file1 = Node("file1.doc", tags="file1", parent=root)
|
||||
>>> root.show(attr_list=["tags"])
|
||||
Downloads
|
||||
├── Pictures
|
||||
│ └── photo1.jpg [tags=photo1]
|
||||
└── file1.doc [tags=file1]
|
||||
|
||||
>>> # Create other tree
|
||||
>>> root_other = Node("Downloads")
|
||||
>>> picture_folder = Node("Pictures", parent=root_other)
|
||||
>>> photo1 = Node("photo1.jpg", tags="photo1-edited", parent=picture_folder)
|
||||
>>> photo2 = Node("photo2.jpg", tags="photo2-new", parent=picture_folder)
|
||||
>>> file1 = Node("file1.doc", tags="file1", parent=root_other)
|
||||
>>> root_other.show(attr_list=["tags"])
|
||||
Downloads
|
||||
├── Pictures
|
||||
│ ├── photo1.jpg [tags=photo1-edited]
|
||||
│ └── photo2.jpg [tags=photo2-new]
|
||||
└── file1.doc [tags=file1]
|
||||
|
||||
>>> # Get tree differences
|
||||
>>> tree_diff = get_tree_diff(root, root_other, attr_list=["tags"])
|
||||
>>> tree_diff.show(attr_list=["tags"])
|
||||
Downloads
|
||||
└── Pictures
|
||||
├── photo1.jpg (~) [tags=('photo1', 'photo1-edited')]
|
||||
└── photo2.jpg (+)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be compared against
|
||||
other_tree (Node): tree to be compared with
|
||||
only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True
|
||||
attr_list (List[str]): tree attributes to check for difference, defaults to empty list
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
other_tree.sep = tree.sep
|
||||
name_col = "name"
|
||||
path_col = "PATH"
|
||||
indicator_col = "Exists"
|
||||
|
||||
data, data_other = (
|
||||
tree_to_dataframe(
|
||||
_tree,
|
||||
name_col=name_col,
|
||||
path_col=path_col,
|
||||
attr_dict={k: k for k in attr_list},
|
||||
)
|
||||
for _tree in (tree, other_tree)
|
||||
)
|
||||
|
||||
# Check tree structure difference
|
||||
data_both = data[[path_col, name_col] + attr_list].merge(
|
||||
data_other[[path_col, name_col] + attr_list],
|
||||
how="outer",
|
||||
on=[path_col, name_col],
|
||||
indicator=indicator_col,
|
||||
)
|
||||
|
||||
# Handle tree structure difference
|
||||
nodes_removed = list(data_both[data_both[indicator_col] == "left_only"][path_col])[
|
||||
::-1
|
||||
]
|
||||
nodes_added = list(data_both[data_both[indicator_col] == "right_only"][path_col])[
|
||||
::-1
|
||||
]
|
||||
for node_removed in nodes_removed:
|
||||
data_both[path_col] = data_both[path_col].str.replace(
|
||||
node_removed, f"{node_removed} (-)", regex=True
|
||||
)
|
||||
for node_added in nodes_added:
|
||||
data_both[path_col] = data_both[path_col].str.replace(
|
||||
node_added, f"{node_added} (+)", regex=True
|
||||
)
|
||||
|
||||
# Check tree attribute difference
|
||||
path_changes_list_of_dict: List[Dict[str, Dict[str, Any]]] = []
|
||||
path_changes_deque: Deque[str] = deque([])
|
||||
for attr_change in attr_list:
|
||||
condition_diff = (
|
||||
(
|
||||
~data_both[f"{attr_change}_x"].isnull()
|
||||
| ~data_both[f"{attr_change}_y"].isnull()
|
||||
)
|
||||
& (data_both[f"{attr_change}_x"] != data_both[f"{attr_change}_y"])
|
||||
& (data_both[indicator_col] == "both")
|
||||
)
|
||||
data_diff = data_both[condition_diff]
|
||||
if len(data_diff):
|
||||
tuple_diff = zip(
|
||||
data_diff[f"{attr_change}_x"], data_diff[f"{attr_change}_y"]
|
||||
)
|
||||
dict_attr_diff = [{attr_change: v} for v in tuple_diff]
|
||||
dict_path_diff = dict(list(zip(data_diff[path_col], dict_attr_diff)))
|
||||
path_changes_list_of_dict.append(dict_path_diff)
|
||||
path_changes_deque.extend(list(data_diff[path_col]))
|
||||
|
||||
if only_diff:
|
||||
data_both = data_both[
|
||||
(data_both[indicator_col] != "both")
|
||||
| (data_both[path_col].isin(path_changes_deque))
|
||||
]
|
||||
data_both = data_both[[path_col]]
|
||||
if len(data_both):
|
||||
tree_diff = dataframe_to_tree(data_both, node_type=tree.__class__)
|
||||
# Handle tree attribute difference
|
||||
if len(path_changes_deque):
|
||||
path_changes_list = sorted(path_changes_deque, reverse=True)
|
||||
name_changes_list = [
|
||||
{k: {"name": f"{k.split(tree.sep)[-1]} (~)"} for k in path_changes_list}
|
||||
]
|
||||
path_changes_list_of_dict.extend(name_changes_list)
|
||||
for attr_change_dict in path_changes_list_of_dict:
|
||||
tree_diff = add_dict_to_tree_by_path(tree_diff, attr_change_dict)
|
||||
return tree_diff
|
||||
1356
python310/packages/bigtree/tree/modify.py
Normal file
1356
python310/packages/bigtree/tree/modify.py
Normal file
File diff suppressed because it is too large
Load Diff
479
python310/packages/bigtree/tree/search.py
Normal file
479
python310/packages/bigtree/tree/search.py
Normal file
@@ -0,0 +1,479 @@
|
||||
from typing import Any, Callable, Iterable, List, Tuple, TypeVar, Union
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.utils.exceptions import SearchError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
__all__ = [
|
||||
"findall",
|
||||
"find",
|
||||
"find_name",
|
||||
"find_names",
|
||||
"find_relative_path",
|
||||
"find_full_path",
|
||||
"find_path",
|
||||
"find_paths",
|
||||
"find_attr",
|
||||
"find_attrs",
|
||||
"find_children",
|
||||
"find_child",
|
||||
"find_child_by_name",
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseNode)
|
||||
NodeT = TypeVar("NodeT", bound=Node)
|
||||
DAGNodeT = TypeVar("DAGNodeT", bound=DAGNode)
|
||||
|
||||
|
||||
def findall(
|
||||
tree: T,
|
||||
condition: Callable[[T], bool],
|
||||
max_depth: int = 0,
|
||||
min_count: int = 0,
|
||||
max_count: int = 0,
|
||||
) -> Tuple[T, ...]:
|
||||
"""
|
||||
Search tree for nodes matching condition (callable function).
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, findall
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> findall(root, lambda node: node.age > 62)
|
||||
(Node(/a, age=90), Node(/a/b, age=65))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
min_count (int): checks for minimum number of occurrences,
|
||||
raise SearchError if the number of results do not meet min_count, defaults to None
|
||||
max_count (int): checks for maximum number of occurrences,
|
||||
raise SearchError if the number of results do not meet min_count, defaults to None
|
||||
|
||||
Returns:
|
||||
(Tuple[BaseNode, ...])
|
||||
"""
|
||||
result = tuple(preorder_iter(tree, filter_condition=condition, max_depth=max_depth))
|
||||
if min_count and len(result) < min_count:
|
||||
raise SearchError(
|
||||
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
if max_count and len(result) > max_count:
|
||||
raise SearchError(
|
||||
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def find(tree: T, condition: Callable[[T], bool], max_depth: int = 0) -> T:
|
||||
"""
|
||||
Search tree for *single node* matching condition (callable function).
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find(root, lambda node: node.age == 65)
|
||||
Node(/a/b, age=65)
|
||||
>>> find(root, lambda node: node.age > 5)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bigtree.utils.exceptions.SearchError: Expected less than 1 element(s), found 4 elements
|
||||
(Node(/a, age=90), Node(/a/b, age=65), Node(/a/c, age=60), Node(/a/c/d, age=40))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
result = findall(tree, condition, max_depth, max_count=1)
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
|
||||
def find_name(tree: NodeT, name: str, max_depth: int = 0) -> NodeT:
|
||||
"""
|
||||
Search tree for single node matching name attribute.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_name
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_name(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
name (str): value to match for name attribute
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
return find(tree, lambda node: node.node_name == name, max_depth)
|
||||
|
||||
|
||||
def find_names(tree: NodeT, name: str, max_depth: int = 0) -> Iterable[NodeT]:
|
||||
"""
|
||||
Search tree for multiple node(s) matching name attribute.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_names
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("b", age=40, parent=c)
|
||||
>>> find_names(root, "c")
|
||||
(Node(/a/c, age=60),)
|
||||
>>> find_names(root, "b")
|
||||
(Node(/a/b, age=65), Node(/a/c/b, age=40))
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
name (str): value to match for name attribute
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[Node])
|
||||
"""
|
||||
return findall(tree, lambda node: node.node_name == name, max_depth)
|
||||
|
||||
|
||||
def find_relative_path(tree: NodeT, path_name: str) -> Iterable[NodeT]:
|
||||
r"""
|
||||
Search tree for single node matching relative path attribute.
|
||||
|
||||
- Supports unix folder expression for relative path, i.e., '../../node_name'
|
||||
- Supports wildcards, i.e., '\*/node_name'
|
||||
- If path name starts with leading separator symbol, it will start at root node.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_relative_path
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_relative_path(d, "..")
|
||||
(Node(/a/c, age=60),)
|
||||
>>> find_relative_path(d, "../../b")
|
||||
(Node(/a/b, age=65),)
|
||||
>>> find_relative_path(d, "../../*")
|
||||
(Node(/a/b, age=65), Node(/a/c, age=60))
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (relative path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Iterable[Node])
|
||||
"""
|
||||
sep = tree.sep
|
||||
if path_name.startswith(sep):
|
||||
resolved_node = find_full_path(tree, path_name)
|
||||
return (resolved_node,)
|
||||
path_name = path_name.rstrip(sep).lstrip(sep)
|
||||
path_list = path_name.split(sep)
|
||||
wildcard_indicator = "*" in path_name
|
||||
resolved_nodes: List[NodeT] = []
|
||||
|
||||
def resolve(node: NodeT, path_idx: int) -> None:
|
||||
"""Resolve node based on path name
|
||||
|
||||
Args:
|
||||
node (Node): current node
|
||||
path_idx (int): current index in path_list
|
||||
"""
|
||||
if path_idx == len(path_list):
|
||||
resolved_nodes.append(node)
|
||||
else:
|
||||
path_component = path_list[path_idx]
|
||||
if path_component == ".":
|
||||
resolve(node, path_idx + 1)
|
||||
elif path_component == "..":
|
||||
if node.is_root:
|
||||
raise SearchError("Invalid path name. Path goes beyond root node.")
|
||||
resolve(node.parent, path_idx + 1)
|
||||
elif path_component == "*":
|
||||
for child in node.children:
|
||||
resolve(child, path_idx + 1)
|
||||
else:
|
||||
node = find_child_by_name(node, path_component)
|
||||
if not node:
|
||||
if not wildcard_indicator:
|
||||
raise SearchError(
|
||||
f"Invalid path name. Node {path_component} cannot be found."
|
||||
)
|
||||
else:
|
||||
resolve(node, path_idx + 1)
|
||||
|
||||
resolve(tree, 0)
|
||||
|
||||
return tuple(resolved_nodes)
|
||||
|
||||
|
||||
def find_full_path(tree: NodeT, path_name: str) -> NodeT:
|
||||
"""
|
||||
Search tree for single node matching path attribute.
|
||||
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name must be full path, works similar to `find_path` but faster.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_full_path
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_full_path(root, "/a/c/d")
|
||||
Node(/a/c/d, age=40)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
sep = tree.sep
|
||||
path_list = path_name.rstrip(sep).lstrip(sep).split(sep)
|
||||
if path_list[0] != tree.root.node_name:
|
||||
raise ValueError(
|
||||
f"Path {path_name} does not match the root node name {tree.root.node_name}"
|
||||
)
|
||||
parent_node = tree.root
|
||||
child_node = parent_node
|
||||
for child_name in path_list[1:]:
|
||||
child_node = find_child_by_name(parent_node, child_name)
|
||||
if not child_node:
|
||||
break
|
||||
parent_node = child_node
|
||||
return child_node
|
||||
|
||||
|
||||
def find_path(tree: NodeT, path_name: str) -> NodeT:
|
||||
"""
|
||||
Search tree for single node matching path attribute.
|
||||
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be full path or partial path (trailing part of path) or node name.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_path
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_path(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
>>> find_path(root, "/c")
|
||||
Node(/a/c, age=60)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
path_name = path_name.rstrip(tree.sep)
|
||||
return find(tree, lambda node: node.path_name.endswith(path_name))
|
||||
|
||||
|
||||
def find_paths(tree: NodeT, path_name: str) -> Tuple[NodeT, ...]:
|
||||
"""
|
||||
Search tree for multiple nodes matching path attribute.
|
||||
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_paths
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("c", age=40, parent=c)
|
||||
>>> find_paths(root, "/a/c")
|
||||
(Node(/a/c, age=60),)
|
||||
>>> find_paths(root, "/c")
|
||||
(Node(/a/c, age=60), Node(/a/c/c, age=40))
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Tuple[Node, ...])
|
||||
"""
|
||||
path_name = path_name.rstrip(tree.sep)
|
||||
return findall(tree, lambda node: node.path_name.endswith(path_name))
|
||||
|
||||
|
||||
def find_attr(
|
||||
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = 0
|
||||
) -> BaseNode:
|
||||
"""
|
||||
Search tree for single node matching custom attribute.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_attr
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_attr(root, "age", 65)
|
||||
Node(/a/b, age=65)
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
attr_name (str): attribute name to perform matching
|
||||
attr_value (Any): value to match for attr_name attribute
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
return find(
|
||||
tree,
|
||||
lambda node: bool(node.get_attr(attr_name) == attr_value),
|
||||
max_depth,
|
||||
)
|
||||
|
||||
|
||||
def find_attrs(
|
||||
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = 0
|
||||
) -> Tuple[BaseNode, ...]:
|
||||
"""
|
||||
Search tree for node(s) matching custom attribute.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_attrs
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=65, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_attrs(root, "age", 65)
|
||||
(Node(/a/b, age=65), Node(/a/c, age=65))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
attr_name (str): attribute name to perform matching
|
||||
attr_value (Any): value to match for attr_name attribute
|
||||
max_depth (int): maximum depth to search for, based on the `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Tuple[BaseNode, ...])
|
||||
"""
|
||||
return findall(
|
||||
tree,
|
||||
lambda node: bool(node.get_attr(attr_name) == attr_value),
|
||||
max_depth,
|
||||
)
|
||||
|
||||
|
||||
def find_children(
|
||||
tree: Union[T, DAGNodeT],
|
||||
condition: Callable[[Union[T, DAGNodeT]], bool],
|
||||
min_count: int = 0,
|
||||
max_count: int = 0,
|
||||
) -> Tuple[Union[T, DAGNodeT], ...]:
|
||||
"""
|
||||
Search children for nodes matching condition (callable function).
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_children
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_children(root, lambda node: node.age > 30)
|
||||
(Node(/a/b, age=65), Node(/a/c, age=60))
|
||||
|
||||
Args:
|
||||
tree (BaseNode/DAGNode): tree to search for its children
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
min_count (int): checks for minimum number of occurrences,
|
||||
raise SearchError if the number of results do not meet min_count, defaults to None
|
||||
max_count (int): checks for maximum number of occurrences,
|
||||
raise SearchError if the number of results do not meet min_count, defaults to None
|
||||
|
||||
Returns:
|
||||
(BaseNode/DAGNode)
|
||||
"""
|
||||
result = tuple([node for node in tree.children if node and condition(node)])
|
||||
if min_count and len(result) < min_count:
|
||||
raise SearchError(
|
||||
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
if max_count and len(result) > max_count:
|
||||
raise SearchError(
|
||||
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def find_child(
|
||||
tree: Union[T, DAGNodeT],
|
||||
condition: Callable[[Union[T, DAGNodeT]], bool],
|
||||
) -> Union[T, DAGNodeT]:
|
||||
"""
|
||||
Search children for *single node* matching condition (callable function).
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_child
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_child(root, lambda node: node.age > 62)
|
||||
Node(/a/b, age=65)
|
||||
|
||||
Args:
|
||||
tree (BaseNode/DAGNode): tree to search for its child
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
|
||||
Returns:
|
||||
(BaseNode/DAGNode)
|
||||
"""
|
||||
result = find_children(tree, condition, max_count=1)
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
|
||||
def find_child_by_name(
|
||||
tree: Union[NodeT, DAGNodeT], name: str
|
||||
) -> Union[NodeT, DAGNodeT]:
|
||||
"""
|
||||
Search tree for single node matching name attribute.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, find_child_by_name
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_child_by_name(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
>>> find_child_by_name(c, "d")
|
||||
Node(/a/c/d, age=40)
|
||||
|
||||
Args:
|
||||
tree (Node/DAGNode): tree to search, parent node
|
||||
name (str): value to match for name attribute, child node
|
||||
|
||||
Returns:
|
||||
(Node/DAGNode)
|
||||
"""
|
||||
return find_child(tree, lambda node: node.node_name == name)
|
||||
0
python310/packages/bigtree/utils/__init__.py
Normal file
0
python310/packages/bigtree/utils/__init__.py
Normal file
53
python310/packages/bigtree/utils/assertions.py
Normal file
53
python310/packages/bigtree/utils/assertions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def assert_style_in_dict(
|
||||
parameter: Any,
|
||||
accepted_parameters: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Raise ValueError is parameter is not in list of accepted parameters
|
||||
|
||||
Args:
|
||||
parameter (Any): argument input for parameter
|
||||
accepted_parameters (List[Any]): list of accepted parameters
|
||||
"""
|
||||
if parameter not in accepted_parameters and parameter != "custom":
|
||||
raise ValueError(
|
||||
f"Choose one of {accepted_parameters.keys()} style, use `custom` to define own style"
|
||||
)
|
||||
|
||||
|
||||
def assert_str_in_list(
|
||||
parameter_name: str,
|
||||
parameter: Any,
|
||||
accepted_parameters: List[Any],
|
||||
) -> None:
|
||||
"""Raise ValueError is parameter is not in list of accepted parameters
|
||||
|
||||
Args:
|
||||
parameter_name (str): parameter name for error message
|
||||
parameter (Any): argument input for parameter
|
||||
accepted_parameters (List[Any]): list of accepted parameters
|
||||
"""
|
||||
if parameter not in accepted_parameters:
|
||||
raise ValueError(
|
||||
f"Invalid input, check `{parameter_name}` should be one of {accepted_parameters}"
|
||||
)
|
||||
|
||||
|
||||
def assert_key_in_dict(
|
||||
parameter_name: str,
|
||||
parameter: Any,
|
||||
accepted_parameters: Dict[Any, Any],
|
||||
) -> None:
|
||||
"""Raise ValueError is parameter is not in key of dictionary
|
||||
|
||||
Args:
|
||||
parameter_name (str): parameter name for error message
|
||||
parameter (Any): argument input for parameter
|
||||
accepted_parameters (Dict[Any]): dictionary of accepted parameters
|
||||
"""
|
||||
if parameter not in accepted_parameters:
|
||||
raise ValueError(
|
||||
f"Invalid input, check `{parameter_name}` should be one of {accepted_parameters.keys()}"
|
||||
)
|
||||
165
python310/packages/bigtree/utils/constants.py
Normal file
165
python310/packages/bigtree/utils/constants.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class ExportConstants:
|
||||
DOWN_RIGHT = "\u250c"
|
||||
VERTICAL_RIGHT = "\u251c"
|
||||
VERTICAL_LEFT = "\u2524"
|
||||
VERTICAL_HORIZONTAL = "\u253c"
|
||||
UP_RIGHT = "\u2514"
|
||||
VERTICAL = "\u2502"
|
||||
HORIZONTAL = "\u2500"
|
||||
|
||||
DOWN_RIGHT_ROUNDED = "\u256D"
|
||||
UP_RIGHT_ROUNDED = "\u2570"
|
||||
|
||||
DOWN_RIGHT_BOLD = "\u250F"
|
||||
VERTICAL_RIGHT_BOLD = "\u2523"
|
||||
VERTICAL_LEFT_BOLD = "\u252B"
|
||||
VERTICAL_HORIZONTAL_BOLD = "\u254B"
|
||||
UP_RIGHT_BOLD = "\u2517"
|
||||
VERTICAL_BOLD = "\u2503"
|
||||
HORIZONTAL_BOLD = "\u2501"
|
||||
|
||||
DOWN_RIGHT_DOUBLE = "\u2554"
|
||||
VERTICAL_RIGHT_DOUBLE = "\u2560"
|
||||
VERTICAL_LEFT_DOUBLE = "\u2563"
|
||||
VERTICAL_HORIZONTAL_DOUBLE = "\u256C"
|
||||
UP_RIGHT_DOUBLE = "\u255a"
|
||||
VERTICAL_DOUBLE = "\u2551"
|
||||
HORIZONTAL_DOUBLE = "\u2550"
|
||||
|
||||
PRINT_STYLES: Dict[str, Tuple[str, str, str]] = {
|
||||
"ansi": ("| ", "|-- ", "`-- "),
|
||||
"ascii": ("| ", "|-- ", "+-- "),
|
||||
"const": (
|
||||
f"{VERTICAL} ",
|
||||
f"{VERTICAL_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
|
||||
f"{UP_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
|
||||
),
|
||||
"const_bold": (
|
||||
f"{VERTICAL_BOLD} ",
|
||||
f"{VERTICAL_RIGHT_BOLD}{HORIZONTAL_BOLD}{HORIZONTAL_BOLD} ",
|
||||
f"{UP_RIGHT_BOLD}{HORIZONTAL_BOLD}{HORIZONTAL_BOLD} ",
|
||||
),
|
||||
"rounded": (
|
||||
f"{VERTICAL} ",
|
||||
f"{VERTICAL_RIGHT}{HORIZONTAL}{HORIZONTAL} ",
|
||||
f"{UP_RIGHT_ROUNDED}{HORIZONTAL}{HORIZONTAL} ",
|
||||
),
|
||||
"double": (
|
||||
f"{VERTICAL_DOUBLE} ",
|
||||
f"{VERTICAL_RIGHT_DOUBLE}{HORIZONTAL_DOUBLE}{HORIZONTAL_DOUBLE} ",
|
||||
f"{UP_RIGHT_DOUBLE}{HORIZONTAL_DOUBLE}{HORIZONTAL_DOUBLE} ",
|
||||
),
|
||||
}
|
||||
|
||||
HPRINT_STYLES: Dict[str, Tuple[str, str, str, str, str, str, str]] = {
|
||||
"ansi": ("/", "+", "+", "+", "\\", "|", "-"),
|
||||
"ascii": ("+", "+", "+", "+", "+", "|", "-"),
|
||||
"const": (
|
||||
DOWN_RIGHT,
|
||||
VERTICAL_RIGHT,
|
||||
VERTICAL_LEFT,
|
||||
VERTICAL_HORIZONTAL,
|
||||
UP_RIGHT,
|
||||
VERTICAL,
|
||||
HORIZONTAL,
|
||||
),
|
||||
"const_bold": (
|
||||
DOWN_RIGHT_BOLD,
|
||||
VERTICAL_RIGHT_BOLD,
|
||||
VERTICAL_LEFT_BOLD,
|
||||
VERTICAL_HORIZONTAL_BOLD,
|
||||
UP_RIGHT_BOLD,
|
||||
VERTICAL_BOLD,
|
||||
HORIZONTAL_BOLD,
|
||||
),
|
||||
"rounded": (
|
||||
DOWN_RIGHT_ROUNDED,
|
||||
VERTICAL_RIGHT,
|
||||
VERTICAL_LEFT,
|
||||
VERTICAL_HORIZONTAL,
|
||||
UP_RIGHT_ROUNDED,
|
||||
VERTICAL,
|
||||
HORIZONTAL,
|
||||
),
|
||||
"double": (
|
||||
DOWN_RIGHT_DOUBLE,
|
||||
VERTICAL_RIGHT_DOUBLE,
|
||||
VERTICAL_LEFT_DOUBLE,
|
||||
VERTICAL_HORIZONTAL_DOUBLE,
|
||||
UP_RIGHT_DOUBLE,
|
||||
VERTICAL_DOUBLE,
|
||||
HORIZONTAL_DOUBLE,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class MermaidConstants:
|
||||
RANK_DIR: List[str] = ["TB", "BT", "LR", "RL"]
|
||||
LINE_SHAPES: List[str] = [
|
||||
"basis",
|
||||
"bumpX",
|
||||
"bumpY",
|
||||
"cardinal",
|
||||
"catmullRom",
|
||||
"linear",
|
||||
"monotoneX",
|
||||
"monotoneY",
|
||||
"natural",
|
||||
"step",
|
||||
"stepAfter",
|
||||
"stepBefore",
|
||||
]
|
||||
NODE_SHAPES: Dict[str, str] = {
|
||||
"rounded_edge": """("{label}")""",
|
||||
"stadium": """(["{label}"])""",
|
||||
"subroutine": """[["{label}"]]""",
|
||||
"cylindrical": """[("{label}")]""",
|
||||
"circle": """(("{label}"))""",
|
||||
"asymmetric": """>"{label}"]""",
|
||||
"rhombus": """{{"{label}"}}""",
|
||||
"hexagon": """{{{{"{label}"}}}}""",
|
||||
"parallelogram": """[/"{label}"/]""",
|
||||
"parallelogram_alt": """[\\"{label}"\\]""",
|
||||
"trapezoid": """[/"{label}"\\]""",
|
||||
"trapezoid_alt": """[\\"{label}"/]""",
|
||||
"double_circle": """((("{label}")))""",
|
||||
}
|
||||
EDGE_ARROWS: Dict[str, str] = {
|
||||
"normal": "-->",
|
||||
"bold": "==>",
|
||||
"dotted": "-.->",
|
||||
"open": "---",
|
||||
"bold_open": "===",
|
||||
"dotted_open": "-.-",
|
||||
"invisible": "~~~",
|
||||
"circle": "--o",
|
||||
"cross": "--x",
|
||||
"double_normal": "<-->",
|
||||
"double_circle": "o--o",
|
||||
"double_cross": "x--x",
|
||||
}
|
||||
|
||||
|
||||
class NewickState(Enum):
|
||||
PARSE_STRING = auto()
|
||||
PARSE_ATTRIBUTE_NAME = auto()
|
||||
PARSE_ATTRIBUTE_VALUE = auto()
|
||||
|
||||
|
||||
class NewickCharacter(str, Enum):
|
||||
OPEN_BRACKET = "("
|
||||
CLOSE_BRACKET = ")"
|
||||
ATTR_START = "["
|
||||
ATTR_END = "]"
|
||||
ATTR_KEY_VALUE = "="
|
||||
ATTR_QUOTE = "'"
|
||||
SEP = ":"
|
||||
NODE_SEP = ","
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> List[str]:
|
||||
return [c.value for c in cls]
|
||||
126
python310/packages/bigtree/utils/exceptions.py
Normal file
126
python310/packages/bigtree/utils/exceptions.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
from warnings import simplefilter, warn
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TreeError(Exception):
|
||||
"""Generic tree exception"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopError(TreeError):
|
||||
"""Error during node creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CorruptedTreeError(TreeError):
|
||||
"""Error during node creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DuplicatedNodeError(TreeError):
|
||||
"""Error during tree creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotFoundError(TreeError):
|
||||
"""Error during tree pruning or modification"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SearchError(TreeError):
|
||||
"""Error during tree search"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def deprecated(
|
||||
alias: str,
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]: # pragma: no cover
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
This is a decorator which can be used to mark functions as deprecated.
|
||||
It will raise a DeprecationWarning when the function is used.
|
||||
Source: https://stackoverflow.com/a/30253848
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
simplefilter("always", DeprecationWarning)
|
||||
warn(
|
||||
"{old_func} is going to be deprecated, use {new_func} instead".format(
|
||||
old_func=func.__name__,
|
||||
new_func=alias,
|
||||
),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
simplefilter("default", DeprecationWarning) # reset filter
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def optional_dependencies_pandas(
|
||||
func: Callable[..., T]
|
||||
) -> Callable[..., T]: # pragma: no cover
|
||||
"""
|
||||
This is a decorator which can be used to import optional pandas dependency.
|
||||
It will raise a ImportError if the module is not found.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
try:
|
||||
import pandas as pd # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas not available. Please perform a\n\n"
|
||||
"pip install 'bigtree[pandas]'\n\nto install required dependencies"
|
||||
) from None
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def optional_dependencies_image(
|
||||
package_name: str = "",
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
This is a decorator which can be used to import optional image dependency.
|
||||
It will raise a ImportError if the module is not found.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
if not package_name or package_name == "pydot":
|
||||
try:
|
||||
import pydot # noqa: F401
|
||||
except ImportError: # pragma: no cover
|
||||
raise ImportError(
|
||||
"pydot not available. Please perform a\n\n"
|
||||
"pip install 'bigtree[image]'\n\nto install required dependencies"
|
||||
) from None
|
||||
if not package_name or package_name == "Pillow":
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont # noqa: F401
|
||||
except ImportError: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Pillow not available. Please perform a\n\n"
|
||||
"pip install 'bigtree[image]'\n\nto install required dependencies"
|
||||
) from None
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
19
python310/packages/bigtree/utils/groot.py
Normal file
19
python310/packages/bigtree/utils/groot.py
Normal file
@@ -0,0 +1,19 @@
|
||||
def whoami() -> str:
|
||||
"""Groot utils
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return "I am Groot!"
|
||||
|
||||
|
||||
def speak_like_groot(sentence: str) -> str:
|
||||
"""Convert sentence into Groot langauge
|
||||
|
||||
Args:
|
||||
sentence (str): Sentence to convert to groot language
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return " ".join([whoami() for _ in range(len(sentence.split()))])
|
||||
587
python310/packages/bigtree/utils/iterators.py
Normal file
587
python310/packages/bigtree/utils/iterators.py
Normal file
@@ -0,0 +1,587 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
|
||||
BaseNodeT = TypeVar("BaseNodeT", bound=BaseNode)
|
||||
BinaryNodeT = TypeVar("BinaryNodeT", bound=BinaryNode)
|
||||
DAGNodeT = TypeVar("DAGNodeT", bound=DAGNode)
|
||||
T = TypeVar("T", bound=Union[BaseNode, DAGNode])
|
||||
|
||||
__all__ = [
|
||||
"inorder_iter",
|
||||
"preorder_iter",
|
||||
"postorder_iter",
|
||||
"levelorder_iter",
|
||||
"levelordergroup_iter",
|
||||
"zigzag_iter",
|
||||
"zigzaggroup_iter",
|
||||
"dag_iterator",
|
||||
]
|
||||
|
||||
|
||||
def inorder_iter(
|
||||
tree: BinaryNodeT,
|
||||
filter_condition: Optional[Callable[[BinaryNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[BinaryNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
In-Order Iteration Algorithm, LNR
|
||||
1. Recursively traverse the current node's left subtree.
|
||||
2. Visit the current node.
|
||||
3. Recursively traverse the current node's right subtree.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import BinaryNode, list_to_binarytree, inorder_iter
|
||||
>>> num_list = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
>>> root = list_to_binarytree(num_list)
|
||||
>>> root.show()
|
||||
1
|
||||
├── 2
|
||||
│ ├── 4
|
||||
│ │ └── 8
|
||||
│ └── 5
|
||||
└── 3
|
||||
├── 6
|
||||
└── 7
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root)]
|
||||
['8', '4', '2', '5', '1', '6', '3', '7']
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root, filter_condition=lambda x: x.node_name in ["1", "4", "3", "6", "7"])]
|
||||
['4', '1', '6', '3', '7']
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root, max_depth=3)]
|
||||
['4', '2', '5', '1', '6', '3', '7']
|
||||
|
||||
Args:
|
||||
tree (BinaryNode): input tree
|
||||
filter_condition (Optional[Callable[[BinaryNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Iterable[BinaryNode])
|
||||
"""
|
||||
if tree and (not max_depth or not tree.depth > max_depth):
|
||||
yield from inorder_iter(tree.left, filter_condition, max_depth)
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
yield from inorder_iter(tree.right, filter_condition, max_depth)
|
||||
|
||||
|
||||
def preorder_iter(
|
||||
tree: T,
|
||||
filter_condition: Optional[Callable[[T], bool]] = None,
|
||||
stop_condition: Optional[Callable[[T], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[T]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Pre-Order Iteration Algorithm, NLR
|
||||
1. Visit the current node.
|
||||
2. Recursively traverse the current node's left subtree.
|
||||
3. Recursively traverse the current node's right subtree.
|
||||
|
||||
It is topologically sorted because a parent node is processed before its child nodes.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, preorder_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root)]
|
||||
['a', 'b', 'd', 'e', 'g', 'h', 'c', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['a', 'd', 'e', 'g', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
['a', 'b', 'd', 'c', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, max_depth=3)]
|
||||
['a', 'b', 'd', 'e', 'c', 'f']
|
||||
|
||||
Args:
|
||||
tree (Union[BaseNode, DAGNode]): input tree
|
||||
filter_condition (Optional[Callable[[T], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[T], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Union[Iterable[BaseNode], Iterable[DAGNode]])
|
||||
"""
|
||||
if (
|
||||
tree
|
||||
and (not max_depth or not tree.get_attr("depth") > max_depth)
|
||||
and (not stop_condition or not stop_condition(tree))
|
||||
):
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
for child in tree.children:
|
||||
yield from preorder_iter(child, filter_condition, stop_condition, max_depth) # type: ignore
|
||||
|
||||
|
||||
def postorder_iter(
|
||||
tree: BaseNodeT,
|
||||
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[BaseNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Post-Order Iteration Algorithm, LRN
|
||||
1. Recursively traverse the current node's left subtree.
|
||||
2. Recursively traverse the current node's right subtree.
|
||||
3. Visit the current node.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, postorder_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root)]
|
||||
['d', 'g', 'h', 'e', 'b', 'f', 'c', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['d', 'g', 'e', 'f', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
['d', 'b', 'f', 'c', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, max_depth=3)]
|
||||
['d', 'e', 'b', 'f', 'c', 'a']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
if (
|
||||
tree
|
||||
and (not max_depth or not tree.depth > max_depth)
|
||||
and (not stop_condition or not stop_condition(tree))
|
||||
):
|
||||
for child in tree.children:
|
||||
yield from postorder_iter(
|
||||
child, filter_condition, stop_condition, max_depth
|
||||
)
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
|
||||
|
||||
def levelorder_iter(
|
||||
tree: BaseNodeT,
|
||||
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[BaseNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Level-Order Iteration Algorithm
|
||||
1. Recursively traverse the nodes on same level.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, levelorder_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root)]
|
||||
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['a', 'd', 'e', 'f', 'g']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
['a', 'b', 'c', 'd', 'f']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, max_depth=3)]
|
||||
['a', 'b', 'c', 'd', 'e', 'f']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
|
||||
def _levelorder_iter(trees: List[BaseNodeT]) -> Iterable[BaseNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Args:
|
||||
trees (List[BaseNode]): trees to get children for next level
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
next_level = []
|
||||
for _tree in trees:
|
||||
if _tree:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
yield _tree
|
||||
next_level.extend(list(_tree.children))
|
||||
if len(next_level):
|
||||
yield from _levelorder_iter(next_level)
|
||||
|
||||
yield from _levelorder_iter([tree])
|
||||
|
||||
|
||||
def levelordergroup_iter(
|
||||
tree: BaseNodeT,
|
||||
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[Iterable[BaseNodeT]]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Level-Order Group Iteration Algorithm
|
||||
1. Recursively traverse the nodes on same level, returns nodes level by level in a nested list.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, levelordergroup_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root)]
|
||||
[['a'], ['b', 'c'], ['d', 'e', 'f'], ['g', 'h']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
[['a'], [], ['d', 'e', 'f'], ['g']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
[['a'], ['b', 'c'], ['d', 'f']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, max_depth=3)]
|
||||
[['a'], ['b', 'c'], ['d', 'e', 'f']]
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable[BaseNode]])
|
||||
"""
|
||||
|
||||
def _levelordergroup_iter(trees: List[BaseNodeT]) -> Iterable[Iterable[BaseNodeT]]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Args:
|
||||
trees (List[BaseNode]): trees to get children for next level
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable[BaseNode]])
|
||||
"""
|
||||
current_tree = []
|
||||
next_level = []
|
||||
for _tree in trees:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
current_tree.append(_tree)
|
||||
next_level.extend([_child for _child in _tree.children if _child])
|
||||
yield tuple(current_tree)
|
||||
if len(next_level) and (not max_depth or not next_level[0].depth > max_depth):
|
||||
yield from _levelordergroup_iter(next_level)
|
||||
|
||||
yield from _levelordergroup_iter([tree])
|
||||
|
||||
|
||||
def zigzag_iter(
|
||||
tree: BaseNodeT,
|
||||
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[BaseNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
ZigZag Iteration Algorithm
|
||||
1. Recursively traverse the nodes on same level, in a zigzag manner across different levels.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, zigzag_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in zigzag_iter(root)]
|
||||
['a', 'c', 'b', 'd', 'e', 'f', 'h', 'g']
|
||||
|
||||
>>> [node.node_name for node in zigzag_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['a', 'd', 'e', 'f', 'g']
|
||||
|
||||
>>> [node.node_name for node in zigzag_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
['a', 'c', 'b', 'd', 'f']
|
||||
|
||||
>>> [node.node_name for node in zigzag_iter(root, max_depth=3)]
|
||||
['a', 'c', 'b', 'd', 'e', 'f']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
|
||||
def _zigzag_iter(
|
||||
trees: List[BaseNodeT], reverse_indicator: bool = False
|
||||
) -> Iterable[BaseNodeT]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Args:
|
||||
trees (List[BaseNode]): trees to get children for next level
|
||||
reverse_indicator (bool): indicator whether it is in reverse order
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
next_level = []
|
||||
for _tree in trees:
|
||||
if _tree:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
yield _tree
|
||||
next_level_nodes = list(_tree.children)
|
||||
if reverse_indicator:
|
||||
next_level_nodes = next_level_nodes[::-1]
|
||||
next_level.extend(next_level_nodes)
|
||||
if len(next_level):
|
||||
yield from _zigzag_iter(
|
||||
next_level[::-1], reverse_indicator=not reverse_indicator
|
||||
)
|
||||
|
||||
yield from _zigzag_iter([tree])
|
||||
|
||||
|
||||
def zigzaggroup_iter(
|
||||
tree: BaseNodeT,
|
||||
filter_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
stop_condition: Optional[Callable[[BaseNodeT], bool]] = None,
|
||||
max_depth: int = 0,
|
||||
) -> Iterable[Iterable[BaseNodeT]]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
ZigZag Group Iteration Algorithm
|
||||
1. Recursively traverse the nodes on same level, in a zigzag manner across different levels,
|
||||
returns nodes level by level in a nested list.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import Node, list_to_tree, zigzaggroup_iter
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root)]
|
||||
[['a'], ['c', 'b'], ['d', 'e', 'f'], ['h', 'g']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
[['a'], [], ['d', 'e', 'f'], ['g']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, stop_condition=lambda x: x.node_name == "e")]
|
||||
[['a'], ['c', 'b'], ['d', 'f']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in zigzaggroup_iter(root, max_depth=3)]
|
||||
[['a'], ['c', 'b'], ['d', 'e', 'f']]
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Return node if condition evaluates to `True`
|
||||
stop_condition (Optional[Callable[[BaseNode], bool]]): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable[BaseNode]])
|
||||
"""
|
||||
|
||||
def _zigzaggroup_iter(
|
||||
trees: List[BaseNodeT], reverse_indicator: bool = False
|
||||
) -> Iterable[Iterable[BaseNodeT]]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Args:
|
||||
trees (List[BaseNode]): trees to get children for next level
|
||||
reverse_indicator (bool): indicator whether it is in reverse order
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable[BaseNode]])
|
||||
"""
|
||||
current_tree = []
|
||||
next_level = []
|
||||
for _tree in trees:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
current_tree.append(_tree)
|
||||
next_level_nodes = [_child for _child in _tree.children if _child]
|
||||
if reverse_indicator:
|
||||
next_level_nodes = next_level_nodes[::-1]
|
||||
next_level.extend(next_level_nodes)
|
||||
yield tuple(current_tree)
|
||||
if len(next_level) and (not max_depth or not next_level[0].depth > max_depth):
|
||||
yield from _zigzaggroup_iter(
|
||||
next_level[::-1], reverse_indicator=not reverse_indicator
|
||||
)
|
||||
|
||||
yield from _zigzaggroup_iter([tree])
|
||||
|
||||
|
||||
def dag_iterator(dag: DAGNodeT) -> Iterable[Tuple[DAGNodeT, DAGNodeT]]:
|
||||
"""Iterate through all nodes of a Directed Acyclic Graph (DAG).
|
||||
Note that node names must be unique.
|
||||
Note that DAG must at least have two nodes to be shown on graph.
|
||||
|
||||
1. Visit the current node.
|
||||
2. Recursively traverse the current node's parents.
|
||||
3. Recursively traverse the current node's children.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import DAGNode, dag_iterator
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(a)]
|
||||
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
|
||||
|
||||
Args:
|
||||
dag (DAGNode): input dag
|
||||
|
||||
Returns:
|
||||
(Iterable[Tuple[DAGNode, DAGNode]])
|
||||
"""
|
||||
visited_nodes = set()
|
||||
|
||||
def _dag_iterator(node: DAGNodeT) -> Iterable[Tuple[DAGNodeT, DAGNodeT]]:
|
||||
"""Iterate through all children of a DAG.
|
||||
|
||||
Args:
|
||||
node (DAGNode): current node
|
||||
|
||||
Returns:
|
||||
(Iterable[Tuple[DAGNode, DAGNode]])
|
||||
"""
|
||||
node_name = node.node_name
|
||||
visited_nodes.add(node_name)
|
||||
|
||||
# Parse upwards
|
||||
for parent in node.parents:
|
||||
parent_name = parent.node_name
|
||||
if parent_name not in visited_nodes:
|
||||
yield parent, node
|
||||
|
||||
# Parse downwards
|
||||
for child in node.children:
|
||||
child_name = child.node_name
|
||||
if child_name not in visited_nodes:
|
||||
yield node, child
|
||||
|
||||
# Parse upwards
|
||||
for parent in node.parents:
|
||||
parent_name = parent.node_name
|
||||
if parent_name not in visited_nodes:
|
||||
yield from _dag_iterator(parent)
|
||||
|
||||
# Parse downwards
|
||||
for child in node.children:
|
||||
child_name = child.node_name
|
||||
if child_name not in visited_nodes:
|
||||
yield from _dag_iterator(child)
|
||||
|
||||
yield from _dag_iterator(dag)
|
||||
354
python310/packages/bigtree/utils/plot.py
Normal file
354
python310/packages/bigtree/utils/plot.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
|
||||
T = TypeVar("T", bound=BaseNode)
|
||||
|
||||
__all__ = [
|
||||
"reingold_tilford",
|
||||
]
|
||||
|
||||
|
||||
def reingold_tilford(
|
||||
tree_node: T,
|
||||
sibling_separation: float = 1.0,
|
||||
subtree_separation: float = 1.0,
|
||||
level_separation: float = 1.0,
|
||||
x_offset: float = 0.0,
|
||||
y_offset: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Algorithm for drawing tree structure, retrieves `(x, y)` coordinates for a tree structure.
|
||||
Adds `x` and `y` attributes to every node in the tree. Modifies tree in-place.
|
||||
|
||||
This algorithm[1] is an improvement over Reingold Tilford algorithm[2].
|
||||
|
||||
According to Reingold Tilford's paper, a tree diagram should satisfy the following aesthetic rules,
|
||||
|
||||
1. Nodes at the same depth should lie along a straight line, and the straight lines defining the depths should be parallel.
|
||||
2. A left child should be positioned to the left of its parent node and a right child to the right.
|
||||
3. A parent should be centered over their children.
|
||||
4. A tree and its mirror image should produce drawings that are reflections of one another; a subtree should be drawn the same way regardless of where it occurs in the tree.
|
||||
|
||||
Examples:
|
||||
>>> from bigtree import reingold_tilford, list_to_tree
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> root.show()
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> reingold_tilford(root)
|
||||
>>> root.show(attr_list=["x", "y"])
|
||||
a [x=1.25, y=3.0]
|
||||
├── b [x=0.5, y=2.0]
|
||||
│ ├── d [x=0.0, y=1.0]
|
||||
│ └── e [x=1.0, y=1.0]
|
||||
│ ├── g [x=0.5, y=0.0]
|
||||
│ └── h [x=1.5, y=0.0]
|
||||
└── c [x=2.0, y=2.0]
|
||||
└── f [x=2.0, y=1.0]
|
||||
|
||||
References
|
||||
|
||||
- [1] Walker, J. (1991). Positioning Nodes for General Trees. https://www.drdobbs.com/positioning-nodes-for-general-trees/184402320?pgno=4
|
||||
- [2] Reingold, E., Tilford, J. (1981). Tidier Drawings of Trees. IEEE Transactions on Software Engineering. https://reingold.co/tidier-drawings.pdf
|
||||
|
||||
Args:
|
||||
tree_node (BaseNode): tree to compute (x, y) coordinate
|
||||
sibling_separation (float): minimum distance between adjacent siblings of the tree
|
||||
subtree_separation (float): minimum distance between adjacent subtrees of the tree
|
||||
level_separation (float): fixed distance between adjacent levels of the tree
|
||||
x_offset (float): graph offset of x-coordinates
|
||||
y_offset (float): graph offset of y-coordinates
|
||||
"""
|
||||
_first_pass(tree_node, sibling_separation, subtree_separation)
|
||||
x_adjustment = _second_pass(tree_node, level_separation, x_offset, y_offset)
|
||||
_third_pass(tree_node, x_adjustment)
|
||||
|
||||
|
||||
def _first_pass(
|
||||
tree_node: T, sibling_separation: float, subtree_separation: float
|
||||
) -> None:
|
||||
"""
|
||||
Performs post-order traversal of tree and assigns `x`, `mod` and `shift` values to each node.
|
||||
Modifies tree in-place.
|
||||
|
||||
Notation:
|
||||
- `lsibling`: left-sibling of node
|
||||
- `lchild`: last child of node
|
||||
- `fchild`: first child of node
|
||||
- `midpoint`: midpoint of node wrt children, :math:`midpoint = (lchild.x + fchild.x) / 2`
|
||||
- `sibling distance`: sibling separation
|
||||
- `subtree distance`: subtree separation
|
||||
|
||||
There are two parts in the first pass,
|
||||
|
||||
1. In the first part, we assign `x` and `mod` values to each node
|
||||
|
||||
`x` value is the initial x-position of each node purely based on the node's position
|
||||
- :math:`x = 0` for leftmost node and :math:`x = lsibling.x + sibling distance` for other nodes
|
||||
- Special case when leftmost node has children, then it will try to center itself, :math:`x = midpoint`
|
||||
|
||||
`mod` value is the amount to shift the subtree (all descendant nodes excluding itself) to make the children centered with itself
|
||||
- :math:`mod = 0` for node does not have children (no need to shift subtree) or it is a leftmost node (parent is already centered, from above point)
|
||||
- Special case when non-leftmost nodes have children, :math:`mod = x - midpoint`
|
||||
|
||||
2. In the second part, we assign `shift` value of nodes due to overlapping subtrees.
|
||||
|
||||
For each node on the same level, ensure that the leftmost descendant does not intersect with the rightmost
|
||||
descendant of any left sibling at every subsequent level. Intersection happens when the subtrees are not
|
||||
at least `subtree distance` apart.
|
||||
|
||||
If there are any intersections, shift the whole subtree by a new `shift` value, shift any left sibling by a
|
||||
fraction of `shift` value, and shift any right sibling by `shift` + a multiple of the fraction of
|
||||
`shift` value to keep nodes centralized at the level.
|
||||
|
||||
Args:
|
||||
tree_node (BaseNode): tree to compute (x, y) coordinate
|
||||
sibling_separation (float): minimum distance between adjacent siblings of the tree
|
||||
subtree_separation (float): minimum distance between adjacent subtrees of the tree
|
||||
"""
|
||||
# Post-order iteration (LRN)
|
||||
for child in tree_node.children:
|
||||
_first_pass(child, sibling_separation, subtree_separation)
|
||||
|
||||
_x = 0.0
|
||||
_mod = 0.0
|
||||
_shift = 0.0
|
||||
_midpoint = 0.0
|
||||
|
||||
if tree_node.is_root:
|
||||
tree_node.set_attrs({"x": _get_midpoint_of_children(tree_node)})
|
||||
tree_node.set_attrs({"mod": _mod})
|
||||
tree_node.set_attrs({"shift": _shift})
|
||||
|
||||
else:
|
||||
# First part - assign x and mod values
|
||||
|
||||
if tree_node.children:
|
||||
_midpoint = _get_midpoint_of_children(tree_node)
|
||||
|
||||
# Non-leftmost node
|
||||
if tree_node.left_sibling:
|
||||
_x = tree_node.left_sibling.get_attr("x") + sibling_separation
|
||||
if tree_node.children:
|
||||
_mod = _x - _midpoint
|
||||
# Leftmost node
|
||||
else:
|
||||
if tree_node.children:
|
||||
_x = _midpoint
|
||||
|
||||
tree_node.set_attrs({"x": _x})
|
||||
tree_node.set_attrs({"mod": _mod})
|
||||
tree_node.set_attrs({"shift": tree_node.get_attr("shift", _shift)})
|
||||
|
||||
# Second part - assign shift values due to overlapping subtrees
|
||||
|
||||
parent_node = tree_node.parent
|
||||
tree_node_idx = parent_node.children.index(tree_node)
|
||||
if tree_node_idx:
|
||||
for idx_node in range(tree_node_idx):
|
||||
left_subtree = parent_node.children[idx_node]
|
||||
_shift = max(
|
||||
_shift,
|
||||
_get_subtree_shift(
|
||||
left_subtree=left_subtree,
|
||||
right_subtree=tree_node,
|
||||
left_idx=idx_node,
|
||||
right_idx=tree_node_idx,
|
||||
subtree_separation=subtree_separation,
|
||||
),
|
||||
)
|
||||
|
||||
# Shift siblings (left siblings, itself, right siblings) accordingly
|
||||
for multiple, sibling in enumerate(parent_node.children):
|
||||
sibling.set_attrs(
|
||||
{
|
||||
"shift": sibling.get_attr("shift", 0)
|
||||
+ (_shift * multiple / tree_node_idx)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_midpoint_of_children(tree_node: BaseNode) -> float:
|
||||
"""Get midpoint of children of a node
|
||||
|
||||
Args:
|
||||
tree_node (BaseNode): tree node to obtain midpoint of their child/children
|
||||
|
||||
Returns:
|
||||
(float)
|
||||
"""
|
||||
if tree_node.children:
|
||||
first_child_x: float = tree_node.children[0].get_attr("x") + tree_node.children[
|
||||
0
|
||||
].get_attr("shift")
|
||||
last_child_x: float = tree_node.children[-1].get_attr("x") + tree_node.children[
|
||||
-1
|
||||
].get_attr("shift")
|
||||
return (last_child_x + first_child_x) / 2
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_subtree_shift(
|
||||
left_subtree: T,
|
||||
right_subtree: T,
|
||||
left_idx: int,
|
||||
right_idx: int,
|
||||
subtree_separation: float,
|
||||
left_cum_shift: float = 0,
|
||||
right_cum_shift: float = 0,
|
||||
cum_shift: float = 0,
|
||||
initial_run: bool = True,
|
||||
) -> float:
|
||||
"""Get shift amount to shift the right subtree towards the right such that it does not overlap with the left subtree
|
||||
|
||||
Args:
|
||||
left_subtree (BaseNode): left subtree, with right contour to be traversed
|
||||
right_subtree (BaseNode): right subtree, with left contour to be traversed
|
||||
left_idx (int): index of left subtree, to compute overlap for relative shift (constant across iteration)
|
||||
right_idx (int): index of right subtree, to compute overlap for relative shift (constant across iteration)
|
||||
subtree_separation (float): minimum distance between adjacent subtrees of the tree (constant across iteration)
|
||||
left_cum_shift (float): cumulative `mod + shift` for left subtree from the ancestors, defaults to 0
|
||||
right_cum_shift (float): cumulative `mod + shift` for right subtree from the ancestors, defaults to 0
|
||||
cum_shift (float): cumulative shift amount for right subtree, defaults to 0
|
||||
initial_run (bool): indicates whether left_subtree and right_subtree are the main subtrees, defaults to True
|
||||
|
||||
Returns:
|
||||
(float)
|
||||
"""
|
||||
new_shift = 0.0
|
||||
|
||||
if not initial_run:
|
||||
x_left = (
|
||||
left_subtree.get_attr("x") + left_subtree.get_attr("shift") + left_cum_shift
|
||||
)
|
||||
x_right = (
|
||||
right_subtree.get_attr("x")
|
||||
+ right_subtree.get_attr("shift")
|
||||
+ right_cum_shift
|
||||
+ cum_shift
|
||||
)
|
||||
new_shift = max(
|
||||
(x_left + subtree_separation - x_right) / (1 - left_idx / right_idx), 0
|
||||
)
|
||||
|
||||
# Search for a left sibling of left_subtree that has children
|
||||
while left_subtree and not left_subtree.children and left_subtree.left_sibling:
|
||||
left_subtree = left_subtree.left_sibling
|
||||
|
||||
# Search for a right sibling of right_subtree that has children
|
||||
while (
|
||||
right_subtree and not right_subtree.children and right_subtree.right_sibling
|
||||
):
|
||||
right_subtree = right_subtree.right_sibling
|
||||
|
||||
if left_subtree.children and right_subtree.children:
|
||||
# Iterate down the level, for the rightmost child of left_subtree and the leftmost child of right_subtree
|
||||
return _get_subtree_shift(
|
||||
left_subtree=left_subtree.children[-1],
|
||||
right_subtree=right_subtree.children[0],
|
||||
left_idx=left_idx,
|
||||
right_idx=right_idx,
|
||||
subtree_separation=subtree_separation,
|
||||
left_cum_shift=(
|
||||
left_cum_shift
|
||||
+ left_subtree.get_attr("mod")
|
||||
+ left_subtree.get_attr("shift")
|
||||
),
|
||||
right_cum_shift=(
|
||||
right_cum_shift
|
||||
+ right_subtree.get_attr("mod")
|
||||
+ right_subtree.get_attr("shift")
|
||||
),
|
||||
cum_shift=cum_shift + new_shift,
|
||||
initial_run=False,
|
||||
)
|
||||
|
||||
return cum_shift + new_shift
|
||||
|
||||
|
||||
def _second_pass(
|
||||
tree_node: T,
|
||||
level_separation: float,
|
||||
x_offset: float,
|
||||
y_offset: float,
|
||||
cum_mod: Optional[float] = 0.0,
|
||||
max_depth: Optional[int] = None,
|
||||
x_adjustment: Optional[float] = 0.0,
|
||||
) -> float:
|
||||
"""
|
||||
Performs pre-order traversal of tree and determines the final `x` and `y` values for each node.
|
||||
Modifies tree in-place.
|
||||
|
||||
Notation:
|
||||
- `depth`: maximum depth of tree
|
||||
- `distance`: level separation
|
||||
- `x'`: x offset
|
||||
- `y'`: y offset
|
||||
|
||||
Final position of each node
|
||||
- :math:`x = node.x + node.shift + sum(ancestor.mod) + x'`
|
||||
- :math:`y = (depth - node.depth) * distance + y'`
|
||||
|
||||
Args:
|
||||
tree_node (BaseNode): tree to compute (x, y) coordinate
|
||||
level_separation (float): fixed distance between adjacent levels of the tree (constant across iteration)
|
||||
x_offset (float): graph offset of x-coordinates (constant across iteration)
|
||||
y_offset (float): graph offset of y-coordinates (constant across iteration)
|
||||
cum_mod (Optional[float]): cumulative `mod + shift` for tree/subtree from the ancestors
|
||||
max_depth (Optional[int]): maximum depth of tree (constant across iteration)
|
||||
x_adjustment (Optional[float]): amount of x-adjustment for third pass, in case any x-coordinates goes below 0
|
||||
|
||||
Returns
|
||||
(float)
|
||||
"""
|
||||
if not max_depth:
|
||||
max_depth = tree_node.max_depth
|
||||
|
||||
final_x: float = (
|
||||
tree_node.get_attr("x") + tree_node.get_attr("shift") + cum_mod + x_offset
|
||||
)
|
||||
final_y: float = (max_depth - tree_node.depth) * level_separation + y_offset
|
||||
tree_node.set_attrs({"x": final_x, "y": final_y})
|
||||
|
||||
# Pre-order iteration (NLR)
|
||||
if tree_node.children:
|
||||
return max(
|
||||
[
|
||||
_second_pass(
|
||||
child,
|
||||
level_separation,
|
||||
x_offset,
|
||||
y_offset,
|
||||
cum_mod + tree_node.get_attr("mod") + tree_node.get_attr("shift"),
|
||||
max_depth,
|
||||
x_adjustment,
|
||||
)
|
||||
for child in tree_node.children
|
||||
]
|
||||
)
|
||||
return max(x_adjustment, -final_x)
|
||||
|
||||
|
||||
def _third_pass(tree_node: BaseNode, x_adjustment: float) -> None:
|
||||
"""Adjust all x-coordinates by an adjustment value so that every x-coordinate is greater than or equal to 0.
|
||||
Modifies tree in-place.
|
||||
|
||||
Args:
|
||||
tree_node (BaseNode): tree to compute (x, y) coordinate
|
||||
x_adjustment (float): amount of adjustment for x-coordinates (constant across iteration)
|
||||
"""
|
||||
if x_adjustment:
|
||||
tree_node.set_attrs({"x": tree_node.get_attr("x") + x_adjustment})
|
||||
|
||||
# Pre-order iteration (NLR)
|
||||
for child in tree_node.children:
|
||||
_third_pass(child, x_adjustment)
|
||||
0
python310/packages/bigtree/workflows/__init__.py
Normal file
0
python310/packages/bigtree/workflows/__init__.py
Normal file
200
python310/packages/bigtree/workflows/app_calendar.py
Normal file
200
python310/packages/bigtree/workflows/app_calendar.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import add_path_to_tree
|
||||
from bigtree.tree.export import tree_to_dataframe
|
||||
from bigtree.tree.search import find_full_path, findall
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError: # pragma: no cover
|
||||
pd = None
|
||||
|
||||
|
||||
class Calendar:
|
||||
"""
|
||||
Calendar Implementation with Big Tree.
|
||||
- Calendar has four levels - year, month, day, and event name (with event attributes)
|
||||
|
||||
Examples:
|
||||
*Initializing and Adding Events*
|
||||
|
||||
>>> from bigtree import Calendar
|
||||
>>> calendar = Calendar("My Calendar")
|
||||
>>> calendar.add_event("Gym", "2023-01-01 18:00")
|
||||
>>> calendar.add_event("Dinner", "2023-01-01", date_format="%Y-%m-%d", budget=20)
|
||||
>>> calendar.add_event("Gym", "2023-01-02 18:00")
|
||||
>>> calendar.show()
|
||||
My Calendar
|
||||
2023-01-01 00:00:00 - Dinner (budget: 20)
|
||||
2023-01-01 18:00:00 - Gym
|
||||
2023-01-02 18:00:00 - Gym
|
||||
|
||||
*Search for Events*
|
||||
|
||||
>>> calendar.find_event("Gym")
|
||||
2023-01-01 18:00:00 - Gym
|
||||
2023-01-02 18:00:00 - Gym
|
||||
|
||||
*Removing Events*
|
||||
|
||||
>>> import datetime as dt
|
||||
>>> calendar.delete_event("Gym", dt.date(2023, 1, 1))
|
||||
>>> calendar.show()
|
||||
My Calendar
|
||||
2023-01-01 00:00:00 - Dinner (budget: 20)
|
||||
2023-01-02 18:00:00 - Gym
|
||||
|
||||
*Export Calendar*
|
||||
|
||||
>>> calendar.to_dataframe()
|
||||
path name date time budget
|
||||
0 /My Calendar/2023/01/01/Dinner Dinner 2023-01-01 00:00:00 20.0
|
||||
1 /My Calendar/2023/01/02/Gym Gym 2023-01-02 18:00:00 NaN
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.calendar = Node(name)
|
||||
self.__sorted = True
|
||||
|
||||
def add_event(
|
||||
self,
|
||||
event_name: str,
|
||||
event_datetime: Union[str, dt.datetime],
|
||||
date_format: str = "%Y-%m-%d %H:%M",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add event to calendar
|
||||
|
||||
Args:
|
||||
event_name (str): event name to be added
|
||||
event_datetime (Union[str, dt.datetime]): event date and time
|
||||
date_format (str): specify datetime format if event_datetime is str
|
||||
"""
|
||||
if isinstance(event_datetime, str):
|
||||
event_datetime = dt.datetime.strptime(event_datetime, date_format)
|
||||
year, month, day, date, time = (
|
||||
event_datetime.year,
|
||||
str(event_datetime.month).zfill(2),
|
||||
str(event_datetime.day).zfill(2),
|
||||
event_datetime.date(),
|
||||
event_datetime.time(),
|
||||
)
|
||||
event_path = f"{self.calendar.node_name}/{year}/{month}/{day}/{event_name}"
|
||||
event_attr = {"date": date, "time": time, **kwargs}
|
||||
if find_full_path(self.calendar, event_path):
|
||||
print(
|
||||
f"Event {event_name} exists on {date}, overwriting information for {event_name}"
|
||||
)
|
||||
add_path_to_tree(
|
||||
tree=self.calendar,
|
||||
path=event_path,
|
||||
node_attrs=event_attr,
|
||||
)
|
||||
self.__sorted = False
|
||||
|
||||
def delete_event(
|
||||
self, event_name: str, event_date: Optional[dt.date] = None
|
||||
) -> None:
|
||||
"""Delete event from calendar
|
||||
|
||||
Args:
|
||||
event_name (str): event name to be deleted
|
||||
event_date (dt.date): event date to be deleted
|
||||
"""
|
||||
if event_date:
|
||||
year, month, day = (
|
||||
event_date.year,
|
||||
str(event_date.month).zfill(2),
|
||||
str(event_date.day).zfill(2),
|
||||
)
|
||||
event_path = f"{self.calendar.node_name}/{year}/{month}/{day}/{event_name}"
|
||||
event = find_full_path(self.calendar, event_path)
|
||||
if event:
|
||||
self._delete_event(event)
|
||||
else:
|
||||
print(f"Event {event_name} does not exist on {event_date}")
|
||||
else:
|
||||
for event in findall(
|
||||
self.calendar, lambda node: node.node_name == event_name
|
||||
):
|
||||
self._delete_event(event)
|
||||
|
||||
def find_event(self, event_name: str) -> None:
|
||||
"""Find event by name, prints result to console
|
||||
|
||||
Args:
|
||||
event_name (str): event name
|
||||
"""
|
||||
if not self.__sorted:
|
||||
self._sort()
|
||||
for event in findall(self.calendar, lambda node: node.node_name == event_name):
|
||||
self._show(event)
|
||||
|
||||
def show(self) -> None:
|
||||
"""Show calendar, prints result to console"""
|
||||
if not len(self.calendar.children):
|
||||
raise Exception("Calendar is empty!")
|
||||
if not self.__sorted:
|
||||
self._sort()
|
||||
print(self.calendar.node_name)
|
||||
for event in self.calendar.leaves:
|
||||
self._show(event)
|
||||
|
||||
def to_dataframe(self) -> pd.DataFrame:
|
||||
"""
|
||||
Export calendar to DataFrame
|
||||
|
||||
Returns:
|
||||
(pd.DataFrame)
|
||||
"""
|
||||
if not len(self.calendar.children):
|
||||
raise Exception("Calendar is empty!")
|
||||
data = tree_to_dataframe(self.calendar, all_attrs=True, leaf_only=True)
|
||||
compulsory_cols = ["path", "name", "date", "time"]
|
||||
other_cols = list(set(data.columns) - set(compulsory_cols))
|
||||
return data[compulsory_cols + other_cols]
|
||||
|
||||
def _delete_event(self, event: Node) -> None:
|
||||
"""Private method to delete event, delete parent node as well
|
||||
|
||||
Args:
|
||||
event (Node): event to be deleted
|
||||
"""
|
||||
if len(event.parent.children) == 1:
|
||||
if event.parent.parent:
|
||||
self._delete_event(event.parent)
|
||||
event.parent.parent = None
|
||||
else:
|
||||
event.parent = None
|
||||
|
||||
def _sort(self) -> None:
|
||||
"""Private method to sort calendar by event date, followed by event time"""
|
||||
for day_event in findall(self.calendar, lambda node: node.depth <= 4):
|
||||
if day_event.depth < 4:
|
||||
day_event.sort(key=lambda attr: attr.node_name)
|
||||
else:
|
||||
day_event.sort(key=lambda attr: attr.time)
|
||||
self.__sorted = True
|
||||
|
||||
@staticmethod
|
||||
def _show(event: Node) -> None:
|
||||
"""Private method to show event, handles the formatting of event
|
||||
Prints result to console
|
||||
|
||||
Args:
|
||||
event (Node): event
|
||||
"""
|
||||
event_datetime = dt.datetime.combine(
|
||||
event.get_attr("date"), event.get_attr("time")
|
||||
)
|
||||
event_attrs = event.describe(
|
||||
exclude_attributes=["date", "time", "name"], exclude_prefix="_"
|
||||
)
|
||||
event_attrs_str = ", ".join([f"{attr[0]}: {attr[1]}" for attr in event_attrs])
|
||||
if event_attrs_str:
|
||||
event_attrs_str = f" ({event_attrs_str})"
|
||||
print(f"{event_datetime} - {event.node_name}{event_attrs_str}")
|
||||
261
python310/packages/bigtree/workflows/app_todo.py
Normal file
261
python310/packages/bigtree/workflows/app_todo.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import nested_dict_to_tree
|
||||
from bigtree.tree.export import print_tree, tree_to_nested_dict
|
||||
from bigtree.tree.search import find_child_by_name, find_name
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
class AppToDo:
|
||||
"""
|
||||
To-Do List Implementation with Big Tree.
|
||||
- To-Do List has three levels - app name, list name, and item name.
|
||||
- If list name is not given, item will be assigned to a `General` list.
|
||||
|
||||
Examples:
|
||||
*Initializing and Adding Items*
|
||||
|
||||
>>> from bigtree import AppToDo
|
||||
>>> app = AppToDo("To Do App")
|
||||
>>> app.add_item(item_name="Homework 1", list_name="School")
|
||||
>>> app.add_item(item_name=["Milk", "Bread"], list_name="Groceries", description="Urgent")
|
||||
>>> app.add_item(item_name="Cook")
|
||||
>>> app.show()
|
||||
To Do App
|
||||
├── School
|
||||
│ └── Homework 1
|
||||
├── Groceries
|
||||
│ ├── Milk [description=Urgent]
|
||||
│ └── Bread [description=Urgent]
|
||||
└── General
|
||||
└── Cook
|
||||
|
||||
*Reorder List and Item*
|
||||
|
||||
>>> app.prioritize_list(list_name="General")
|
||||
>>> app.show()
|
||||
To Do App
|
||||
├── General
|
||||
│ └── Cook
|
||||
├── School
|
||||
│ └── Homework 1
|
||||
└── Groceries
|
||||
├── Milk [description=Urgent]
|
||||
└── Bread [description=Urgent]
|
||||
|
||||
>>> app.prioritize_item(item_name="Bread")
|
||||
>>> app.show()
|
||||
To Do App
|
||||
├── General
|
||||
│ └── Cook
|
||||
├── School
|
||||
│ └── Homework 1
|
||||
└── Groceries
|
||||
├── Bread [description=Urgent]
|
||||
└── Milk [description=Urgent]
|
||||
|
||||
*Removing Items*
|
||||
|
||||
>>> app.remove_item("Homework 1")
|
||||
>>> app.show()
|
||||
To Do App
|
||||
├── General
|
||||
│ └── Cook
|
||||
└── Groceries
|
||||
├── Bread [description=Urgent]
|
||||
└── Milk [description=Urgent]
|
||||
|
||||
*Exporting and Importing List*
|
||||
|
||||
>>> app.save("assets/docstr/list.json")
|
||||
>>> app2 = AppToDo.load("assets/docstr/list.json")
|
||||
>>> app2.show()
|
||||
To Do App
|
||||
├── General
|
||||
│ └── Cook
|
||||
└── Groceries
|
||||
├── Bread [description=Urgent]
|
||||
└── Milk [description=Urgent]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_name: str = "",
|
||||
):
|
||||
"""Initialize To-Do app
|
||||
|
||||
Args:
|
||||
app_name (str): name of to-do app, optional
|
||||
"""
|
||||
self._root = Node(app_name)
|
||||
|
||||
def add_list(self, list_name: str, **kwargs: Any) -> Node:
|
||||
"""Add list to app
|
||||
|
||||
If list is present, return list node, else a new list will be created
|
||||
|
||||
Args:
|
||||
list_name (str): name of list
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
list_node = find_child_by_name(self._root, list_name)
|
||||
if not list_node:
|
||||
list_node = Node(list_name, parent=self._root, **kwargs)
|
||||
logging.info(f"Created list {list_name}")
|
||||
return list_node
|
||||
|
||||
def prioritize_list(self, list_name: str) -> None:
|
||||
"""Prioritize list in app, shift it to be the first list
|
||||
|
||||
Args:
|
||||
list_name (str): name of list
|
||||
"""
|
||||
list_node = find_child_by_name(self._root, list_name)
|
||||
if not list_node:
|
||||
raise ValueError(f"List {list_name} not found")
|
||||
current_children = list(self._root.children)
|
||||
current_children.remove(list_node)
|
||||
current_children.insert(0, list_node)
|
||||
self._root.children = current_children # type: ignore
|
||||
|
||||
def add_item(
|
||||
self, item_name: Union[str, List[str]], list_name: str = "", **kwargs: Any
|
||||
) -> None:
|
||||
"""Add items to list
|
||||
|
||||
Args:
|
||||
item_name (str/List[str]): items to be added
|
||||
list_name (str): list to add items to, optional
|
||||
"""
|
||||
if not isinstance(item_name, str) and not isinstance(item_name, list):
|
||||
raise TypeError("Invalid data type for item")
|
||||
if isinstance(item_name, str):
|
||||
item_name = [item_name]
|
||||
|
||||
# Get list to add to
|
||||
if list_name:
|
||||
list_node = self.add_list(list_name)
|
||||
else:
|
||||
list_node = self.add_list("General")
|
||||
|
||||
# Add items to list
|
||||
for _item in item_name:
|
||||
_ = Node(_item, parent=list_node, **kwargs)
|
||||
logging.info(f"Created item(s) {', '.join(item_name)}")
|
||||
|
||||
def remove_item(
|
||||
self, item_name: Union[str, List[str]], list_name: str = ""
|
||||
) -> None:
|
||||
"""Remove items from list
|
||||
|
||||
Args:
|
||||
item_name (str/List[str]): items to be added
|
||||
list_name (str): list to add items to, optional
|
||||
"""
|
||||
if not isinstance(item_name, str) and not isinstance(item_name, list):
|
||||
raise TypeError("Invalid data type for item")
|
||||
if isinstance(item_name, str):
|
||||
item_name = [item_name]
|
||||
|
||||
# Check if items can be found
|
||||
items_to_remove = []
|
||||
parent_to_check: set[Node] = set()
|
||||
if list_name:
|
||||
list_node = find_child_by_name(self._root, list_name)
|
||||
if not list_node:
|
||||
raise ValueError(f"List {list_name} does not exist!")
|
||||
for _item in item_name:
|
||||
item_node = find_child_by_name(list_node, _item)
|
||||
if not item_node:
|
||||
raise ValueError(f"Item {_item} does not exist!")
|
||||
assert isinstance(item_node.parent, Node) # for mypy type checking
|
||||
items_to_remove.append(item_node)
|
||||
parent_to_check.add(item_node.parent)
|
||||
else:
|
||||
for _item in item_name:
|
||||
item_node = find_name(self._root, _item)
|
||||
if not item_node:
|
||||
raise ValueError(f"Item {_item} does not exist!")
|
||||
assert isinstance(item_node.parent, Node) # for mypy type checking
|
||||
items_to_remove.append(item_node)
|
||||
parent_to_check.add(item_node.parent)
|
||||
|
||||
# Remove items
|
||||
for item_to_remove in items_to_remove:
|
||||
if item_to_remove.depth != 3:
|
||||
raise ValueError(
|
||||
f"Check item to remove {item_to_remove} is an item at item-level"
|
||||
)
|
||||
item_to_remove.parent = None
|
||||
logging.info(
|
||||
f"Removed item(s) {', '.join(item.node_name for item in items_to_remove)}"
|
||||
)
|
||||
|
||||
# Remove list if empty
|
||||
for list_node in parent_to_check:
|
||||
if not len(list(list_node.children)):
|
||||
list_node.parent = None
|
||||
logging.info(f"Removed list {list_node.node_name}")
|
||||
|
||||
def prioritize_item(self, item_name: str) -> None:
|
||||
"""Prioritize item in list, shift it to be the first item in list
|
||||
|
||||
Args:
|
||||
item_name (str): name of item
|
||||
"""
|
||||
item_node = find_name(self._root, item_name)
|
||||
if not item_node:
|
||||
raise ValueError(f"Item {item_name} not found")
|
||||
if item_node.depth != 3:
|
||||
raise ValueError(f"{item_name} is not an item")
|
||||
assert isinstance(item_node.parent, Node) # for mypy type checking
|
||||
current_parent = item_node.parent
|
||||
current_children = list(current_parent.children)
|
||||
current_children.remove(item_node)
|
||||
current_children.insert(0, item_node)
|
||||
current_parent.children = current_children # type: ignore
|
||||
|
||||
def show(self, **kwargs: Any) -> None:
|
||||
"""Print tree to console"""
|
||||
print_tree(self._root, all_attrs=True, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(json_path: str) -> AppToDo:
|
||||
"""Load To-Do app from json
|
||||
|
||||
Args:
|
||||
json_path (str): json load path
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if not json_path.endswith(".json"):
|
||||
raise ValueError("Path should end with .json")
|
||||
|
||||
with open(json_path, "r") as fp:
|
||||
app_dict = json.load(fp)
|
||||
_app = AppToDo("dummy")
|
||||
AppToDo.__setattr__(_app, "_root", nested_dict_to_tree(app_dict["root"]))
|
||||
return _app
|
||||
|
||||
def save(self, json_path: str) -> None:
|
||||
"""Save To-Do app as json
|
||||
|
||||
Args:
|
||||
json_path (str): json save path
|
||||
"""
|
||||
if not json_path.endswith(".json"):
|
||||
raise ValueError("Path should end with .json")
|
||||
|
||||
node_dict = tree_to_nested_dict(self._root, all_attrs=True)
|
||||
app_dict = {"root": node_dict}
|
||||
with open(json_path, "w") as fp:
|
||||
json.dump(app_dict, fp)
|
||||
1
python310/packages/elog.pth
Normal file
1
python310/packages/elog.pth
Normal file
@@ -0,0 +1 @@
|
||||
./elog-1.3.4-py3.7.egg
|
||||
13
python310/packages/elog/__init__.py
Normal file
13
python310/packages/elog/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from elog.logbook import Logbook
|
||||
from elog.logbook import LogbookError, LogbookAuthenticationError, LogbookServerProblem, LogbookMessageRejected, \
|
||||
LogbookInvalidMessageID, LogbookInvalidAttachmentType
|
||||
|
||||
|
||||
def open(*args, **kwargs):
|
||||
"""
|
||||
Will return a Logbook object. All arguments are passed to the logbook constructor.
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: Logbook() instance
|
||||
"""
|
||||
return Logbook(*args, **kwargs)
|
||||
571
python310/packages/elog/logbook.py
Normal file
571
python310/packages/elog/logbook.py
Normal file
@@ -0,0 +1,571 @@
|
||||
import requests
|
||||
import urllib.parse
|
||||
import os
|
||||
import builtins
|
||||
import re
|
||||
from elog.logbook_exceptions import *
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Logbook(object):
|
||||
"""
|
||||
Logbook provides methods to interface with logbook on location: "server:port/subdir/logbook". User can create,
|
||||
edit, delete logbook messages.
|
||||
"""
|
||||
|
||||
def __init__(self, hostname, logbook='', port=None, user=None, password=None, subdir='', use_ssl=True,
|
||||
encrypt_pwd=True):
|
||||
"""
|
||||
:param hostname: elog server hostname. If whole url is specified here, it will be parsed and arguments:
|
||||
"logbook, port, subdir, use_ssl" will be overwritten by parsed values.
|
||||
:param logbook: name of the logbook on the elog server
|
||||
:param port: elog server port (if not specified will default to '80' if use_ssl=False or '443' if use_ssl=True
|
||||
:param user: username (if authentication needed)
|
||||
:param password: password (if authentication needed) Password will be encrypted with sha256 unless
|
||||
encrypt_pwd=False (default: True)
|
||||
:param subdir: subdirectory of logbooks locations
|
||||
:param use_ssl: connect using ssl (ignored if url starts with 'http://'' or 'https://'?
|
||||
:param encrypt_pwd: To avoid exposing password in the code, this flag can be set to False and password
|
||||
will then be handled as it is (user needs to provide sha256 encrypted password with
|
||||
salt= '' and rounds=5000)
|
||||
:return:
|
||||
"""
|
||||
hostname = hostname.strip()
|
||||
|
||||
# parse url to see if some parameters are defined with url
|
||||
parsed_url = urllib.parse.urlsplit(hostname)
|
||||
|
||||
# ---- handle SSL -----
|
||||
# hostname must be modified according to use_ssl flag. If hostname starts with https:// or http://
|
||||
# the use_ssl flag is ignored
|
||||
url_scheme = parsed_url.scheme
|
||||
if url_scheme == 'http':
|
||||
use_ssl = False
|
||||
|
||||
elif url_scheme == 'https':
|
||||
use_ssl = True
|
||||
|
||||
elif not url_scheme:
|
||||
# add http or https
|
||||
if use_ssl:
|
||||
url_scheme = 'https'
|
||||
else:
|
||||
url_scheme = 'http'
|
||||
|
||||
# ---- handle port -----
|
||||
# 1) by default use port defined in the url
|
||||
# 2) remove any 'default' ports such as 80 for http and 443 for https
|
||||
# 3) if port not defined in url and not 'default' add it to netloc
|
||||
|
||||
netloc = parsed_url.netloc
|
||||
if netloc == "" and "localhost" in hostname:
|
||||
netloc = 'localhost'
|
||||
netloc_split = netloc.split(':')
|
||||
if len(netloc_split) > 1:
|
||||
# port defined in url --> remove if needed
|
||||
port = netloc_split[1]
|
||||
if (port == 80 and not use_ssl) or (port == 443 and use_ssl):
|
||||
netloc = netloc_split[0]
|
||||
|
||||
else:
|
||||
# add port info if needed
|
||||
if port is not None and not (port == 80 and not use_ssl) and not (port == 443 and use_ssl):
|
||||
netloc += ':{}'.format(port)
|
||||
|
||||
# ---- handle subdir and logbook -----
|
||||
# parsed_url.path = /<subdir>/<logbook>/
|
||||
|
||||
# Remove last '/' for easier parsing
|
||||
url_path = parsed_url.path
|
||||
if url_path.endswith('/'):
|
||||
url_path = url_path[:-1]
|
||||
|
||||
splitted_path = url_path.split('/')
|
||||
if url_path and len(splitted_path) > 1:
|
||||
# If here ... then at least some part of path is defined.
|
||||
|
||||
# If logbook defined --> treat path current path as subdir and add logbook at the end
|
||||
# to define the full path. Else treat existing path as <subdir>/<logbook>.
|
||||
# Put first and last '/' back on its place
|
||||
if logbook:
|
||||
url_path += '/{}'.format(logbook)
|
||||
else:
|
||||
logbook = splitted_path[-1]
|
||||
|
||||
else:
|
||||
# There is nothing. Use arguments.
|
||||
url_path = subdir + '/' + logbook
|
||||
|
||||
# urllib.parse.quote replaces special characters with %xx escapes
|
||||
# self._logbook_path = urllib.parse.quote('/' + url_path + '/').replace('//', '/')
|
||||
self._logbook_path = ('/' + url_path + '/').replace('//', '/')
|
||||
|
||||
self._url = url_scheme + '://' + netloc + self._logbook_path
|
||||
self.logbook = logbook
|
||||
self._user = user
|
||||
self._password = _handle_pswd(password, encrypt_pwd)
|
||||
|
||||
def post(self, message, msg_id=None, reply=False, attributes=None, attachments=None, encoding=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Posts message to the logbook. If msg_id is not specified new message will be created, otherwise existing
|
||||
message will be edited, or a reply (if reply=True) to it will be created. This method returns the msg_id
|
||||
of the newly created message.
|
||||
|
||||
:param message: string with message text
|
||||
:param msg_id: ID number of message to edit or reply. If not specified new message is created.
|
||||
:param reply: If 'True' reply to existing message is created instead of editing it
|
||||
:param attributes: Dictionary of attributes. Following attributes are used internally by the elog and will be
|
||||
ignored: Text, Date, Encoding, Reply to, In reply to, Locked by, Attachment
|
||||
:param attachments: list of:
|
||||
- file like objects which read() will return bytes (if file_like_object.name is not
|
||||
defined, default name "attachment<i>" will be used.
|
||||
- paths to the files
|
||||
All items will be appended as attachment to the elog entry. In case of unknown
|
||||
attachment an exception LogbookInvalidAttachment will be raised.
|
||||
:param encoding: Defines encoding of the message. Can be: 'plain' -> plain text, 'html'->html-text,
|
||||
'ELCode' --> elog formatting syntax
|
||||
:param kwargs: Anything in the kwargs will be interpreted as attribute. e.g.: logbook.post('Test text',
|
||||
Author='Rok Vintar), "Author" will be sent as an attribute. If named same as one of the
|
||||
attributes defined in "attributes", kwargs will have priority.
|
||||
|
||||
:return: msg_id
|
||||
"""
|
||||
|
||||
attributes = attributes or {}
|
||||
attributes = {**attributes, **kwargs} # kwargs as attributes with higher priority
|
||||
|
||||
attachments = attachments or []
|
||||
|
||||
if encoding is not None:
|
||||
if encoding not in ['plain', 'HTML', 'ELCode']:
|
||||
raise LogbookMessageRejected('Invalid message encoding. Valid options: plain, HTML, ELCode.')
|
||||
attributes['Encoding'] = encoding
|
||||
|
||||
attributes_to_edit = dict()
|
||||
if msg_id:
|
||||
# Message exists, we can continue
|
||||
if reply:
|
||||
# Verify that there is a message on the server, otherwise do not reply to it!
|
||||
self._check_if_message_on_server(msg_id) # raises exception in case of none existing message
|
||||
|
||||
attributes['reply_to'] = str(msg_id)
|
||||
|
||||
else: # Edit existing
|
||||
attributes['edit_id'] = str(msg_id)
|
||||
attributes['skiplock'] = '1'
|
||||
|
||||
# Handle existing attachments
|
||||
msg_to_edit, attributes_to_edit, attach_to_edit = self.read(msg_id)
|
||||
|
||||
i = 0
|
||||
for attachment in attach_to_edit:
|
||||
if attachment:
|
||||
# Existing attachments must be passed as regular arguments attachment<i> with value= file name
|
||||
# Read message returnes full urls to existing attachments:
|
||||
# <hostname>:[<port>][/<subdir]/<logbook>/<msg_id>/<file_name>
|
||||
attributes['attachment' + str(i)] = os.path.basename(attachment)
|
||||
i += 1
|
||||
|
||||
for attribute, data in attributes.items():
|
||||
new_data = attributes.get(attribute)
|
||||
if new_data is not None:
|
||||
attributes_to_edit[attribute] = new_data
|
||||
else:
|
||||
# As we create a new message, specify creation time if not already specified in attributes
|
||||
if 'When' not in attributes:
|
||||
attributes['When'] = int(datetime.now().timestamp())
|
||||
|
||||
if not attributes_to_edit:
|
||||
attributes_to_edit = attributes
|
||||
# Remove any attributes that should not be sent
|
||||
_remove_reserved_attributes(attributes_to_edit)
|
||||
|
||||
if attachments:
|
||||
files_to_attach, objects_to_close = self._prepare_attachments(attachments)
|
||||
else:
|
||||
objects_to_close = list()
|
||||
files_to_attach = list()
|
||||
|
||||
# Make requests module think that Text is a "file". This is the only way to force requests to send data as
|
||||
# multipart/form-data even if there are no attachments. Elog understands only multipart/form-data
|
||||
files_to_attach.append(('Text', ('', message)))
|
||||
|
||||
# Base attributes are common to all messages
|
||||
self._add_base_msg_attributes(attributes_to_edit)
|
||||
|
||||
# Keys in attributes cannot have certain characters like whitespaces or dashes for the http request
|
||||
attributes_to_edit = _replace_special_characters_in_attribute_keys(attributes_to_edit)
|
||||
|
||||
try:
|
||||
response = requests.post(self._url, data=attributes_to_edit, files=files_to_attach, allow_redirects=False,
|
||||
verify=False)
|
||||
# Validate response. Any problems will raise an Exception.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
|
||||
# Close file like objects that were opened by the elog (if path
|
||||
for file_like_object in objects_to_close:
|
||||
if hasattr(file_like_object, 'close'):
|
||||
file_like_object.close()
|
||||
|
||||
except requests.RequestException as e:
|
||||
# Check if message on server.
|
||||
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
|
||||
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to post a message, ' + 'because of:\n' +
|
||||
'{0}'.format(e))
|
||||
|
||||
# Any error before here should raise an exception, but check again for nay case.
|
||||
if not resp_msg_id or resp_msg_id < 1:
|
||||
raise LogbookInvalidMessageID('Invalid message ID: ' + str(resp_msg_id) + ' returned')
|
||||
return resp_msg_id
|
||||
|
||||
def read(self, msg_id):
|
||||
"""
|
||||
Reads message from the logbook server and returns tuple of (message, attributes, attachments) where:
|
||||
message: string with message body
|
||||
attributes: dictionary of all attributes returned by the logbook
|
||||
attachments: list of urls to attachments on the logbook server
|
||||
|
||||
:param msg_id: ID of the message to be read
|
||||
:return: message, attributes, attachments
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
self._check_if_message_on_server(msg_id) # raises exceptions if no message or no response from server
|
||||
response = requests.get(self._url + str(msg_id) + '?cmd=download', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read the message with ID: ' + str(msg_id) +
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
# Parse message to separate message body, attributes and attachments
|
||||
attributes = dict()
|
||||
attachments = list()
|
||||
|
||||
returned_msg = resp_message.decode('utf-8', 'ignore').splitlines()
|
||||
delimiter_idx = returned_msg.index('========================================')
|
||||
|
||||
message = '\n'.join(returned_msg[delimiter_idx + 1:])
|
||||
for line in returned_msg[0:delimiter_idx]:
|
||||
line = line.split(': ')
|
||||
data = ''.join(line[1:])
|
||||
if line[0] == 'Attachment':
|
||||
attachments = data.split(',')
|
||||
# Here are only attachment names, make a full url out of it, so they could be
|
||||
# recognisable by others, and downloaded if needed
|
||||
attachments = [self._url + '{0}'.format(i) for i in attachments]
|
||||
else:
|
||||
attributes[line[0]] = data
|
||||
|
||||
return message, attributes, attachments
|
||||
|
||||
def delete(self, msg_id):
|
||||
"""
|
||||
Deletes message thread (!!!message + all replies!!!) from logbook.
|
||||
It also deletes all of attachments of corresponding messages from the server.
|
||||
|
||||
:param msg_id: message to be deleted
|
||||
:return:
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
self._check_if_message_on_server(msg_id) # check if something to delete
|
||||
|
||||
response = requests.get(self._url + str(msg_id) + '?cmd=Delete&confirm=Yes', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
_validate_response(response) # raises exception if any other error identified
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to delete the message with ID: ' + str(msg_id) +
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
# Additional validation: If successfully deleted then status_code = 302. In case command was not executed at
|
||||
# all (not English language --> no download command supported) status_code = 200 and the content is just a
|
||||
# html page of this whole message.
|
||||
if response.status_code == 200:
|
||||
raise LogbookServerProblem('Cannot process delete command (only logbooks in English supported).')
|
||||
|
||||
def search(self, search_term, n_results = 20, scope="subtext"):
|
||||
"""
|
||||
Searches the logbook and returns the message ids.
|
||||
|
||||
"""
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
# Putting n_results = 0 crashes the elog. also in the web-gui.
|
||||
n_results = 1 if n_results < 1 else n_results
|
||||
|
||||
params = {
|
||||
"mode": "full",
|
||||
"reverse": "1",
|
||||
"npp": n_results,
|
||||
scope: search_term
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(self._url, params=params, headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
_validate_response(response)
|
||||
resp_message = response
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read message ids '
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
from lxml import html
|
||||
tree = html.fromstring(resp_message.content)
|
||||
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
|
||||
message_ids = [int(m.split("/")[-1]) for m in message_ids]
|
||||
return message_ids
|
||||
|
||||
|
||||
def get_last_message_id(self):
|
||||
ids = self.get_message_ids()
|
||||
if len(ids) > 0:
|
||||
return ids[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_message_ids(self):
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
|
||||
try:
|
||||
response = requests.get(self._url + 'page', headers=request_headers,
|
||||
allow_redirects=False, verify=False)
|
||||
|
||||
# Validate response. If problems Exception will be thrown.
|
||||
_validate_response(response)
|
||||
resp_message = response
|
||||
|
||||
except requests.RequestException as e:
|
||||
# If here: message is on server but cannot be downloaded (should never happen)
|
||||
raise LogbookServerProblem('Cannot access logbook server to read message ids '
|
||||
'because of:\n' + '{0}'.format(e))
|
||||
|
||||
from lxml import html
|
||||
tree = html.fromstring(resp_message.content)
|
||||
message_ids = tree.xpath('(//tr/td[@class="list1" or @class="list2"][1])/a/@href')
|
||||
message_ids = [int(m.split("/")[-1]) for m in message_ids]
|
||||
return message_ids
|
||||
|
||||
def _check_if_message_on_server(self, msg_id):
|
||||
"""Try to load page for specific message. If there is a htm tag like <td class="errormsg"> then there is no
|
||||
such message.
|
||||
|
||||
:param msg_id: ID of message to be checked
|
||||
:return:
|
||||
"""
|
||||
|
||||
request_headers = dict()
|
||||
if self._user or self._password:
|
||||
request_headers['Cookie'] = self._make_user_and_pswd_cookie()
|
||||
try:
|
||||
response = requests.get(self._url + str(msg_id), headers=request_headers, allow_redirects=False,
|
||||
verify=False)
|
||||
|
||||
# If there is no message code 200 will be returned (OK) and _validate_response will not recognise it
|
||||
# but there will be some error in the html code.
|
||||
resp_message, resp_headers, resp_msg_id = _validate_response(response)
|
||||
# If there is no message, code 200 will be returned (OK) but there will be some error indication in
|
||||
# the html code.
|
||||
if re.findall('<td.*?class="errormsg".*?>.*?</td>',
|
||||
resp_message.decode('utf-8', 'ignore'),
|
||||
flags=re.DOTALL):
|
||||
raise LogbookInvalidMessageID('Message with ID: ' + str(msg_id) + ' does not exist on logbook.')
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise LogbookServerProblem('No response from the logbook server.\nDetails: ' + '{0}'.format(e))
|
||||
|
||||
def _add_base_msg_attributes(self, data):
|
||||
"""
|
||||
Adds base message attributes which are used by all messages.
|
||||
:param data: dict of current attributes
|
||||
:return: content string
|
||||
"""
|
||||
data['cmd'] = 'Submit'
|
||||
data['exp'] = self.logbook
|
||||
if self._user:
|
||||
data['unm'] = self._user
|
||||
if self._password:
|
||||
data['upwd'] = self._password
|
||||
|
||||
def _prepare_attachments(self, files):
|
||||
"""
|
||||
Parses attachments to content objects. Attachments can be:
|
||||
- file like objects: must have method read() which returns bytes. If it has attribute .name it will be used
|
||||
for attachment name, otherwise generic attribute<i> name will be used.
|
||||
- path to the file on disk
|
||||
|
||||
Note that if attachment is is an url pointing to the existing Logbook server it will be ignored and no
|
||||
exceptions will be raised. This can happen if attachments returned with read_method are resend.
|
||||
|
||||
:param files: list of file like objects or paths
|
||||
:return: content string
|
||||
"""
|
||||
prepared = list()
|
||||
i = 0
|
||||
objects_to_close = list() # objects that are created (opened) by elog must be later closed
|
||||
for file_obj in files:
|
||||
if hasattr(file_obj, 'read'):
|
||||
i += 1
|
||||
attribute_name = 'attfile' + str(i)
|
||||
|
||||
filename = attribute_name # If file like object has no name specified use this one
|
||||
candidate_filename = os.path.basename(file_obj.name)
|
||||
|
||||
if filename: # use only if not empty string
|
||||
filename = candidate_filename
|
||||
|
||||
elif isinstance(file_obj, str):
|
||||
# Check if it is:
|
||||
# - a path to the file --> open file and append
|
||||
# - an url pointing to the existing Logbook server --> ignore
|
||||
|
||||
filename = ""
|
||||
attribute_name = ""
|
||||
|
||||
if os.path.isfile(file_obj):
|
||||
i += 1
|
||||
attribute_name = 'attfile' + str(i)
|
||||
|
||||
file_obj = builtins.open(file_obj, 'rb')
|
||||
filename = os.path.basename(file_obj.name)
|
||||
|
||||
objects_to_close.append(file_obj)
|
||||
|
||||
elif not file_obj.startswith(self._url):
|
||||
raise LogbookInvalidAttachmentType('Invalid type of attachment: \"' + file_obj + '\".')
|
||||
else:
|
||||
raise LogbookInvalidAttachmentType('Invalid type of attachment[' + str(i) + '].')
|
||||
|
||||
prepared.append((attribute_name, (filename, file_obj)))
|
||||
|
||||
return prepared, objects_to_close
|
||||
|
||||
def _make_user_and_pswd_cookie(self):
|
||||
"""
|
||||
prepares user name and password cookie. It is sent in header when posting a message.
|
||||
:return: user name and password value for the Cookie header
|
||||
"""
|
||||
cookie = ''
|
||||
if self._user:
|
||||
cookie += 'unm=' + self._user + ';'
|
||||
if self._password:
|
||||
cookie += 'upwd=' + self._password + ';'
|
||||
|
||||
return cookie
|
||||
|
||||
|
||||
def _remove_reserved_attributes(attributes):
|
||||
"""
|
||||
Removes elog reserved attributes (from the attributes dict) that can not be sent.
|
||||
|
||||
:param attributes: dictionary of attributes to be cleaned.
|
||||
:return:
|
||||
"""
|
||||
|
||||
if attributes:
|
||||
attributes.get('$@MID@$', None)
|
||||
attributes.pop('Date', None)
|
||||
attributes.pop('Attachment', None)
|
||||
attributes.pop('Text', None) # Remove this one because it will be send attachment like
|
||||
|
||||
|
||||
def _replace_special_characters_in_attribute_keys(attributes):
|
||||
"""
|
||||
Replaces special characters in elog attribute keys by underscore, otherwise attribute values will be erased in
|
||||
the http request. This is using the same replacement elog itself is using to handle these cases
|
||||
|
||||
:param attributes: dictionary of attributes to be cleaned.
|
||||
:return: attributes with replaced keys
|
||||
"""
|
||||
return {re.sub('[^0-9a-zA-Z]', '_', key): value for key, value in attributes.items()}
|
||||
|
||||
|
||||
def _validate_response(response):
|
||||
""" Validate response of the request."""
|
||||
|
||||
msg_id = None
|
||||
|
||||
if response.status_code not in [200, 302]:
|
||||
# 200 --> OK; 302 --> Found
|
||||
# Html page is returned with error description (handling errors same way as on original client. Looks
|
||||
# like there is no other way.
|
||||
|
||||
err = re.findall('<td.*?class="errormsg".*?>.*?</td>',
|
||||
response.content.decode('utf-8', 'ignore'),
|
||||
flags=re.DOTALL)
|
||||
|
||||
if len(err) > 0:
|
||||
# Remove html tags
|
||||
# If part of the message has: Please go back... remove this part since it is an instruction for
|
||||
# the user when using browser.
|
||||
err = re.sub('(?:<.*?>)', '', err[0])
|
||||
if err:
|
||||
raise LogbookMessageRejected('Rejected because of: ' + err)
|
||||
else:
|
||||
raise LogbookMessageRejected('Rejected because of unknown error.')
|
||||
|
||||
# Other unknown errors
|
||||
raise LogbookMessageRejected('Rejected because of unknown error.')
|
||||
else:
|
||||
location = response.headers.get('Location')
|
||||
if location is not None:
|
||||
if 'has moved' in location:
|
||||
raise LogbookServerProblem('Logbook server has moved to another location.')
|
||||
elif 'fail' in location:
|
||||
raise LogbookAuthenticationError('Invalid username or password.')
|
||||
else:
|
||||
# returned locations is something like: '<host>/<sub_dir>/<logbook>/<msg_id><query>
|
||||
# with urllib.parse.urlparse returns attribute path=<sub_dir>/<logbook>/<msg_id>
|
||||
msg_id = int(urllib.parse.urlsplit(location).path.split('/')[-1])
|
||||
|
||||
if b'form name=form1' in response.content or b'type=password' in response.content:
|
||||
# Not to smart to check this way, but no other indication of this kind of error.
|
||||
# C client does it the same way
|
||||
raise LogbookAuthenticationError('Invalid username or password.')
|
||||
|
||||
return response.content, response.headers, msg_id
|
||||
|
||||
|
||||
def _handle_pswd(password, encrypt=True):
|
||||
"""
|
||||
Takes password string and returns password as needed by elog. If encrypt=True then password will be
|
||||
sha256 encrypted (salt='', rounds=5000). Before returning password, any trailing $5$$ will be removed
|
||||
independent off encrypt flag.
|
||||
|
||||
:param password: password string
|
||||
:param encrypt: encrypt password?
|
||||
:return: elog prepared password
|
||||
"""
|
||||
if encrypt and password:
|
||||
from passlib.hash import sha256_crypt
|
||||
return sha256_crypt.encrypt(password, salt='', rounds=5000)[4:]
|
||||
elif password and password.startswith('$5$$'):
|
||||
return password[4:]
|
||||
else:
|
||||
return password
|
||||
28
python310/packages/elog/logbook_exceptions.py
Normal file
28
python310/packages/elog/logbook_exceptions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
class LogbookError(Exception):
|
||||
""" Parent logbook exception."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookAuthenticationError(LogbookError):
|
||||
""" Raise when problem with username and password."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookServerProblem(LogbookError):
|
||||
""" Raise when problem accessing logbook server."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookMessageRejected(LogbookError):
|
||||
""" Raised when manipulating/creating message was rejected by the server or there was problem composing message."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookInvalidMessageID(LogbookMessageRejected):
|
||||
""" Raised when there is no message with specified ID on the server."""
|
||||
pass
|
||||
|
||||
|
||||
class LogbookInvalidAttachmentType(LogbookMessageRejected):
|
||||
""" Raised when passed attachment has invalid type."""
|
||||
pass
|
||||
68
python37/packages/bigtree/__init__.py
Normal file
68
python37/packages/bigtree/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
__version__ = "0.7.2"
|
||||
|
||||
from bigtree.binarytree.construct import list_to_binarytree
|
||||
from bigtree.dag.construct import dataframe_to_dag, dict_to_dag, list_to_dag
|
||||
from bigtree.dag.export import dag_to_dataframe, dag_to_dict, dag_to_dot, dag_to_list
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import (
|
||||
add_dataframe_to_tree_by_name,
|
||||
add_dataframe_to_tree_by_path,
|
||||
add_dict_to_tree_by_name,
|
||||
add_dict_to_tree_by_path,
|
||||
add_path_to_tree,
|
||||
dataframe_to_tree,
|
||||
dataframe_to_tree_by_relation,
|
||||
dict_to_tree,
|
||||
list_to_tree,
|
||||
list_to_tree_by_relation,
|
||||
nested_dict_to_tree,
|
||||
str_to_tree,
|
||||
)
|
||||
from bigtree.tree.export import (
|
||||
print_tree,
|
||||
tree_to_dataframe,
|
||||
tree_to_dict,
|
||||
tree_to_dot,
|
||||
tree_to_nested_dict,
|
||||
tree_to_pillow,
|
||||
yield_tree,
|
||||
)
|
||||
from bigtree.tree.helper import clone_tree, get_tree_diff, prune_tree
|
||||
from bigtree.tree.modify import (
|
||||
copy_nodes,
|
||||
copy_nodes_from_tree_to_tree,
|
||||
copy_or_shift_logic,
|
||||
shift_nodes,
|
||||
)
|
||||
from bigtree.tree.search import (
|
||||
find,
|
||||
find_attr,
|
||||
find_attrs,
|
||||
find_children,
|
||||
find_full_path,
|
||||
find_name,
|
||||
find_names,
|
||||
find_path,
|
||||
find_paths,
|
||||
findall,
|
||||
)
|
||||
from bigtree.utils.exceptions import (
|
||||
CorruptedTreeError,
|
||||
DuplicatedNodeError,
|
||||
LoopError,
|
||||
NotFoundError,
|
||||
SearchError,
|
||||
TreeError,
|
||||
)
|
||||
from bigtree.utils.iterators import (
|
||||
dag_iterator,
|
||||
inorder_iter,
|
||||
levelorder_iter,
|
||||
levelordergroup_iter,
|
||||
postorder_iter,
|
||||
preorder_iter,
|
||||
)
|
||||
from bigtree.workflows.app_todo import AppToDo
|
||||
50
python37/packages/bigtree/binarytree/construct.py
Normal file
50
python37/packages/bigtree/binarytree/construct.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import List, Type, Union
|
||||
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
|
||||
|
||||
def list_to_binarytree(
|
||||
heapq_list: List[Union[int, float]], node_type: Type[BinaryNode] = BinaryNode
|
||||
) -> BinaryNode:
|
||||
"""Construct tree from list of numbers (int or float) in heapq format.
|
||||
|
||||
>>> from bigtree import list_to_binarytree, print_tree, tree_to_dot
|
||||
>>> nums_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> root = list_to_binarytree(nums_list)
|
||||
>>> print_tree(root)
|
||||
1
|
||||
├── 2
|
||||
│ ├── 4
|
||||
│ │ ├── 8
|
||||
│ │ └── 9
|
||||
│ └── 5
|
||||
│ └── 10
|
||||
└── 3
|
||||
├── 6
|
||||
└── 7
|
||||
>>> graph = tree_to_dot(root, node_colour="gold")
|
||||
>>> graph.write_png("assets/binarytree.png")
|
||||
|
||||
.. image:: https://github.com/kayjan/bigtree/raw/master/assets/binarytree.png
|
||||
|
||||
Args:
|
||||
heapq_list (List[Union[int, float]]): list containing integer node names, ordered in heapq fashion
|
||||
node_type (Type[BinaryNode]): node type of tree to be created, defaults to BinaryNode
|
||||
|
||||
Returns:
|
||||
(BinaryNode)
|
||||
"""
|
||||
if not len(heapq_list):
|
||||
raise ValueError("Input list does not contain any data, check `heapq_list`")
|
||||
|
||||
root = node_type(heapq_list[0])
|
||||
node_list = [root]
|
||||
for idx, num in enumerate(heapq_list):
|
||||
if idx:
|
||||
if idx % 2:
|
||||
parent_idx = int((idx - 1) / 2)
|
||||
else:
|
||||
parent_idx = int((idx - 2) / 2)
|
||||
node = node_type(num, parent=node_list[parent_idx])
|
||||
node_list.append(node)
|
||||
return root
|
||||
0
python37/packages/bigtree/dag/__init__.py
Normal file
0
python37/packages/bigtree/dag/__init__.py
Normal file
186
python37/packages/bigtree/dag/construct.py
Normal file
186
python37/packages/bigtree/dag/construct.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
|
||||
__all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"]
|
||||
|
||||
|
||||
def list_to_dag(
|
||||
relations: List[Tuple[str, str]],
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from list of tuple containing parent-child names.
|
||||
Note that node names must be unique.
|
||||
|
||||
>>> from bigtree import list_to_dag, dag_iterator
|
||||
>>> relations_list = [("a", "c"), ("a", "d"), ("b", "c"), ("c", "d"), ("d", "e")]
|
||||
>>> dag = list_to_dag(relations_list)
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
relations (list): list containing tuple of parent-child names
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
if not len(relations):
|
||||
raise ValueError("Input list does not contain any data, check `relations`")
|
||||
|
||||
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
|
||||
return dataframe_to_dag(
|
||||
relation_data, child_col="child", parent_col="parent", node_type=node_type
|
||||
)
|
||||
|
||||
|
||||
def dict_to_dag(
|
||||
relation_attrs: dict,
|
||||
parent_key: str = "parents",
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from nested dictionary, ``key``: child name, ``value``: dict of parent names, attribute name and
|
||||
attribute value.
|
||||
Note that node names must be unique.
|
||||
|
||||
>>> from bigtree import dict_to_dag, dag_iterator
|
||||
>>> relation_dict = {
|
||||
... "a": {"step": 1},
|
||||
... "b": {"step": 1},
|
||||
... "c": {"parents": ["a", "b"], "step": 2},
|
||||
... "d": {"parents": ["a", "c"], "step": 2},
|
||||
... "e": {"parents": ["d"], "step": 3},
|
||||
... }
|
||||
>>> dag = dict_to_dag(relation_dict, parent_key="parents")
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
relation_attrs (dict): dictionary containing node, node parents, and node attribute information,
|
||||
key: child name, value: dict of parent names, node attribute and attribute value
|
||||
parent_key (str): key of dictionary to retrieve list of parents name, defaults to "parent"
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
if not len(relation_attrs):
|
||||
raise ValueError("Dictionary does not contain any data, check `relation_attrs`")
|
||||
|
||||
# Convert dictionary to dataframe
|
||||
data = pd.DataFrame(relation_attrs).T.rename_axis("_tmp_child").reset_index()
|
||||
assert (
|
||||
parent_key in data
|
||||
), f"Parent key {parent_key} not in dictionary, check `relation_attrs` and `parent_key`"
|
||||
|
||||
data = data.explode(parent_key)
|
||||
return dataframe_to_dag(
|
||||
data,
|
||||
child_col="_tmp_child",
|
||||
parent_col=parent_key,
|
||||
node_type=node_type,
|
||||
)
|
||||
|
||||
|
||||
def dataframe_to_dag(
|
||||
data: pd.DataFrame,
|
||||
child_col: str = None,
|
||||
parent_col: str = None,
|
||||
attribute_cols: list = [],
|
||||
node_type: Type[DAGNode] = DAGNode,
|
||||
) -> DAGNode:
|
||||
"""Construct DAG from pandas DataFrame.
|
||||
Note that node names must be unique.
|
||||
|
||||
`child_col` and `parent_col` specify columns for child name and parent name to construct DAG.
|
||||
`attribute_cols` specify columns for node attribute for child name
|
||||
If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
|
||||
columns are `attribute_cols`.
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import dataframe_to_dag, dag_iterator
|
||||
>>> relation_data = pd.DataFrame([
|
||||
... ["a", None, 1],
|
||||
... ["b", None, 1],
|
||||
... ["c", "a", 2],
|
||||
... ["c", "b", 2],
|
||||
... ["d", "a", 2],
|
||||
... ["d", "c", 2],
|
||||
... ["e", "d", 3],
|
||||
... ],
|
||||
... columns=["child", "parent", "step"]
|
||||
... )
|
||||
>>> dag = dataframe_to_dag(relation_data)
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(dag)]
|
||||
[('a', 'd'), ('c', 'd'), ('d', 'e'), ('a', 'c'), ('b', 'c')]
|
||||
|
||||
Args:
|
||||
data (pandas.DataFrame): data containing path and node attribute information
|
||||
child_col (str): column of data containing child name information, defaults to None
|
||||
if not set, it will take the first column of data
|
||||
parent_col (str): column of data containing parent name information, defaults to None
|
||||
if not set, it will take the second column of data
|
||||
attribute_cols (list): columns of data containing child node attribute information,
|
||||
if not set, it will take all columns of data except `child_col` and `parent_col`
|
||||
node_type (Type[DAGNode]): node type of DAG to be created, defaults to DAGNode
|
||||
|
||||
Returns:
|
||||
(DAGNode)
|
||||
"""
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not child_col:
|
||||
child_col = data.columns[0]
|
||||
if not parent_col:
|
||||
parent_col = data.columns[1]
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(child_col)
|
||||
attribute_cols.remove(parent_col)
|
||||
|
||||
data_check = data.copy()[[child_col] + attribute_cols].drop_duplicates()
|
||||
_duplicate_check = (
|
||||
data_check[child_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(child_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate child name with different attributes\nCheck {_duplicate_check}"
|
||||
)
|
||||
if np.any(data[child_col].isnull()):
|
||||
raise ValueError(f"Child name cannot be empty, check {child_col}")
|
||||
|
||||
node_dict = dict()
|
||||
parent_node = None
|
||||
|
||||
for row in data.reset_index(drop=True).to_dict(orient="index").values():
|
||||
child_name = row[child_col]
|
||||
parent_name = row[parent_col]
|
||||
node_attrs = row.copy()
|
||||
del node_attrs[child_col]
|
||||
del node_attrs[parent_col]
|
||||
node_attrs = {k: v for k, v in node_attrs.items() if not pd.isnull(v)}
|
||||
child_node = node_dict.get(child_name)
|
||||
if not child_node:
|
||||
child_node = node_type(child_name)
|
||||
node_dict[child_name] = child_node
|
||||
child_node.set_attrs(node_attrs)
|
||||
|
||||
if not pd.isnull(parent_name):
|
||||
parent_node = node_dict.get(parent_name)
|
||||
if not parent_node:
|
||||
parent_node = node_type(parent_name)
|
||||
node_dict[parent_name] = parent_node
|
||||
child_node.parents = [parent_node]
|
||||
|
||||
return parent_node
|
||||
269
python37/packages/bigtree/dag/export.py
Normal file
269
python37/packages/bigtree/dag/export.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from bigtree.node.dagnode import DAGNode
|
||||
from bigtree.utils.iterators import dag_iterator
|
||||
|
||||
__all__ = ["dag_to_list", "dag_to_dict", "dag_to_dataframe", "dag_to_dot"]
|
||||
|
||||
|
||||
def dag_to_list(
|
||||
dag: DAGNode,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Export DAG to list of tuple containing parent-child names
|
||||
|
||||
>>> from bigtree import DAGNode, dag_to_list
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_list(a)
|
||||
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
|
||||
Returns:
|
||||
(List[Tuple[str, str]])
|
||||
"""
|
||||
relations = []
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
relations.append((parent_node.node_name, child_node.node_name))
|
||||
return relations
|
||||
|
||||
|
||||
def dag_to_dict(
|
||||
dag: DAGNode,
|
||||
parent_key: str = "parents",
|
||||
attr_dict: dict = {},
|
||||
all_attrs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Export tree to dictionary.
|
||||
|
||||
Exported dictionary will have key as child name, and parent names and node attributes as a nested dictionary.
|
||||
|
||||
>>> from bigtree import DAGNode, dag_to_dict
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_dict(a, parent_key="parent", attr_dict={"step": "step no."})
|
||||
{'a': {'step no.': 1}, 'c': {'parent': ['a', 'b'], 'step no.': 2}, 'd': {'parent': ['a', 'c'], 'step no.': 2}, 'b': {'step no.': 1}, 'e': {'parent': ['d'], 'step no.': 3}}
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
parent_key (str): dictionary key for `node.parent.node_name`, defaults to `parents`
|
||||
attr_dict (dict): dictionary mapping node attributes to dictionary key,
|
||||
key: node attributes, value: corresponding dictionary key, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes
|
||||
|
||||
Returns:
|
||||
(dict)
|
||||
"""
|
||||
dag = dag.copy()
|
||||
data_dict = {}
|
||||
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
if parent_node.is_root:
|
||||
data_parent = {}
|
||||
if all_attrs:
|
||||
data_parent.update(
|
||||
parent_node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_parent[v] = parent_node.get_attr(k)
|
||||
data_dict[parent_node.node_name] = data_parent
|
||||
|
||||
if data_dict.get(child_node.node_name):
|
||||
data_dict[child_node.node_name][parent_key].append(parent_node.node_name)
|
||||
else:
|
||||
data_child = {parent_key: [parent_node.node_name]}
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = child_node.get_attr(k)
|
||||
data_dict[child_node.node_name] = data_child
|
||||
return data_dict
|
||||
|
||||
|
||||
def dag_to_dataframe(
|
||||
dag: DAGNode,
|
||||
name_col: str = "name",
|
||||
parent_col: str = "parent",
|
||||
attr_dict: dict = {},
|
||||
all_attrs: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Export DAG to pandas DataFrame.
|
||||
|
||||
>>> from bigtree import DAGNode, dag_to_dataframe
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_to_dataframe(a, name_col="name", parent_col="parent", attr_dict={"step": "step no."})
|
||||
name parent step no.
|
||||
0 a None 1
|
||||
1 c a 2
|
||||
2 d a 2
|
||||
3 b None 1
|
||||
4 c b 2
|
||||
5 d c 2
|
||||
6 e d 3
|
||||
|
||||
Args:
|
||||
dag (DAGNode): DAG to be exported
|
||||
name_col (str): column name for `node.node_name`, defaults to 'name'
|
||||
parent_col (str): column name for `node.parent.node_name`, defaults to 'parent'
|
||||
attr_dict (dict): dictionary mapping node attributes to column name,
|
||||
key: node attributes, value: corresponding column in dataframe, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes
|
||||
|
||||
Returns:
|
||||
(pd.DataFrame)
|
||||
"""
|
||||
dag = dag.copy()
|
||||
data_list = []
|
||||
|
||||
for parent_node, child_node in dag_iterator(dag):
|
||||
if parent_node.is_root:
|
||||
data_parent = {name_col: parent_node.node_name, parent_col: None}
|
||||
if all_attrs:
|
||||
data_parent.update(
|
||||
parent_node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_parent[v] = parent_node.get_attr(k)
|
||||
data_list.append(data_parent)
|
||||
|
||||
data_child = {name_col: child_node.node_name, parent_col: parent_node.node_name}
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
child_node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = child_node.get_attr(k)
|
||||
data_list.append(data_child)
|
||||
return pd.DataFrame(data_list).drop_duplicates().reset_index(drop=True)
|
||||
|
||||
|
||||
def dag_to_dot(
|
||||
dag: Union[DAGNode, List[DAGNode]],
|
||||
rankdir: str = "TB",
|
||||
bg_colour: str = None,
|
||||
node_colour: str = None,
|
||||
edge_colour: str = None,
|
||||
node_attr: str = None,
|
||||
edge_attr: str = None,
|
||||
):
|
||||
r"""Export DAG tree or list of DAG trees to image.
|
||||
Note that node names must be unique.
|
||||
Posible node attributes include style, fillcolor, shape.
|
||||
|
||||
>>> from bigtree import DAGNode, dag_to_dot
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> dag_graph = dag_to_dot(a)
|
||||
|
||||
Export to image, dot file, etc.
|
||||
|
||||
>>> dag_graph.write_png("tree_dag.png")
|
||||
>>> dag_graph.write_dot("tree_dag.dot")
|
||||
|
||||
Export to string
|
||||
|
||||
>>> dag_graph.to_string()
|
||||
'strict digraph G {\nrankdir=TB;\nc [label=c];\na [label=a];\na -> c;\nd [label=d];\na [label=a];\na -> d;\nc [label=c];\nb [label=b];\nb -> c;\nd [label=d];\nc [label=c];\nc -> d;\ne [label=e];\nd [label=d];\nd -> e;\n}\n'
|
||||
|
||||
Args:
|
||||
dag (Union[DAGNode, List[DAGNode]]): DAG or list of DAGs to be exported
|
||||
rankdir (str): set direction of graph layout, defaults to 'TB', can be 'BT, 'LR', 'RL'
|
||||
bg_colour (str): background color of image, defaults to None
|
||||
node_colour (str): fill colour of nodes, defaults to None
|
||||
edge_colour (str): colour of edges, defaults to None
|
||||
node_attr (str): node attribute for style, overrides node_colour, defaults to None
|
||||
Possible node attributes include {"style": "filled", "fillcolor": "gold"}
|
||||
edge_attr (str): edge attribute for style, overrides edge_colour, defaults to None
|
||||
Possible edge attributes include {"style": "bold", "label": "edge label", "color": "black"}
|
||||
|
||||
Returns:
|
||||
(pydot.Dot)
|
||||
"""
|
||||
try:
|
||||
import pydot
|
||||
except ImportError: # pragma: no cover
|
||||
raise ImportError(
|
||||
"pydot not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
|
||||
)
|
||||
|
||||
# Get style
|
||||
if bg_colour:
|
||||
graph_style = dict(bgcolor=bg_colour)
|
||||
else:
|
||||
graph_style = dict()
|
||||
|
||||
if node_colour:
|
||||
node_style = dict(style="filled", fillcolor=node_colour)
|
||||
else:
|
||||
node_style = dict()
|
||||
|
||||
if edge_colour:
|
||||
edge_style = dict(color=edge_colour)
|
||||
else:
|
||||
edge_style = dict()
|
||||
|
||||
_graph = pydot.Dot(
|
||||
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
|
||||
)
|
||||
|
||||
if not isinstance(dag, list):
|
||||
dag = [dag]
|
||||
|
||||
for _dag in dag:
|
||||
if not isinstance(_dag, DAGNode):
|
||||
raise ValueError(
|
||||
"Tree should be of type `DAGNode`, or inherit from `DAGNode`"
|
||||
)
|
||||
_dag = _dag.copy()
|
||||
|
||||
for parent_node, child_node in dag_iterator(_dag):
|
||||
child_name = child_node.name
|
||||
child_node_style = node_style.copy()
|
||||
if node_attr and child_node.get_attr(node_attr):
|
||||
child_node_style.update(child_node.get_attr(node_attr))
|
||||
if edge_attr:
|
||||
edge_style.update(child_node.get_attr(edge_attr))
|
||||
pydot_child = pydot.Node(
|
||||
name=child_name, label=child_name, **child_node_style
|
||||
)
|
||||
_graph.add_node(pydot_child)
|
||||
|
||||
parent_name = parent_node.name
|
||||
parent_node_style = node_style.copy()
|
||||
if node_attr and parent_node.get_attr(node_attr):
|
||||
parent_node_style.update(parent_node.get_attr(node_attr))
|
||||
pydot_parent = pydot.Node(
|
||||
name=parent_name, label=parent_name, **parent_node_style
|
||||
)
|
||||
_graph.add_node(pydot_parent)
|
||||
|
||||
edge = pydot.Edge(parent_name, child_name, **edge_style)
|
||||
_graph.add_edge(edge)
|
||||
|
||||
return _graph
|
||||
0
python37/packages/bigtree/node/__init__.py
Normal file
0
python37/packages/bigtree/node/__init__.py
Normal file
696
python37/packages/bigtree/node/basenode.py
Normal file
696
python37/packages/bigtree/node/basenode.py
Normal file
@@ -0,0 +1,696 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
|
||||
class BaseNode:
|
||||
"""
|
||||
BaseNode extends any Python class to a tree node.
|
||||
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
Nodes can be linked to each other with `parent` and `children` setter methods,
|
||||
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
|
||||
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65)
|
||||
>>> c = Node("c", age=60)
|
||||
>>> d = Node("d", age=40)
|
||||
>>> root.children = [b, c]
|
||||
>>> d.parent = b
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ └── d [age=40]
|
||||
└── c [age=60]
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65)
|
||||
>>> c = Node("c", age=60)
|
||||
>>> d = Node("d", age=40)
|
||||
>>> root >> b
|
||||
>>> root >> c
|
||||
>>> d << b
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ └── d [age=40]
|
||||
└── c [age=60]
|
||||
|
||||
Directly passing `parent` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=b)
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> d = Node("d")
|
||||
>>> c = Node("c")
|
||||
>>> b = Node("b", children=[d])
|
||||
>>> a = Node("a", children=[b, c])
|
||||
|
||||
**BaseNode Creation**
|
||||
|
||||
Node can be created by instantiating a `BaseNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> root = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**BaseNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set other `BaseNode`
|
||||
|
||||
1. ``parent``: Get/set parent node
|
||||
2. ``children``: Get/set child nodes
|
||||
|
||||
Get other `BaseNode`
|
||||
|
||||
1. ``ancestors``: Get ancestors of node excluding self, iterator
|
||||
2. ``descendants``: Get descendants of node excluding self, iterator
|
||||
3. ``leaves``: Get all leaf node(s) from self, iterator
|
||||
4. ``siblings``: Get siblings of self
|
||||
5. ``left_sibling``: Get sibling left of self
|
||||
6. ``right_sibling``: Get sibling right of self
|
||||
|
||||
Get `BaseNode` configuration
|
||||
|
||||
1. ``node_path``: Get tuple of nodes from root
|
||||
2. ``is_root``: Get indicator if self is root node
|
||||
3. ``is_leaf``: Get indicator if self is leaf node
|
||||
4. ``root``: Get root node of tree
|
||||
5. ``depth``: Get depth of self
|
||||
6. ``max_depth``: Get maximum depth from root to leaf node
|
||||
|
||||
**BaseNode Methods**
|
||||
|
||||
These are methods available to be performed on `BaseNode`.
|
||||
|
||||
Constructor methods
|
||||
|
||||
1. ``from_dict()``: Create BaseNode from dictionary
|
||||
|
||||
`BaseNode` methods
|
||||
|
||||
1. ``describe()``: Get node information sorted by attributes, returns list of tuples
|
||||
2. ``get_attr(attr_name: str)``: Get value of node attribute
|
||||
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
|
||||
4. ``go_to(node: BaseNode)``: Get a path from own node to another node from same tree
|
||||
5. ``copy()``: Deep copy BaseNode
|
||||
6. ``sort()``: Sort child nodes
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None, children: List = None, **kwargs):
|
||||
self.__parent = None
|
||||
self.__children = []
|
||||
if children is None:
|
||||
children = []
|
||||
self.parent = parent
|
||||
self.children = children
|
||||
if "parents" in kwargs:
|
||||
raise ValueError(
|
||||
"Attempting to set `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
"""Get parent node
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__parent
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parent):
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if not (isinstance(new_parent, BaseNode) or new_parent is None):
|
||||
raise TypeError(
|
||||
f"Expect input to be BaseNode type or NoneType, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
def __check_parent_loop(self, new_parent):
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if new_parent is not None:
|
||||
if new_parent is self:
|
||||
raise LoopError("Error setting parent: Node cannot be parent of itself")
|
||||
if any(
|
||||
ancestor is self
|
||||
for ancestor in new_parent.ancestors
|
||||
if new_parent.ancestors
|
||||
):
|
||||
raise LoopError(
|
||||
"Error setting parent: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
@parent.setter
|
||||
def parent(self, new_parent):
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
self.__check_parent_type(new_parent)
|
||||
self.__check_parent_loop(new_parent)
|
||||
|
||||
current_parent = self.parent
|
||||
current_child_idx = None
|
||||
|
||||
# Assign new parent - rollback if error
|
||||
self.__pre_assign_parent(new_parent)
|
||||
try:
|
||||
# Remove self from old parent
|
||||
if current_parent is not None:
|
||||
if not any(
|
||||
child is self for child in current_parent.children
|
||||
): # pragma: no cover
|
||||
raise CorruptedTreeError(
|
||||
"Error setting parent: Node does not exist as children of its parent"
|
||||
)
|
||||
current_child_idx = current_parent.__children.index(self)
|
||||
current_parent.__children.remove(self)
|
||||
|
||||
# Assign self to new parent
|
||||
self.__parent = new_parent
|
||||
if new_parent is not None:
|
||||
new_parent.__children.append(self)
|
||||
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
if new_parent is not None:
|
||||
new_parent.__children.remove(self)
|
||||
|
||||
# Reassign self to old parent
|
||||
self.__parent = current_parent
|
||||
if current_child_idx is not None:
|
||||
current_parent.__children.insert(current_child_idx, self)
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_parent(self, new_parent):
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overriden with `_BaseNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self, new_parent):
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overriden with `_BaseNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def parents(self) -> None:
|
||||
"""Do not allow `parents` attribute to be accessed"""
|
||||
raise ValueError(
|
||||
"Attempting to access `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
|
||||
@parents.setter
|
||||
def parents(self, new_parent):
|
||||
"""Do not allow `parents` attribute to be set
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
raise ValueError("Attempting to set `parents` attribute, do you mean `parent`?")
|
||||
|
||||
@property
|
||||
def children(self) -> Iterable:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
def __check_children_type(self, new_children: List):
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
if not isinstance(new_children, list):
|
||||
raise TypeError(
|
||||
f"Children input should be list type, received input type {type(new_children)}"
|
||||
)
|
||||
|
||||
def __check_children_loop(self, new_children: List):
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if not isinstance(new_child, BaseNode):
|
||||
raise TypeError(
|
||||
f"Expect input to be BaseNode type, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestors of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@children.setter
|
||||
def children(self, new_children: List):
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
self.__check_children_type(new_children)
|
||||
self.__check_children_loop(new_children)
|
||||
|
||||
current_new_children = {
|
||||
new_child: (new_child.parent.__children.index(new_child), new_child.parent)
|
||||
for new_child in new_children
|
||||
if new_child.parent is not None
|
||||
}
|
||||
current_new_orphan = [
|
||||
new_child for new_child in new_children if new_child.parent is None
|
||||
]
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Remove old children from self
|
||||
del self.children
|
||||
|
||||
# Assign new children to self
|
||||
self.__children = new_children
|
||||
for new_child in new_children:
|
||||
if new_child.parent:
|
||||
new_child.parent.__children.remove(new_child)
|
||||
new_child.__parent = self
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign new children to their original parent
|
||||
for child, idx_parent in current_new_children.items():
|
||||
child_idx, parent = idx_parent
|
||||
child.__parent = parent
|
||||
parent.__children.insert(child_idx, child)
|
||||
for child in current_new_orphan:
|
||||
child.__parent = None
|
||||
|
||||
# Reassign old children to self
|
||||
self.__children = current_children
|
||||
for child in current_children:
|
||||
child.__parent = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
@children.deleter
|
||||
def children(self):
|
||||
"""Delete child node(s)"""
|
||||
for child in self.children:
|
||||
child.parent.__children.remove(child)
|
||||
child.__parent = None
|
||||
|
||||
def __pre_assign_children(self, new_children: List):
|
||||
"""Custom method to check before attaching children
|
||||
Can be overriden with `_BaseNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self, new_children: List):
|
||||
"""Custom method to check after attaching children
|
||||
Can be overriden with `_BaseNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ancestors(self) -> Iterable:
|
||||
"""Get iterator to yield all ancestors of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
node = self.parent
|
||||
while node is not None:
|
||||
yield node
|
||||
node = node.parent
|
||||
|
||||
@property
|
||||
def descendants(self) -> Iterable:
|
||||
"""Get iterator to yield all descendants of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
yield from preorder_iter(self, filter_condition=lambda _node: _node != self)
|
||||
|
||||
@property
|
||||
def leaves(self) -> Iterable:
|
||||
"""Get iterator to yield all leaf nodes from self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
yield from preorder_iter(self, filter_condition=lambda _node: _node.is_leaf)
|
||||
|
||||
@property
|
||||
def siblings(self) -> Iterable:
|
||||
"""Get siblings of self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.is_root:
|
||||
return ()
|
||||
return tuple(child for child in self.parent.children if child is not self)
|
||||
|
||||
@property
|
||||
def left_sibling(self):
|
||||
"""Get sibling left of self
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.parent:
|
||||
children = self.parent.children
|
||||
child_idx = children.index(self)
|
||||
if child_idx:
|
||||
return self.parent.children[child_idx - 1]
|
||||
return None
|
||||
|
||||
@property
|
||||
def right_sibling(self):
|
||||
"""Get sibling right of self
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.parent:
|
||||
children = self.parent.children
|
||||
child_idx = children.index(self)
|
||||
if child_idx + 1 < len(children):
|
||||
return self.parent.children[child_idx + 1]
|
||||
return None
|
||||
|
||||
@property
|
||||
def node_path(self) -> Iterable:
|
||||
"""Get tuple of nodes starting from root
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.is_root:
|
||||
return [self]
|
||||
return tuple(list(self.parent.node_path) + [self])
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Get indicator if self is root node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return self.parent is None
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.children))
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
"""Get root node of tree
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
if self.is_root:
|
||||
return self
|
||||
return self.parent.root
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
"""Get depth of self, indexing starts from 1
|
||||
|
||||
Returns:
|
||||
(int)
|
||||
"""
|
||||
if self.is_root:
|
||||
return 1
|
||||
return self.parent.depth + 1
|
||||
|
||||
@property
|
||||
def max_depth(self) -> int:
|
||||
"""Get maximum depth from root to leaf node
|
||||
|
||||
Returns:
|
||||
(int)
|
||||
"""
|
||||
return max(node.depth for node in list(preorder_iter(self.root)))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, input_dict: Dict[str, Any]):
|
||||
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
|
||||
Input dictionary must have key `name` if not `Node` will not have any name
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
Args:
|
||||
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return cls(**input_dict)
|
||||
|
||||
def describe(self, exclude_attributes: List[str] = [], exclude_prefix: str = ""):
|
||||
"""Get node information sorted by attribute name, returns list of tuples
|
||||
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a', age=90)
|
||||
>>> a.describe()
|
||||
[('_BaseNode__children', []), ('_BaseNode__parent', None), ('_sep', '/'), ('age', 90), ('name', 'a')]
|
||||
>>> a.describe(exclude_prefix="_")
|
||||
[('age', 90), ('name', 'a')]
|
||||
>>> a.describe(exclude_prefix="_", exclude_attributes=["name"])
|
||||
[('age', 90)]
|
||||
|
||||
Args:
|
||||
exclude_attributes (List[str]): list of attributes to exclude
|
||||
exclude_prefix (str): prefix of attributes to exclude
|
||||
|
||||
Returns:
|
||||
(List[str])
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
|
||||
if (item[0] not in exclude_attributes)
|
||||
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
|
||||
]
|
||||
|
||||
def get_attr(self, attr_name: str) -> Any:
|
||||
"""Get value of node attribute
|
||||
Returns None if attribute name does not exist
|
||||
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a', age=90)
|
||||
>>> a.get_attr("age")
|
||||
90
|
||||
|
||||
Args:
|
||||
attr_name (str): attribute name
|
||||
|
||||
Returns:
|
||||
(Any)
|
||||
"""
|
||||
try:
|
||||
return self.__getattribute__(attr_name)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def set_attrs(self, attrs: Dict[str, Any]):
|
||||
"""Set node attributes
|
||||
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a.set_attrs({"age": 90})
|
||||
>>> a
|
||||
Node(/a, age=90)
|
||||
|
||||
Args:
|
||||
attrs (Dict[str, Any]): attribute dictionary,
|
||||
key: attribute name, value: attribute value
|
||||
"""
|
||||
self.__dict__.update(attrs)
|
||||
|
||||
def go_to(self, node) -> Iterable:
|
||||
"""Get path from current node to specified node from same tree
|
||||
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> a = Node(name="a")
|
||||
>>> b = Node(name="b", parent=a)
|
||||
>>> c = Node(name="c", parent=a)
|
||||
>>> d = Node(name="d", parent=b)
|
||||
>>> e = Node(name="e", parent=b)
|
||||
>>> f = Node(name="f", parent=c)
|
||||
>>> g = Node(name="g", parent=e)
|
||||
>>> h = Node(name="h", parent=e)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
>>> d.go_to(d)
|
||||
[Node(/a/b/d, )]
|
||||
>>> d.go_to(g)
|
||||
[Node(/a/b/d, ), Node(/a/b, ), Node(/a/b/e, ), Node(/a/b/e/g, )]
|
||||
>>> d.go_to(f)
|
||||
[Node(/a/b/d, ), Node(/a/b, ), Node(/a, ), Node(/a/c, ), Node(/a/c/f, )]
|
||||
|
||||
Args:
|
||||
node (Self): node to travel to from current node, inclusive of start and end node
|
||||
|
||||
Returns:
|
||||
(Iterable)
|
||||
"""
|
||||
if not isinstance(node, BaseNode):
|
||||
raise TypeError(
|
||||
f"Expect node to be BaseNode type, received input type {type(node)}"
|
||||
)
|
||||
if self.root != node.root:
|
||||
raise TreeError(
|
||||
f"Nodes are not from the same tree. Check {self} and {node}"
|
||||
)
|
||||
if self == node:
|
||||
return [self]
|
||||
self_path = [self] + list(self.ancestors)
|
||||
node_path = ([node] + list(node.ancestors))[::-1]
|
||||
common_nodes = set(self_path).intersection(set(node_path))
|
||||
self_min_index, min_common_node = sorted(
|
||||
[(self_path.index(_node), _node) for _node in common_nodes]
|
||||
)[0]
|
||||
node_min_index = node_path.index(min_common_node)
|
||||
return self_path[:self_min_index] + node_path[node_min_index:]
|
||||
|
||||
def copy(self):
|
||||
"""Deep copy self; clone BaseNode
|
||||
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a_copy = a.copy()
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def sort(self, **kwargs):
|
||||
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
|
||||
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> a = Node('a')
|
||||
>>> c = Node("c", parent=a)
|
||||
>>> b = Node("b", parent=a)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── c
|
||||
└── b
|
||||
>>> a.sort(key=lambda node: node.name)
|
||||
>>> print_tree(a)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
"""
|
||||
children = list(self.children)
|
||||
children.sort(**kwargs)
|
||||
self.__children = children
|
||||
|
||||
def __copy__(self):
|
||||
"""Shallow copy self
|
||||
|
||||
>>> import copy
|
||||
>>> from bigtree.node.node import Node
|
||||
>>> a = Node('a')
|
||||
>>> a_copy = copy.deepcopy(a)
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
obj = type(self).__new__(self.__class__)
|
||||
obj.__dict__.update(self.__dict__)
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_")
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({node_description})"
|
||||
|
||||
def __rshift__(self, other):
|
||||
"""Set children using >> bitshift operator for self >> other
|
||||
|
||||
Args:
|
||||
other (Self): other node, children
|
||||
"""
|
||||
other.parent = self
|
||||
|
||||
def __lshift__(self, other):
|
||||
"""Set parent using << bitshift operator for self << other
|
||||
|
||||
Args:
|
||||
other (Self): other node, parent
|
||||
"""
|
||||
self.parent = other
|
||||
395
python37/packages/bigtree/node/binarynode.py
Normal file
395
python37/packages/bigtree/node/binarynode.py
Normal file
@@ -0,0 +1,395 @@
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError
|
||||
|
||||
|
||||
class BinaryNode(Node):
|
||||
"""
|
||||
BinaryNode is an extension of Node, and is able to extend to any Python class for Binary Tree implementation.
|
||||
Nodes can have attributes if they are initialized from `BinaryNode`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
BinaryNode can be linked to each other with `children`, `left`, or `right` setter methods.
|
||||
If initialized with `children`, it must be length 2, denoting left and right child.
|
||||
|
||||
>>> from bigtree import BinaryNode, print_tree
|
||||
>>> a = BinaryNode(1)
|
||||
>>> b = BinaryNode(2)
|
||||
>>> c = BinaryNode(3)
|
||||
>>> d = BinaryNode(4)
|
||||
>>> a.children = [b, c]
|
||||
>>> b.right = d
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 2
|
||||
│ └── 4
|
||||
└── 3
|
||||
|
||||
Directly passing `left`, `right`, or `children` argument.
|
||||
|
||||
>>> from bigtree import BinaryNode
|
||||
>>> d = BinaryNode(4)
|
||||
>>> c = BinaryNode(3)
|
||||
>>> b = BinaryNode(2, right=d)
|
||||
>>> a = BinaryNode(1, children=[b, c])
|
||||
|
||||
**BinaryNode Creation**
|
||||
|
||||
Node can be created by instantiating a `BinaryNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import BinaryNode
|
||||
>>> a = BinaryNode.from_dict({"name": "1"})
|
||||
>>> a
|
||||
BinaryNode(name=1, val=1)
|
||||
|
||||
**BinaryNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get `BinaryNode` configuration
|
||||
|
||||
1. ``left``: Get left children
|
||||
2. ``right``: Get right children
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, int] = "",
|
||||
left=None,
|
||||
right=None,
|
||||
parent=None,
|
||||
children: List = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.val = int(name)
|
||||
self.name = str(name)
|
||||
self._sep = "/"
|
||||
self.__parent = None
|
||||
self.__children = []
|
||||
if not children:
|
||||
children = []
|
||||
if len(children):
|
||||
if len(children) and len(children) != 2:
|
||||
raise ValueError("Children input must have length 2")
|
||||
if left and left != children[0]:
|
||||
raise ValueError(
|
||||
f"Attempting to set both left and children with mismatched values\n"
|
||||
f"Check left {left} and children {children}"
|
||||
)
|
||||
if right and right != children[1]:
|
||||
raise ValueError(
|
||||
f"Attempting to set both right and children with mismatched values\n"
|
||||
f"Check right {right} and children {children}"
|
||||
)
|
||||
else:
|
||||
children = [left, right]
|
||||
self.parent = parent
|
||||
self.children = children
|
||||
if "parents" in kwargs:
|
||||
raise ValueError(
|
||||
"Attempting to set `parents` attribute, do you mean `parent`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@property
|
||||
def left(self):
|
||||
"""Get left children
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__children[0]
|
||||
|
||||
@left.setter
|
||||
def left(self, left_child):
|
||||
"""Set left children
|
||||
|
||||
Args:
|
||||
left_child (Self): left child
|
||||
"""
|
||||
self.children = [left_child, self.right]
|
||||
|
||||
@property
|
||||
def right(self):
|
||||
"""Get right children
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__children[1]
|
||||
|
||||
@right.setter
|
||||
def right(self, right_child):
|
||||
"""Set right children
|
||||
|
||||
Args:
|
||||
right_child (Self): right child
|
||||
"""
|
||||
self.children = [self.left, right_child]
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
"""Get parent node
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return self.__parent
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parent):
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
if not (isinstance(new_parent, BinaryNode) or new_parent is None):
|
||||
raise TypeError(
|
||||
f"Expect input to be BinaryNode type or NoneType, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
@parent.setter
|
||||
def parent(self, new_parent):
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
self.__check_parent_type(new_parent)
|
||||
self._BaseNode__check_parent_loop(new_parent)
|
||||
|
||||
current_parent = self.parent
|
||||
current_child_idx = None
|
||||
|
||||
# Assign new parent - rollback if error
|
||||
self.__pre_assign_parent(new_parent)
|
||||
try:
|
||||
# Remove self from old parent
|
||||
if current_parent is not None:
|
||||
if not any(
|
||||
child is self for child in current_parent.children
|
||||
): # pragma: no cover
|
||||
raise CorruptedTreeError(
|
||||
"Error setting parent: Node does not exist as children of its parent"
|
||||
)
|
||||
current_child_idx = current_parent.__children.index(self)
|
||||
current_parent.__children[current_child_idx] = None
|
||||
|
||||
# Assign self to new parent
|
||||
self.__parent = new_parent
|
||||
if new_parent is not None:
|
||||
inserted = False
|
||||
for child_idx, child in enumerate(new_parent.__children):
|
||||
if not child and not inserted:
|
||||
new_parent.__children[child_idx] = self
|
||||
inserted = True
|
||||
if not inserted:
|
||||
raise TreeError(f"Parent {new_parent} already has 2 children")
|
||||
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
if new_parent is not None and self in new_parent.__children:
|
||||
child_idx = new_parent.__children.index(self)
|
||||
new_parent.__children[child_idx] = None
|
||||
|
||||
# Reassign self to old parent
|
||||
self.__parent = current_parent
|
||||
if current_child_idx is not None:
|
||||
current_parent.__children[current_child_idx] = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_parent(self, new_parent):
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overriden with `_BinaryNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self, new_parent):
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overriden with `_BinaryNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def children(self) -> Iterable:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
def __check_children_type(self, new_children: List) -> List:
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
if not len(new_children):
|
||||
new_children = [None, None]
|
||||
if len(new_children) != 2:
|
||||
raise ValueError("Children input must have length 2")
|
||||
return new_children
|
||||
|
||||
def __check_children_loop(self, new_children: List):
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if new_child is not None and not isinstance(new_child, BinaryNode):
|
||||
raise TypeError(
|
||||
f"Expect input to be BinaryNode type or NoneType, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestors of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if new_child is not None:
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@children.setter
|
||||
def children(self, new_children: List):
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
self._BaseNode__check_children_type(new_children)
|
||||
new_children = self.__check_children_type(new_children)
|
||||
self.__check_children_loop(new_children)
|
||||
|
||||
current_new_children = {
|
||||
new_child: (
|
||||
new_child.parent.__children.index(new_child),
|
||||
new_child.parent,
|
||||
)
|
||||
for new_child in new_children
|
||||
if new_child is not None and new_child.parent is not None
|
||||
}
|
||||
current_new_orphan = [
|
||||
new_child
|
||||
for new_child in new_children
|
||||
if new_child is not None and new_child.parent is None
|
||||
]
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Remove old children from self
|
||||
del self.children
|
||||
|
||||
# Assign new children to self
|
||||
self.__children = new_children
|
||||
for new_child in new_children:
|
||||
if new_child is not None:
|
||||
if new_child.parent:
|
||||
child_idx = new_child.parent.__children.index(new_child)
|
||||
new_child.parent.__children[child_idx] = None
|
||||
new_child.__parent = self
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign new children to their original parent
|
||||
for child, idx_parent in current_new_children.items():
|
||||
child_idx, parent = idx_parent
|
||||
child.__parent = parent
|
||||
parent.__children[child_idx] = child
|
||||
for child in current_new_orphan:
|
||||
child.__parent = None
|
||||
|
||||
# Reassign old children to self
|
||||
self.__children = current_children
|
||||
for child in current_children:
|
||||
if child:
|
||||
child.__parent = self
|
||||
raise TreeError(exc_info)
|
||||
|
||||
@children.deleter
|
||||
def children(self):
|
||||
"""Delete child node(s)"""
|
||||
for child in self.children:
|
||||
if child is not None:
|
||||
child.parent.__children.remove(child)
|
||||
child.__parent = None
|
||||
|
||||
def __pre_assign_children(self, new_children: List):
|
||||
"""Custom method to check before attaching children
|
||||
Can be overriden with `_BinaryNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self, new_children: List):
|
||||
"""Custom method to check after attaching children
|
||||
Can be overriden with `_BinaryNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len([child for child in self.children if child])
|
||||
|
||||
def sort(self, **kwargs):
|
||||
"""Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True``
|
||||
|
||||
>>> from bigtree import BinaryNode, print_tree
|
||||
>>> a = BinaryNode(1)
|
||||
>>> c = BinaryNode(3, parent=a)
|
||||
>>> b = BinaryNode(2, parent=a)
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 3
|
||||
└── 2
|
||||
>>> a.sort(key=lambda node: node.val)
|
||||
>>> print_tree(a)
|
||||
1
|
||||
├── 2
|
||||
└── 3
|
||||
"""
|
||||
children = [child for child in self.children if child]
|
||||
if len(children) == 2:
|
||||
children.sort(**kwargs)
|
||||
self.__children = children
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_", exclude_attributes=[])
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({node_description})"
|
||||
570
python37/packages/bigtree/node/dagnode.py
Normal file
570
python37/packages/bigtree/node/dagnode.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from bigtree.utils.exceptions import LoopError, TreeError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
|
||||
class DAGNode:
|
||||
"""
|
||||
Base DAGNode extends any Python class to a DAG node, for DAG implementation.
|
||||
In DAG implementation, a node can have multiple parents.
|
||||
Parents and children cannot be reassigned once assigned, as Nodes are allowed to have multiple parents and children.
|
||||
If each node only has one parent, use `Node` class.
|
||||
DAGNodes can have attributes if they are initialized from `DAGNode` or dictionary.
|
||||
|
||||
DAGNode can be linked to each other with `parents` and `children` setter methods,
|
||||
or using bitshift operator with the convention `parent_node >> child_node` or `child_node << parent_node`.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> c.parents = [a, b]
|
||||
>>> c.children = [d]
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> a >> c
|
||||
>>> b >> c
|
||||
>>> d << c
|
||||
|
||||
Directly passing `parents` argument.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c", parents=[a, b])
|
||||
>>> d = DAGNode("d", parents=[c])
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> d = DAGNode("d")
|
||||
>>> c = DAGNode("c", children=[d])
|
||||
>>> b = DAGNode("b", children=[c])
|
||||
>>> a = DAGNode("a", children=[c])
|
||||
|
||||
**DAGNode Creation**
|
||||
|
||||
Node can be created by instantiating a `DAGNode` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**DAGNode Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set other `DAGNode`
|
||||
|
||||
1. ``parents``: Get/set parent nodes
|
||||
2. ``children``: Get/set child nodes
|
||||
|
||||
Get other `DAGNode`
|
||||
|
||||
1. ``ancestors``: Get ancestors of node excluding self, iterator
|
||||
2. ``descendants``: Get descendants of node excluding self, iterator
|
||||
3. ``siblings``: Get siblings of self
|
||||
|
||||
Get `DAGNode` configuration
|
||||
|
||||
1. ``node_name``: Get node name, without accessing `name` directly
|
||||
2. ``is_root``: Get indicator if self is root node
|
||||
3. ``is_leaf``: Get indicator if self is leaf node
|
||||
|
||||
**DAGNode Methods**
|
||||
|
||||
These are methods available to be performed on `DAGNode`.
|
||||
|
||||
Constructor methods
|
||||
|
||||
1. ``from_dict()``: Create DAGNode from dictionary
|
||||
|
||||
`DAGNode` methods
|
||||
|
||||
1. ``describe()``: Get node information sorted by attributes, returns list of tuples
|
||||
2. ``get_attr(attr_name: str)``: Get value of node attribute
|
||||
3. ``set_attrs(attrs: dict)``: Set node attribute name(s) and value(s)
|
||||
4. ``go_to(node: BaseNode)``: Get a path from own node to another node from same DAG
|
||||
5. ``copy()``: Deep copy DAGNode
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "", parents: List = None, children: List = None, **kwargs
|
||||
):
|
||||
self.name = name
|
||||
self.__parents = []
|
||||
self.__children = []
|
||||
if parents is None:
|
||||
parents = []
|
||||
if children is None:
|
||||
children = []
|
||||
self.parents = parents
|
||||
self.children = children
|
||||
if "parent" in kwargs:
|
||||
raise ValueError(
|
||||
"Attempting to set `parent` attribute, do you mean `parents`?"
|
||||
)
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
@property
|
||||
def parent(self) -> None:
|
||||
"""Do not allow `parent` attribute to be accessed"""
|
||||
raise ValueError(
|
||||
"Attempting to access `parent` attribute, do you mean `parents`?"
|
||||
)
|
||||
|
||||
@parent.setter
|
||||
def parent(self, new_parent):
|
||||
"""Do not allow `parent` attribute to be set
|
||||
|
||||
Args:
|
||||
new_parent (Self): parent node
|
||||
"""
|
||||
raise ValueError("Attempting to set `parent` attribute, do you mean `parents`?")
|
||||
|
||||
@property
|
||||
def parents(self) -> Iterable:
|
||||
"""Get parent nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__parents)
|
||||
|
||||
@staticmethod
|
||||
def __check_parent_type(new_parents: List):
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
if not isinstance(new_parents, list):
|
||||
raise TypeError(
|
||||
f"Parents input should be list type, received input type {type(new_parents)}"
|
||||
)
|
||||
|
||||
def __check_parent_loop(self, new_parents: List):
|
||||
"""Check parent type
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
seen_parent = []
|
||||
for new_parent in new_parents:
|
||||
# Check type
|
||||
if not isinstance(new_parent, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect input to be DAGNode type, received input type {type(new_parent)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_parent is self:
|
||||
raise LoopError("Error setting parent: Node cannot be parent of itself")
|
||||
if new_parent.ancestors:
|
||||
if any(ancestor is self for ancestor in new_parent.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting parent: Node cannot be ancestor of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_parent) in seen_parent:
|
||||
raise TreeError(
|
||||
"Error setting parent: Node cannot be added multiple times as a parent"
|
||||
)
|
||||
else:
|
||||
seen_parent.append(id(new_parent))
|
||||
|
||||
@parents.setter
|
||||
def parents(self, new_parents: List):
|
||||
"""Set parent node
|
||||
|
||||
Args:
|
||||
new_parents (List[Self]): parent nodes
|
||||
"""
|
||||
self.__check_parent_type(new_parents)
|
||||
self.__check_parent_loop(new_parents)
|
||||
|
||||
current_parents = self.__parents.copy()
|
||||
|
||||
# Assign new parents - rollback if error
|
||||
self.__pre_assign_parents(new_parents)
|
||||
try:
|
||||
# Assign self to new parent
|
||||
for new_parent in new_parents:
|
||||
if new_parent not in self.__parents:
|
||||
self.__parents.append(new_parent)
|
||||
new_parent.__children.append(self)
|
||||
|
||||
self.__post_assign_parents(new_parents)
|
||||
except Exception as exc_info:
|
||||
# Remove self from new parent
|
||||
for new_parent in new_parents:
|
||||
if new_parent not in current_parents:
|
||||
self.__parents.remove(new_parent)
|
||||
new_parent.__children.remove(self)
|
||||
raise TreeError(
|
||||
f"{exc_info}, current parents {current_parents}, new parents {new_parents}"
|
||||
)
|
||||
|
||||
def __pre_assign_parents(self, new_parents: List):
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overriden with `_DAGNode__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parents (List): new parents to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parents(self, new_parents: List):
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overriden with `_DAGNode__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parents (List): new parents to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def children(self) -> Iterable:
|
||||
"""Get child nodes
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
return tuple(self.__children)
|
||||
|
||||
def __check_children_type(self, new_children: List):
|
||||
"""Check child type
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
if not isinstance(new_children, list):
|
||||
raise TypeError(
|
||||
f"Children input should be list type, received input type {type(new_children)}"
|
||||
)
|
||||
|
||||
def __check_children_loop(self, new_children: List):
|
||||
"""Check child loop
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
seen_children = []
|
||||
for new_child in new_children:
|
||||
# Check type
|
||||
if not isinstance(new_child, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect input to be DAGNode type, received input type {type(new_child)}"
|
||||
)
|
||||
|
||||
# Check for loop and tree structure
|
||||
if new_child is self:
|
||||
raise LoopError("Error setting child: Node cannot be child of itself")
|
||||
if any(child is new_child for child in self.ancestors):
|
||||
raise LoopError(
|
||||
"Error setting child: Node cannot be ancestors of itself"
|
||||
)
|
||||
|
||||
# Check for duplicate children
|
||||
if id(new_child) in seen_children:
|
||||
raise TreeError(
|
||||
"Error setting child: Node cannot be added multiple times as a child"
|
||||
)
|
||||
else:
|
||||
seen_children.append(id(new_child))
|
||||
|
||||
@children.setter
|
||||
def children(self, new_children: List):
|
||||
"""Set child nodes
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): child node
|
||||
"""
|
||||
self.__check_children_type(new_children)
|
||||
self.__check_children_loop(new_children)
|
||||
|
||||
current_children = list(self.children)
|
||||
|
||||
# Assign new children - rollback if error
|
||||
self.__pre_assign_children(new_children)
|
||||
try:
|
||||
# Assign new children to self
|
||||
for new_child in new_children:
|
||||
if self not in new_child.__parents:
|
||||
new_child.__parents.append(self)
|
||||
self.__children.append(new_child)
|
||||
self.__post_assign_children(new_children)
|
||||
except Exception as exc_info:
|
||||
# Reassign old children to self
|
||||
for new_child in new_children:
|
||||
if new_child not in current_children:
|
||||
new_child.__parents.remove(self)
|
||||
self.__children.remove(new_child)
|
||||
raise TreeError(exc_info)
|
||||
|
||||
def __pre_assign_children(self, new_children: List):
|
||||
"""Custom method to check before attaching children
|
||||
Can be overriden with `_DAGNode__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self, new_children: List):
|
||||
"""Custom method to check after attaching children
|
||||
Can be overriden with `_DAGNode__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ancestors(self) -> Iterable:
|
||||
"""Get iterator to yield all ancestors of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if not len(list(self.parents)):
|
||||
return ()
|
||||
|
||||
def recursive_parent(node):
|
||||
for _node in node.parents:
|
||||
yield from recursive_parent(_node)
|
||||
yield _node
|
||||
|
||||
ancestors = list(recursive_parent(self))
|
||||
return list(dict.fromkeys(ancestors))
|
||||
|
||||
@property
|
||||
def descendants(self) -> Iterable:
|
||||
"""Get iterator to yield all descendants of self, does not include self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
descendants = list(
|
||||
preorder_iter(self, filter_condition=lambda _node: _node != self)
|
||||
)
|
||||
return list(dict.fromkeys(descendants))
|
||||
|
||||
@property
|
||||
def siblings(self) -> Iterable:
|
||||
"""Get siblings of self
|
||||
|
||||
Returns:
|
||||
(Iterable[Self])
|
||||
"""
|
||||
if self.is_root:
|
||||
return ()
|
||||
return tuple(
|
||||
child
|
||||
for parent in self.parents
|
||||
for child in parent.children
|
||||
if child is not self
|
||||
)
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
"""Get node name
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Get indicator if self is root node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.parents))
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Get indicator if self is leaf node
|
||||
|
||||
Returns:
|
||||
(bool)
|
||||
"""
|
||||
return not len(list(self.children))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, input_dict: Dict[str, Any]):
|
||||
"""Construct node from dictionary, all keys of dictionary will be stored as class attributes
|
||||
Input dictionary must have key `name` if not `Node` will not have any name
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode.from_dict({"name": "a", "age": 90})
|
||||
|
||||
Args:
|
||||
input_dict (Dict[str, Any]): dictionary with node information, key: attribute name, value: attribute value
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return cls(**input_dict)
|
||||
|
||||
def describe(self, exclude_attributes: List[str] = [], exclude_prefix: str = ""):
|
||||
"""Get node information sorted by attribute name, returns list of tuples
|
||||
|
||||
Args:
|
||||
exclude_attributes (List[str]): list of attributes to exclude
|
||||
exclude_prefix (str): prefix of attributes to exclude
|
||||
|
||||
Returns:
|
||||
(List[str])
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for item in sorted(self.__dict__.items(), key=lambda item: item[0])
|
||||
if (item[0] not in exclude_attributes)
|
||||
and (not len(exclude_prefix) or not item[0].startswith(exclude_prefix))
|
||||
]
|
||||
|
||||
def get_attr(self, attr_name: str) -> Any:
|
||||
"""Get value of node attribute
|
||||
Returns None if attribute name does not exist
|
||||
|
||||
Args:
|
||||
attr_name (str): attribute name
|
||||
|
||||
Returns:
|
||||
(Any)
|
||||
"""
|
||||
try:
|
||||
return self.__getattribute__(attr_name)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def set_attrs(self, attrs: Dict[str, Any]):
|
||||
"""Set node attributes
|
||||
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a.set_attrs({"age": 90})
|
||||
>>> a
|
||||
DAGNode(a, age=90)
|
||||
|
||||
Args:
|
||||
attrs (Dict[str, Any]): attribute dictionary,
|
||||
key: attribute name, value: attribute value
|
||||
"""
|
||||
self.__dict__.update(attrs)
|
||||
|
||||
def go_to(self, node) -> Iterable[Iterable]:
|
||||
"""Get list of possible paths from current node to specified node from same tree
|
||||
|
||||
>>> from bigtree import DAGNode
|
||||
>>> a = DAGNode("a")
|
||||
>>> b = DAGNode("b")
|
||||
>>> c = DAGNode("c")
|
||||
>>> d = DAGNode("d")
|
||||
>>> a >> c
|
||||
>>> b >> c
|
||||
>>> c >> d
|
||||
>>> a >> d
|
||||
>>> a.go_to(c)
|
||||
[[DAGNode(a, ), DAGNode(c, )]]
|
||||
>>> a.go_to(d)
|
||||
[[DAGNode(a, ), DAGNode(c, ), DAGNode(d, )], [DAGNode(a, ), DAGNode(d, )]]
|
||||
>>> a.go_to(b)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bigtree.utils.exceptions.TreeError: It is not possible to go to DAGNode(b, )
|
||||
|
||||
Args:
|
||||
node (Self): node to travel to from current node, inclusive of start and end node
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable])
|
||||
"""
|
||||
if not isinstance(node, DAGNode):
|
||||
raise TypeError(
|
||||
f"Expect node to be DAGNode type, received input type {type(node)}"
|
||||
)
|
||||
if self == node:
|
||||
return [self]
|
||||
if node not in self.descendants:
|
||||
raise TreeError(f"It is not possible to go to {node}")
|
||||
|
||||
self.__path = []
|
||||
|
||||
def recursive_path(_node, _path, _ans):
|
||||
if _node: # pragma: no cover
|
||||
_path.append(_node)
|
||||
if _node == node:
|
||||
return _path
|
||||
for _child in _node.children:
|
||||
ans = recursive_path(_child, _path.copy(), _ans)
|
||||
if ans:
|
||||
self.__path.append(ans)
|
||||
|
||||
recursive_path(self, [], [])
|
||||
return self.__path
|
||||
|
||||
def copy(self):
|
||||
"""Deep copy self; clone DAGNode
|
||||
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a_copy = a.copy()
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def __copy__(self):
|
||||
"""Shallow copy self
|
||||
|
||||
>>> import copy
|
||||
>>> from bigtree.node.dagnode import DAGNode
|
||||
>>> a = DAGNode('a')
|
||||
>>> a_copy = copy.deepcopy(a)
|
||||
|
||||
Returns:
|
||||
(Self)
|
||||
"""
|
||||
obj = type(self).__new__(self.__class__)
|
||||
obj.__dict__.update(self.__dict__)
|
||||
return obj
|
||||
|
||||
def __rshift__(self, other):
|
||||
"""Set children using >> bitshift operator for self >> other
|
||||
|
||||
Args:
|
||||
other (Self): other node, children
|
||||
"""
|
||||
other.parents = [self]
|
||||
|
||||
def __lshift__(self, other):
|
||||
"""Set parent using << bitshift operator for self << other
|
||||
|
||||
Args:
|
||||
other (Self): other node, parent
|
||||
"""
|
||||
self.parents = [other]
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_attributes=["name"])
|
||||
node_description = ", ".join(
|
||||
[f"{k}={v}" for k, v in node_dict if not k.startswith("_")]
|
||||
)
|
||||
return f"{class_name}({self.node_name}, {node_description})"
|
||||
204
python37/packages/bigtree/node/node.py
Normal file
204
python37/packages/bigtree/node/node.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.utils.exceptions import TreeError
|
||||
|
||||
|
||||
class Node(BaseNode):
|
||||
"""
|
||||
Node is an extension of BaseNode, and is able to extend to any Python class.
|
||||
Nodes can have attributes if they are initialized from `Node`, *dictionary*, or *pandas DataFrame*.
|
||||
|
||||
Nodes can be linked to each other with `parent` and `children` setter methods.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node("a")
|
||||
>>> b = Node("b")
|
||||
>>> c = Node("c")
|
||||
>>> d = Node("d")
|
||||
>>> b.parent = a
|
||||
>>> b.children = [c, d]
|
||||
|
||||
Directly passing `parent` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node("a")
|
||||
>>> b = Node("b", parent=a)
|
||||
>>> c = Node("c", parent=b)
|
||||
>>> d = Node("d", parent=b)
|
||||
|
||||
Directly passing `children` argument.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> d = Node("d")
|
||||
>>> c = Node("c")
|
||||
>>> b = Node("b", children=[c, d])
|
||||
>>> a = Node("a", children=[b])
|
||||
|
||||
**Node Creation**
|
||||
|
||||
Node can be created by instantiating a `Node` class or by using a *dictionary*.
|
||||
If node is created with dictionary, all keys of dictionary will be stored as class attributes.
|
||||
|
||||
>>> from bigtree import Node
|
||||
>>> a = Node.from_dict({"name": "a", "age": 90})
|
||||
|
||||
**Node Attributes**
|
||||
|
||||
These are node attributes that have getter and/or setter methods.
|
||||
|
||||
Get and set `Node` configuration
|
||||
|
||||
1. ``sep``: Get/set separator for path name
|
||||
|
||||
Get `Node` configuration
|
||||
|
||||
1. ``node_name``: Get node name, without accessing `name` directly
|
||||
2. ``path_name``: Get path name from root, separated by `sep`
|
||||
|
||||
----
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", **kwargs):
|
||||
self.name = name
|
||||
self._sep: str = "/"
|
||||
super().__init__(**kwargs)
|
||||
if not self.node_name:
|
||||
raise TreeError("Node must have a `name` attribute")
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
"""Get node name
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def sep(self) -> str:
|
||||
"""Get separator, gets from root node
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
if self.is_root:
|
||||
return self._sep
|
||||
return self.parent.sep
|
||||
|
||||
@sep.setter
|
||||
def sep(self, value: str):
|
||||
"""Set separator, affects root node
|
||||
|
||||
Args:
|
||||
value (str): separator to replace default separator
|
||||
"""
|
||||
self.root._sep = value
|
||||
|
||||
@property
|
||||
def path_name(self) -> str:
|
||||
"""Get path name, separated by self.sep
|
||||
|
||||
Returns:
|
||||
(str)
|
||||
"""
|
||||
if self.is_root:
|
||||
return f"{self.sep}{self.name}"
|
||||
return f"{self.parent.path_name}{self.sep}{self.name}"
|
||||
|
||||
def __pre_assign_children(self, new_children: List):
|
||||
"""Custom method to check before attaching children
|
||||
Can be overriden with `_Node__pre_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_children(self, new_children: List):
|
||||
"""Custom method to check after attaching children
|
||||
Can be overriden with `_Node__post_assign_children()`
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __pre_assign_parent(self, new_parent):
|
||||
"""Custom method to check before attaching parent
|
||||
Can be overriden with `_Node__pre_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def __post_assign_parent(self, new_parent):
|
||||
"""Custom method to check after attaching parent
|
||||
Can be overriden with `_Node__post_assign_parent()`
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
pass
|
||||
|
||||
def _BaseNode__pre_assign_parent(self, new_parent):
|
||||
"""Do not allow duplicate nodes of same path
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
self.__pre_assign_parent(new_parent)
|
||||
if new_parent is not None:
|
||||
if any(
|
||||
child.node_name == self.node_name and child is not self
|
||||
for child in new_parent.children
|
||||
):
|
||||
raise TreeError(
|
||||
f"Error: Duplicate node with same path\n"
|
||||
f"There exist a node with same path {new_parent.path_name}{self.sep}{self.node_name}"
|
||||
)
|
||||
|
||||
def _BaseNode__post_assign_parent(self, new_parent):
|
||||
"""No rules
|
||||
|
||||
Args:
|
||||
new_parent (Self): new parent to be added
|
||||
"""
|
||||
self.__post_assign_parent(new_parent)
|
||||
|
||||
def _BaseNode__pre_assign_children(self, new_children: List):
|
||||
"""Do not allow duplicate nodes of same path
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
self.__pre_assign_children(new_children)
|
||||
children_names = [node.node_name for node in new_children]
|
||||
duplicated_names = [
|
||||
item[0] for item in Counter(children_names).items() if item[1] > 1
|
||||
]
|
||||
if len(duplicated_names):
|
||||
duplicated_names = " and ".join(
|
||||
[f"{self.path_name}{self.sep}{name}" for name in duplicated_names]
|
||||
)
|
||||
raise TreeError(
|
||||
f"Error: Duplicate node with same path\n"
|
||||
f"Attempting to add nodes same path {duplicated_names}"
|
||||
)
|
||||
|
||||
def _BaseNode__post_assign_children(self, new_children: List):
|
||||
"""No rules
|
||||
|
||||
Args:
|
||||
new_children (List[Self]): new children to be added
|
||||
"""
|
||||
self.__post_assign_children(new_children)
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
node_dict = self.describe(exclude_prefix="_", exclude_attributes=["name"])
|
||||
node_description = ", ".join([f"{k}={v}" for k, v in node_dict])
|
||||
return f"{class_name}({self.path_name}, {node_description})"
|
||||
0
python37/packages/bigtree/tree/__init__.py
Normal file
0
python37/packages/bigtree/tree/__init__.py
Normal file
914
python37/packages/bigtree/tree/construct.py
Normal file
914
python37/packages/bigtree/tree/construct.py
Normal file
@@ -0,0 +1,914 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.export import tree_to_dataframe
|
||||
from bigtree.tree.search import find_children, find_name
|
||||
from bigtree.utils.exceptions import DuplicatedNodeError, TreeError
|
||||
|
||||
__all__ = [
|
||||
"add_path_to_tree",
|
||||
"add_dict_to_tree_by_path",
|
||||
"add_dict_to_tree_by_name",
|
||||
"add_dataframe_to_tree_by_path",
|
||||
"add_dataframe_to_tree_by_name",
|
||||
"str_to_tree",
|
||||
"list_to_tree",
|
||||
"list_to_tree_by_relation",
|
||||
"dict_to_tree",
|
||||
"nested_dict_to_tree",
|
||||
"dataframe_to_tree",
|
||||
"dataframe_to_tree_by_relation",
|
||||
]
|
||||
|
||||
|
||||
def add_path_to_tree(
|
||||
tree: Node,
|
||||
path: str,
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
node_attrs: dict = {},
|
||||
) -> Node:
|
||||
"""Add nodes and attributes to existing tree *in-place*, return node of added path.
|
||||
Adds to existing tree from list of path strings.
|
||||
|
||||
Path should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
- Path separator `sep` is for the input `path` and can be different from that of existing tree.
|
||||
|
||||
Path can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> from bigtree import add_path_to_tree, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> add_path_to_tree(root, "a/b/c")
|
||||
Node(/a/b/c, )
|
||||
>>> print_tree(root)
|
||||
a
|
||||
└── b
|
||||
└── c
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
path (str): path to be added to tree
|
||||
sep (str): path separator for input `path`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
node_attrs (dict): attributes to add to node, key: attribute name, value: attribute value, optional
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(path):
|
||||
raise ValueError("Path is empty, check `path`")
|
||||
|
||||
tree_root = tree.root
|
||||
tree_sep = tree_root.sep
|
||||
node_type = tree_root.__class__
|
||||
branch = path.lstrip(sep).rstrip(sep).split(sep)
|
||||
if branch[0] != tree_root.node_name:
|
||||
raise TreeError(
|
||||
f"Error: Path does not have same root node, expected {tree_root.node_name}, received {branch[0]}\n"
|
||||
f"Check your input paths or verify that path separator `sep` is set correctly"
|
||||
)
|
||||
|
||||
# Grow tree
|
||||
node = tree_root
|
||||
parent_node = tree_root
|
||||
for idx in range(1, len(branch)):
|
||||
node_name = branch[idx]
|
||||
node_path = tree_sep.join(branch[: idx + 1])
|
||||
if not duplicate_name_allowed:
|
||||
node = find_name(tree_root, node_name)
|
||||
if node and not node.path_name.endswith(node_path):
|
||||
raise DuplicatedNodeError(
|
||||
f"Node {node_name} already exists, try setting `duplicate_name_allowed` to True "
|
||||
f"to allow `Node` with same node name"
|
||||
)
|
||||
else:
|
||||
node = find_children(parent_node, node_name)
|
||||
if not node:
|
||||
node = node_type(branch[idx])
|
||||
node.parent = parent_node
|
||||
parent_node = node
|
||||
node.set_attrs(node_attrs)
|
||||
return node
|
||||
|
||||
|
||||
def add_dict_to_tree_by_path(
|
||||
tree: Node,
|
||||
path_attrs: dict,
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
) -> Node:
|
||||
"""Add nodes and attributes to tree *in-place*, return root of tree.
|
||||
Adds to existing tree from nested dictionary, ``key``: path, ``value``: dict of attribute name and attribute value.
|
||||
|
||||
Path should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
- Path separator `sep` is for the input `path` and can be different from that of existing tree.
|
||||
|
||||
Path can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> from bigtree import Node, add_dict_to_tree_by_path, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> path_dict = {
|
||||
... "a": {"age": 90},
|
||||
... "a/b": {"age": 65},
|
||||
... "a/c": {"age": 60},
|
||||
... "a/b/d": {"age": 40},
|
||||
... "a/b/e": {"age": 35},
|
||||
... "a/c/f": {"age": 38},
|
||||
... "a/b/e/g": {"age": 10},
|
||||
... "a/b/e/h": {"age": 6},
|
||||
... }
|
||||
>>> root = add_dict_to_tree_by_path(root, path_dict)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
path_attrs (dict): dictionary containing node path and attribute information,
|
||||
key: node path, value: dict of node attribute name and attribute value
|
||||
sep (str): path separator for input `path_attrs`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(path_attrs):
|
||||
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
|
||||
|
||||
tree_root = tree.root
|
||||
|
||||
for k, v in path_attrs.items():
|
||||
add_path_to_tree(
|
||||
tree_root,
|
||||
k,
|
||||
sep=sep,
|
||||
duplicate_name_allowed=duplicate_name_allowed,
|
||||
node_attrs=v,
|
||||
)
|
||||
return tree_root
|
||||
|
||||
|
||||
def add_dict_to_tree_by_name(
|
||||
tree: Node, path_attrs: dict, join_type: str = "left"
|
||||
) -> Node:
|
||||
"""Add attributes to tree, return *new* root of tree.
|
||||
Adds to existing tree from nested dictionary, ``key``: name, ``value``: dict of attribute name and attribute value.
|
||||
|
||||
Function can return all existing tree nodes or only tree nodes that are in the input dictionary keys.
|
||||
Input dictionary keys that are not existing node names will be ignored.
|
||||
Note that if multiple nodes have the same name, attributes will be added to all nodes sharing same name.
|
||||
|
||||
>>> from bigtree import Node, add_dict_to_tree_by_name, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> name_dict = {
|
||||
... "a": {"age": 90},
|
||||
... "b": {"age": 65},
|
||||
... }
|
||||
>>> root = add_dict_to_tree_by_name(root, name_dict)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
└── b [age=65]
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
path_attrs (dict): dictionary containing node name and attribute information,
|
||||
key: node name, value: dict of node attribute name and attribute value
|
||||
join_type (str): join type with attribute, default of 'left' takes existing tree nodes,
|
||||
if join_type is set to 'inner' it will only take tree nodes that are in `path_attrs` key and drop others
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if join_type not in ["inner", "left"]:
|
||||
raise ValueError("`join_type` must be one of 'inner' or 'left'")
|
||||
|
||||
if not len(path_attrs):
|
||||
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
|
||||
|
||||
# Convert dictionary to dataframe
|
||||
data = pd.DataFrame(path_attrs).T.rename_axis("NAME").reset_index()
|
||||
return add_dataframe_to_tree_by_name(tree, data=data, join_type=join_type)
|
||||
|
||||
|
||||
def add_dataframe_to_tree_by_path(
|
||||
tree: Node,
|
||||
data: pd.DataFrame,
|
||||
path_col: str = "",
|
||||
attribute_cols: list = [],
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
) -> Node:
|
||||
"""Add nodes and attributes to tree *in-place*, return root of tree.
|
||||
|
||||
`path_col` and `attribute_cols` specify columns for node path and attributes to add to existing tree.
|
||||
If columns are not specified, `path_col` takes first column and all other columns are `attribute_cols`
|
||||
|
||||
Path in path column should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
- Path separator `sep` is for the input `path_col` and can be different from that of existing tree.
|
||||
|
||||
Path in path column can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import add_dataframe_to_tree_by_path, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> path_data = pd.DataFrame([
|
||||
... ["a", 90],
|
||||
... ["a/b", 65],
|
||||
... ["a/c", 60],
|
||||
... ["a/b/d", 40],
|
||||
... ["a/b/e", 35],
|
||||
... ["a/c/f", 38],
|
||||
... ["a/b/e/g", 10],
|
||||
... ["a/b/e/h", 6],
|
||||
... ],
|
||||
... columns=["PATH", "age"]
|
||||
... )
|
||||
>>> root = add_dataframe_to_tree_by_path(root, path_data)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
│ ├── g [age=10]
|
||||
│ └── h [age=6]
|
||||
└── c [age=60]
|
||||
└── f [age=38]
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
data (pandas.DataFrame): data containing node path and attribute information
|
||||
path_col (str): column of data containing `path_name` information,
|
||||
if not set, it will take the first column of data
|
||||
attribute_cols (list): columns of data containing node attribute information,
|
||||
if not set, it will take all columns of data except `path_col`
|
||||
sep (str): path separator for input `path_col`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not path_col:
|
||||
path_col = data.columns[0]
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(path_col)
|
||||
|
||||
tree_root = tree.root
|
||||
data[path_col] = data[path_col].str.lstrip(sep).str.rstrip(sep)
|
||||
data2 = data.copy()[[path_col] + attribute_cols].astype(str).drop_duplicates()
|
||||
_duplicate_check = (
|
||||
data2[path_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(path_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate path with different attributes\nCheck {_duplicate_check}"
|
||||
)
|
||||
|
||||
for row in data.to_dict(orient="index").values():
|
||||
node_attrs = row.copy()
|
||||
del node_attrs[path_col]
|
||||
node_attrs = {k: v for k, v in node_attrs.items() if not np.all(pd.isnull(v))}
|
||||
add_path_to_tree(
|
||||
tree_root,
|
||||
row[path_col],
|
||||
sep=sep,
|
||||
duplicate_name_allowed=duplicate_name_allowed,
|
||||
node_attrs=node_attrs,
|
||||
)
|
||||
return tree_root
|
||||
|
||||
|
||||
def add_dataframe_to_tree_by_name(
|
||||
tree: Node,
|
||||
data: pd.DataFrame,
|
||||
name_col: str = "",
|
||||
attribute_cols: list = [],
|
||||
join_type: str = "left",
|
||||
):
|
||||
"""Add attributes to tree, return *new* root of tree.
|
||||
|
||||
`name_col` and `attribute_cols` specify columns for node name and attributes to add to existing tree.
|
||||
If columns are not specified, the first column will be taken as name column and all other columns as attributes.
|
||||
|
||||
Function can return all existing tree nodes or only tree nodes that are in the input data node names.
|
||||
Input data node names that are not existing node names will be ignored.
|
||||
Note that if multiple nodes have the same name, attributes will be added to all nodes sharing same name.
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import add_dataframe_to_tree_by_name, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> name_data = pd.DataFrame([
|
||||
... ["a", 90],
|
||||
... ["b", 65],
|
||||
... ],
|
||||
... columns=["NAME", "age"]
|
||||
... )
|
||||
>>> root = add_dataframe_to_tree_by_name(root, name_data)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
└── b [age=65]
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
data (pandas.DataFrame): data containing node name and attribute information
|
||||
name_col (str): column of data containing `name` information,
|
||||
if not set, it will take the first column of data
|
||||
attribute_cols (list): column(s) of data containing node attribute information,
|
||||
if not set, it will take all columns of data except path_col
|
||||
join_type (str): join type with attribute, default of 'left' takes existing tree nodes,
|
||||
if join_type is set to 'inner' it will only take tree nodes with attributes and drop the other nodes
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if join_type not in ["inner", "left"]:
|
||||
raise ValueError("`join_type` must be one of 'inner' or 'left'")
|
||||
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not name_col:
|
||||
name_col = data.columns[0]
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(name_col)
|
||||
|
||||
# Attribute data
|
||||
path_col = "PATH"
|
||||
data2 = data.copy()[[name_col] + attribute_cols].astype(str).drop_duplicates()
|
||||
_duplicate_check = (
|
||||
data2[name_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(name_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate name with different attributes\nCheck {_duplicate_check}"
|
||||
)
|
||||
|
||||
# Tree data
|
||||
tree_root = tree.root
|
||||
sep = tree_root.sep
|
||||
node_type = tree_root.__class__
|
||||
data_tree = tree_to_dataframe(
|
||||
tree_root, name_col=name_col, path_col=path_col, all_attrs=True
|
||||
)
|
||||
common_cols = list(set(data_tree.columns).intersection(attribute_cols))
|
||||
data_tree = data_tree.drop(columns=common_cols)
|
||||
|
||||
# Attribute data
|
||||
data_tree_attrs = pd.merge(data_tree, data, on=name_col, how=join_type)
|
||||
data_tree_attrs = data_tree_attrs.drop(columns=name_col)
|
||||
|
||||
return dataframe_to_tree(
|
||||
data_tree_attrs, path_col=path_col, sep=sep, node_type=node_type
|
||||
)
|
||||
|
||||
|
||||
def str_to_tree(
|
||||
tree_string: str,
|
||||
tree_prefix_list: List[str] = [],
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
r"""Construct tree from tree string
|
||||
|
||||
>>> from bigtree import str_to_tree, print_tree
|
||||
>>> tree_str = 'a\n├── b\n│ ├── d\n│ └── e\n│ ├── g\n│ └── h\n└── c\n └── f'
|
||||
>>> root = str_to_tree(tree_str, tree_prefix_list=["├──", "└──"])
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
Args:
|
||||
tree_string (str): String to construct tree
|
||||
tree_prefix_list (list): List of prefix to mark the end of tree branch/stem and start of node name, optional.
|
||||
If not specified, it will infer unicode characters and whitespace as prefix.
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
tree_string = tree_string.strip("\n")
|
||||
if not len(tree_string):
|
||||
raise ValueError("Tree string does not contain any data, check `tree_string`")
|
||||
tree_list = tree_string.split("\n")
|
||||
tree_root = node_type(tree_list[0])
|
||||
|
||||
# Infer prefix length
|
||||
prefix_length = None
|
||||
cur_parent = tree_root
|
||||
for node_str in tree_list[1:]:
|
||||
if len(tree_prefix_list):
|
||||
node_name = re.split("|".join(tree_prefix_list), node_str)[-1].lstrip()
|
||||
else:
|
||||
node_name = node_str.encode("ascii", "ignore").decode("ascii").lstrip()
|
||||
|
||||
# Find node parent
|
||||
if not prefix_length:
|
||||
prefix_length = node_str.index(node_name)
|
||||
if not prefix_length:
|
||||
raise ValueError(
|
||||
f"Invalid prefix, prefix should be unicode character or whitespace, "
|
||||
f"otherwise specify one or more prefixes in `tree_prefix_list`, check: {node_str}"
|
||||
)
|
||||
node_prefix_length = node_str.index(node_name)
|
||||
if node_prefix_length % prefix_length:
|
||||
raise ValueError(
|
||||
f"Tree string have different prefix length, check branch: {node_str}"
|
||||
)
|
||||
while cur_parent.depth > node_prefix_length / prefix_length:
|
||||
cur_parent = cur_parent.parent
|
||||
|
||||
# Link node
|
||||
child_node = node_type(node_name)
|
||||
child_node.parent = cur_parent
|
||||
cur_parent = child_node
|
||||
|
||||
return tree_root
|
||||
|
||||
|
||||
def list_to_tree(
|
||||
paths: list,
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from list of path strings.
|
||||
|
||||
Path should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
|
||||
Path can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> from bigtree import list_to_tree, print_tree
|
||||
>>> path_list = ["a/b", "a/c", "a/b/d", "a/b/e", "a/c/f", "a/b/e/g", "a/b/e/h"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
Args:
|
||||
paths (list): list containing path strings
|
||||
sep (str): path separator for input `paths` and created tree, defaults to `/`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(paths):
|
||||
raise ValueError("Path list does not contain any data, check `paths`")
|
||||
|
||||
# Remove duplicates
|
||||
paths = list(OrderedDict.fromkeys(paths))
|
||||
|
||||
# Construct root node
|
||||
root_name = paths[0].lstrip(sep).split(sep)[0]
|
||||
root_node = node_type(root_name)
|
||||
root_node.sep = sep
|
||||
|
||||
for path in paths:
|
||||
add_path_to_tree(
|
||||
root_node, path, sep=sep, duplicate_name_allowed=duplicate_name_allowed
|
||||
)
|
||||
root_node.sep = sep
|
||||
return root_node
|
||||
|
||||
|
||||
def list_to_tree_by_relation(
|
||||
relations: List[Tuple[str, str]],
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from list of tuple containing parent-child names.
|
||||
|
||||
Note that node names must be unique since tree is created from parent-child names,
|
||||
except for leaf nodes - names of leaf nodes may be repeated as there is no confusion.
|
||||
|
||||
>>> from bigtree import list_to_tree_by_relation, print_tree
|
||||
>>> relations_list = [("a", "b"), ("a", "c"), ("b", "d"), ("b", "e"), ("c", "f"), ("e", "g"), ("e", "h")]
|
||||
>>> root = list_to_tree_by_relation(relations_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
Args:
|
||||
relations (list): list containing tuple containing parent-child names
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(relations):
|
||||
raise ValueError("Path list does not contain any data, check `relations`")
|
||||
|
||||
relation_data = pd.DataFrame(relations, columns=["parent", "child"])
|
||||
return dataframe_to_tree_by_relation(
|
||||
relation_data, child_col="child", parent_col="parent", node_type=node_type
|
||||
)
|
||||
|
||||
|
||||
def dict_to_tree(
|
||||
path_attrs: dict,
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from nested dictionary using path,
|
||||
``key``: path, ``value``: dict of attribute name and attribute value.
|
||||
|
||||
Path should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
|
||||
Path can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> from bigtree import dict_to_tree, print_tree
|
||||
>>> path_dict = {
|
||||
... "a": {"age": 90},
|
||||
... "a/b": {"age": 65},
|
||||
... "a/c": {"age": 60},
|
||||
... "a/b/d": {"age": 40},
|
||||
... "a/b/e": {"age": 35},
|
||||
... "a/c/f": {"age": 38},
|
||||
... "a/b/e/g": {"age": 10},
|
||||
... "a/b/e/h": {"age": 6},
|
||||
... }
|
||||
>>> root = dict_to_tree(path_dict)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
│ ├── g [age=10]
|
||||
│ └── h [age=6]
|
||||
└── c [age=60]
|
||||
└── f [age=38]
|
||||
|
||||
Args:
|
||||
path_attrs (dict): dictionary containing path and node attribute information,
|
||||
key: path, value: dict of tree attribute and attribute value
|
||||
sep (str): path separator of input `path_attrs` and created tree, defaults to `/`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(path_attrs):
|
||||
raise ValueError("Dictionary does not contain any data, check `path_attrs`")
|
||||
|
||||
# Convert dictionary to dataframe
|
||||
data = pd.DataFrame(path_attrs).T.rename_axis("PATH").reset_index()
|
||||
return dataframe_to_tree(
|
||||
data,
|
||||
sep=sep,
|
||||
duplicate_name_allowed=duplicate_name_allowed,
|
||||
node_type=node_type,
|
||||
)
|
||||
|
||||
|
||||
def nested_dict_to_tree(
|
||||
node_attrs: dict,
|
||||
name_key: str = "name",
|
||||
child_key: str = "children",
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from nested recursive dictionary.
|
||||
- ``key``: `name_key`, `child_key`, or any attributes key.
|
||||
- ``value`` of `name_key` (str): node name.
|
||||
- ``value`` of `child_key` (list): list of dict containing `name_key` and `child_key` (recursive).
|
||||
|
||||
>>> from bigtree import nested_dict_to_tree, print_tree
|
||||
>>> path_dict = {
|
||||
... "name": "a",
|
||||
... "age": 90,
|
||||
... "children": [
|
||||
... {"name": "b",
|
||||
... "age": 65,
|
||||
... "children": [
|
||||
... {"name": "d", "age": 40},
|
||||
... {"name": "e", "age": 35, "children": [
|
||||
... {"name": "g", "age": 10},
|
||||
... ]},
|
||||
... ]},
|
||||
... ],
|
||||
... }
|
||||
>>> root = nested_dict_to_tree(path_dict)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
└── b [age=65]
|
||||
├── d [age=40]
|
||||
└── e [age=35]
|
||||
└── g [age=10]
|
||||
|
||||
Args:
|
||||
node_attrs (dict): dictionary containing node, children, and node attribute information,
|
||||
key: `name_key` and `child_key`
|
||||
value of `name_key` (str): node name
|
||||
value of `child_key` (list): list of dict containing `name_key` and `child_key` (recursive)
|
||||
name_key (str): key of node name, value is type str
|
||||
child_key (str): key of child list, value is type list
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
|
||||
def recursive_add_child(child_dict, parent_node=None):
|
||||
child_dict = child_dict.copy()
|
||||
node_name = child_dict.pop(name_key)
|
||||
node_children = child_dict.pop(child_key, [])
|
||||
node = node_type(node_name, parent=parent_node, **child_dict)
|
||||
for _child in node_children:
|
||||
recursive_add_child(_child, parent_node=node)
|
||||
return node
|
||||
|
||||
root_node = recursive_add_child(node_attrs)
|
||||
return root_node
|
||||
|
||||
|
||||
def dataframe_to_tree(
|
||||
data: pd.DataFrame,
|
||||
path_col: str = "",
|
||||
attribute_cols: list = [],
|
||||
sep: str = "/",
|
||||
duplicate_name_allowed: bool = True,
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from pandas DataFrame using path, return root of tree.
|
||||
|
||||
`path_col` and `attribute_cols` specify columns for node path and attributes to construct tree.
|
||||
If columns are not specified, `path_col` takes first column and all other columns are `attribute_cols`.
|
||||
|
||||
Path in path column can start from root node `name`, or start with `sep`.
|
||||
- For example: Path string can be "/a/b" or "a/b", if sep is "/".
|
||||
|
||||
Path in path column should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
|
||||
All paths should start from the same root node.
|
||||
- For example: Path strings should be "a/b", "a/c", "a/b/d" etc. and should not start with another root node.
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import dataframe_to_tree, print_tree
|
||||
>>> path_data = pd.DataFrame([
|
||||
... ["a", 90],
|
||||
... ["a/b", 65],
|
||||
... ["a/c", 60],
|
||||
... ["a/b/d", 40],
|
||||
... ["a/b/e", 35],
|
||||
... ["a/c/f", 38],
|
||||
... ["a/b/e/g", 10],
|
||||
... ["a/b/e/h", 6],
|
||||
... ],
|
||||
... columns=["PATH", "age"]
|
||||
... )
|
||||
>>> root = dataframe_to_tree(path_data)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
│ ├── g [age=10]
|
||||
│ └── h [age=6]
|
||||
└── c [age=60]
|
||||
└── f [age=38]
|
||||
|
||||
Args:
|
||||
data (pandas.DataFrame): data containing path and node attribute information
|
||||
path_col (str): column of data containing `path_name` information,
|
||||
if not set, it will take the first column of data
|
||||
attribute_cols (list): columns of data containing node attribute information,
|
||||
if not set, it will take all columns of data except `path_col`
|
||||
sep (str): path separator of input `path_col` and created tree, defaults to `/`
|
||||
duplicate_name_allowed (bool): indicator if nodes with duplicated `Node` name is allowed, defaults to True
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not path_col:
|
||||
path_col = data.columns[0]
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(path_col)
|
||||
|
||||
data[path_col] = data[path_col].str.lstrip(sep).str.rstrip(sep)
|
||||
data2 = data.copy()[[path_col] + attribute_cols].astype(str).drop_duplicates()
|
||||
_duplicate_check = (
|
||||
data2[path_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(path_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate path with different attributes\nCheck {_duplicate_check}"
|
||||
)
|
||||
|
||||
root_name = data[path_col].values[0].split(sep)[0]
|
||||
root_node = node_type(root_name)
|
||||
add_dataframe_to_tree_by_path(
|
||||
root_node,
|
||||
data,
|
||||
sep=sep,
|
||||
duplicate_name_allowed=duplicate_name_allowed,
|
||||
)
|
||||
root_node.sep = sep
|
||||
return root_node
|
||||
|
||||
|
||||
def dataframe_to_tree_by_relation(
|
||||
data: pd.DataFrame,
|
||||
child_col: str = "",
|
||||
parent_col: str = "",
|
||||
attribute_cols: list = [],
|
||||
node_type: Type[Node] = Node,
|
||||
) -> Node:
|
||||
"""Construct tree from pandas DataFrame using parent and child names, return root of tree.
|
||||
|
||||
Note that node names must be unique since tree is created from parent-child names,
|
||||
except for leaf nodes - names of leaf nodes may be repeated as there is no confusion.
|
||||
|
||||
`child_col` and `parent_col` specify columns for child name and parent name to construct tree.
|
||||
`attribute_cols` specify columns for node attribute for child name
|
||||
If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
|
||||
columns are `attribute_cols`.
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> from bigtree import dataframe_to_tree_by_relation, print_tree
|
||||
>>> relation_data = pd.DataFrame([
|
||||
... ["a", None, 90],
|
||||
... ["b", "a", 65],
|
||||
... ["c", "a", 60],
|
||||
... ["d", "b", 40],
|
||||
... ["e", "b", 35],
|
||||
... ["f", "c", 38],
|
||||
... ["g", "e", 10],
|
||||
... ["h", "e", 6],
|
||||
... ],
|
||||
... columns=["child", "parent", "age"]
|
||||
... )
|
||||
>>> root = dataframe_to_tree_by_relation(relation_data)
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
│ ├── g [age=10]
|
||||
│ └── h [age=6]
|
||||
└── c [age=60]
|
||||
└── f [age=38]
|
||||
|
||||
Args:
|
||||
data (pandas.DataFrame): data containing path and node attribute information
|
||||
child_col (str): column of data containing child name information, defaults to None
|
||||
if not set, it will take the first column of data
|
||||
parent_col (str): column of data containing parent name information, defaults to None
|
||||
if not set, it will take the second column of data
|
||||
attribute_cols (list): columns of data containing node attribute information,
|
||||
if not set, it will take all columns of data except `child_col` and `parent_col`
|
||||
node_type (Type[Node]): node type of tree to be created, defaults to Node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
if not len(data.columns):
|
||||
raise ValueError("Data does not contain any columns, check `data`")
|
||||
if not len(data):
|
||||
raise ValueError("Data does not contain any rows, check `data`")
|
||||
|
||||
if not child_col:
|
||||
child_col = data.columns[0]
|
||||
if not parent_col:
|
||||
parent_col = data.columns[1]
|
||||
if not len(attribute_cols):
|
||||
attribute_cols = list(data.columns)
|
||||
attribute_cols.remove(child_col)
|
||||
attribute_cols.remove(parent_col)
|
||||
|
||||
data_check = data.copy()[[child_col, parent_col]].drop_duplicates()
|
||||
# Filter for child nodes that are parent of other nodes
|
||||
data_check = data_check[data_check[child_col].isin(data_check[parent_col])]
|
||||
_duplicate_check = (
|
||||
data_check[child_col]
|
||||
.value_counts()
|
||||
.to_frame("counts")
|
||||
.rename_axis(child_col)
|
||||
.reset_index()
|
||||
)
|
||||
_duplicate_check = _duplicate_check[_duplicate_check["counts"] > 1]
|
||||
if len(_duplicate_check):
|
||||
raise ValueError(
|
||||
f"There exists duplicate child with different parent where the child is also a parent node.\n"
|
||||
f"Duplicated node names should not happen, but can only exist in leaf nodes to avoid confusion.\n"
|
||||
f"Check {_duplicate_check}"
|
||||
)
|
||||
|
||||
# If parent-child contains None -> root
|
||||
root_row = data[data[parent_col].isnull()]
|
||||
root_names = list(root_row[child_col])
|
||||
if not len(root_names):
|
||||
root_names = list(set(data[parent_col]) - set(data[child_col]))
|
||||
if len(root_names) != 1:
|
||||
raise ValueError(f"Unable to determine root node\nCheck {root_names}")
|
||||
root_name = root_names[0]
|
||||
root_node = node_type(root_name)
|
||||
|
||||
def retrieve_attr(row):
|
||||
node_attrs = row.copy()
|
||||
node_attrs["name"] = node_attrs[child_col]
|
||||
del node_attrs[child_col]
|
||||
del node_attrs[parent_col]
|
||||
_node_attrs = {k: v for k, v in node_attrs.items() if not np.all(pd.isnull(v))}
|
||||
return _node_attrs
|
||||
|
||||
def recursive_create_child(parent_node):
|
||||
child_rows = data[data[parent_col] == parent_node.node_name]
|
||||
|
||||
for row in child_rows.to_dict(orient="index").values():
|
||||
child_node = node_type(**retrieve_attr(row))
|
||||
child_node.parent = parent_node
|
||||
recursive_create_child(child_node)
|
||||
|
||||
# Create root node attributes
|
||||
if len(root_row):
|
||||
row = list(root_row.to_dict(orient="index").values())[0]
|
||||
root_node.set_attrs(retrieve_attr(row))
|
||||
recursive_create_child(root_node)
|
||||
return root_node
|
||||
831
python37/packages/bigtree/tree/export.py
Normal file
831
python37/packages/bigtree/tree/export.py
Normal file
@@ -0,0 +1,831 @@
|
||||
import collections
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.search import find_path
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
__all__ = [
|
||||
"print_tree",
|
||||
"yield_tree",
|
||||
"tree_to_dict",
|
||||
"tree_to_nested_dict",
|
||||
"tree_to_dataframe",
|
||||
"tree_to_dot",
|
||||
"tree_to_pillow",
|
||||
]
|
||||
|
||||
|
||||
available_styles = {
|
||||
"ansi": ("| ", "|-- ", "`-- "),
|
||||
"ascii": ("| ", "|-- ", "+-- "),
|
||||
"const": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 "),
|
||||
"const_bold": ("\u2503 ", "\u2523\u2501\u2501 ", "\u2517\u2501\u2501 "),
|
||||
"rounded": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2570\u2500\u2500 "),
|
||||
"double": ("\u2551 ", "\u2560\u2550\u2550 ", "\u255a\u2550\u2550 "),
|
||||
"custom": ("", "", ""),
|
||||
}
|
||||
|
||||
|
||||
def print_tree(
|
||||
tree: Node,
|
||||
node_name_or_path: str = "",
|
||||
max_depth: int = None,
|
||||
attr_list: List[str] = None,
|
||||
all_attrs: bool = False,
|
||||
attr_omit_null: bool = True,
|
||||
attr_bracket: List[str] = ["[", "]"],
|
||||
style: str = "const",
|
||||
custom_style: List[str] = [],
|
||||
):
|
||||
"""Print tree to console, starting from `tree`.
|
||||
|
||||
- Able to select which node to print from, resulting in a subtree, using `node_name_or_path`
|
||||
- Able to customize for maximum depth to print, using `max_depth`
|
||||
- Able to choose which attributes to show or show all attributes, using `attr_name_filter` and `all_attrs`
|
||||
- Able to omit showing of attributes if it is null, using `attr_omit_null`
|
||||
- Able to customize open and close brackets if attributes are shown, using `attr_bracket`
|
||||
- Able to customize style, to choose from `ansi`, `ascii`, `const`, `rounded`, `double`, and `custom` style
|
||||
- Default style is `const` style
|
||||
- If style is set to custom, user can choose their own style for stem, branch and final stem icons
|
||||
- Stem, branch, and final stem symbol should have the same number of characters
|
||||
|
||||
**Printing tree**
|
||||
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
└── c
|
||||
|
||||
**Printing Sub-tree**
|
||||
|
||||
>>> print_tree(root, node_name_or_path="b")
|
||||
b
|
||||
├── d
|
||||
└── e
|
||||
|
||||
>>> print_tree(root, max_depth=2)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
|
||||
**Printing Attributes**
|
||||
|
||||
>>> print_tree(root, attr_list=["age"])
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
└── c [age=60]
|
||||
|
||||
>>> print_tree(root, attr_list=["age"], attr_bracket=["*(", ")"])
|
||||
a *(age=90)
|
||||
├── b *(age=65)
|
||||
│ ├── d *(age=40)
|
||||
│ └── e *(age=35)
|
||||
└── c *(age=60)
|
||||
|
||||
**Available Styles**
|
||||
|
||||
>>> print_tree(root, style="ansi")
|
||||
a
|
||||
|-- b
|
||||
| |-- d
|
||||
| `-- e
|
||||
`-- c
|
||||
|
||||
>>> print_tree(root, style="ascii")
|
||||
a
|
||||
|-- b
|
||||
| |-- d
|
||||
| +-- e
|
||||
+-- c
|
||||
|
||||
>>> print_tree(root, style="const")
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
└── c
|
||||
|
||||
>>> print_tree(root, style="const_bold")
|
||||
a
|
||||
┣━━ b
|
||||
┃ ┣━━ d
|
||||
┃ ┗━━ e
|
||||
┗━━ c
|
||||
|
||||
>>> print_tree(root, style="rounded")
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ ╰── e
|
||||
╰── c
|
||||
|
||||
>>> print_tree(root, style="double")
|
||||
a
|
||||
╠══ b
|
||||
║ ╠══ d
|
||||
║ ╚══ e
|
||||
╚══ c
|
||||
|
||||
Args:
|
||||
tree (Node): tree to print
|
||||
node_name_or_path (str): node to print from, becomes the root node of printing
|
||||
max_depth (int): maximum depth of tree to print, based on `depth` attribute, optional
|
||||
attr_list (list): list of node attributes to print, optional
|
||||
all_attrs (bool): indicator to show all attributes, overrides `attr_list`
|
||||
attr_omit_null (bool): indicator whether to omit showing of null attributes, defaults to True
|
||||
attr_bracket (List[str]): open and close bracket for `all_attrs` or `attr_list`
|
||||
style (str): style of print, defaults to abstract style
|
||||
custom_style (List[str]): style of stem, branch and final stem, used when `style` is set to 'custom'
|
||||
"""
|
||||
for pre_str, fill_str, _node in yield_tree(
|
||||
tree=tree,
|
||||
node_name_or_path=node_name_or_path,
|
||||
max_depth=max_depth,
|
||||
style=style,
|
||||
custom_style=custom_style,
|
||||
):
|
||||
# Get node_str (node name and attributes)
|
||||
attr_str = ""
|
||||
if all_attrs or attr_list:
|
||||
if len(attr_bracket) != 2:
|
||||
raise ValueError(
|
||||
f"Expect open and close brackets in `attr_bracket`, received {attr_bracket}"
|
||||
)
|
||||
attr_bracket_open, attr_bracket_close = attr_bracket
|
||||
if all_attrs:
|
||||
attrs = _node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
attr_str_list = [f"{k}={v}" for k, v in attrs]
|
||||
else:
|
||||
if attr_omit_null:
|
||||
attr_str_list = [
|
||||
f"{attr_name}={_node.get_attr(attr_name)}"
|
||||
for attr_name in attr_list
|
||||
if _node.get_attr(attr_name)
|
||||
]
|
||||
else:
|
||||
attr_str_list = [
|
||||
f"{attr_name}={_node.get_attr(attr_name)}"
|
||||
for attr_name in attr_list
|
||||
]
|
||||
attr_str = ", ".join(attr_str_list)
|
||||
if attr_str:
|
||||
attr_str = f" {attr_bracket_open}{attr_str}{attr_bracket_close}"
|
||||
node_str = f"{_node.node_name}{attr_str}"
|
||||
print(f"{pre_str}{fill_str}{node_str}")
|
||||
|
||||
|
||||
def yield_tree(
|
||||
tree: Node,
|
||||
node_name_or_path: str = "",
|
||||
max_depth: int = None,
|
||||
style: str = "const",
|
||||
custom_style: List[str] = [],
|
||||
):
|
||||
"""Generator method for customizing printing of tree, starting from `tree`.
|
||||
|
||||
- Able to select which node to print from, resulting in a subtree, using `node_name_or_path`
|
||||
- Able to customize for maximum depth to print, using `max_depth`
|
||||
- Able to customize style, to choose from `ansi`, `ascii`, `const`, `rounded`, `double`, and `custom` style
|
||||
- Default style is `const` style
|
||||
- If style is set to custom, user can choose their own style for stem, branch and final stem icons
|
||||
- Stem, branch, and final stem symbol should have the same number of characters
|
||||
|
||||
**Printing tree**
|
||||
|
||||
>>> from bigtree import Node, print_tree
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> for branch, stem, node in yield_tree(root):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
└── c
|
||||
|
||||
**Printing Sub-tree**
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, node_name_or_path="b"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
b
|
||||
├── d
|
||||
└── e
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, max_depth=2):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
|
||||
**Available Styles**
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="ansi"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
|-- b
|
||||
| |-- d
|
||||
| `-- e
|
||||
`-- c
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="ascii"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
|-- b
|
||||
| |-- d
|
||||
| +-- e
|
||||
+-- c
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="const"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
└── c
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="const_bold"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
┣━━ b
|
||||
┃ ┣━━ d
|
||||
┃ ┗━━ e
|
||||
┗━━ c
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="rounded"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ ╰── e
|
||||
╰── c
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="double"):
|
||||
... print(f"{branch}{stem}{node.node_name}")
|
||||
a
|
||||
╠══ b
|
||||
║ ╠══ d
|
||||
║ ╚══ e
|
||||
╚══ c
|
||||
|
||||
**Printing Attributes**
|
||||
|
||||
>>> for branch, stem, node in yield_tree(root, style="const"):
|
||||
... print(f"{branch}{stem}{node.node_name} [age={node.age}]")
|
||||
a [age=90]
|
||||
├── b [age=65]
|
||||
│ ├── d [age=40]
|
||||
│ └── e [age=35]
|
||||
└── c [age=60]
|
||||
|
||||
Args:
|
||||
tree (Node): tree to print
|
||||
node_name_or_path (str): node to print from, becomes the root node of printing, optional
|
||||
max_depth (int): maximum depth of tree to print, based on `depth` attribute, optional
|
||||
style (str): style of print, defaults to abstract style
|
||||
custom_style (List[str]): style of stem, branch and final stem, used when `style` is set to 'custom'
|
||||
"""
|
||||
if style not in available_styles.keys():
|
||||
raise ValueError(
|
||||
f"Choose one of {available_styles.keys()} style, use `custom` to define own style"
|
||||
)
|
||||
|
||||
tree = tree.copy()
|
||||
if node_name_or_path:
|
||||
tree = find_path(tree, node_name_or_path)
|
||||
if not tree.is_root:
|
||||
tree.parent = None
|
||||
|
||||
# Set style
|
||||
if style == "custom":
|
||||
if len(custom_style) != 3:
|
||||
raise ValueError(
|
||||
"Custom style selected, please specify the style of stem, branch, and final stem in `custom_style`"
|
||||
)
|
||||
style_stem, style_branch, style_stem_final = custom_style
|
||||
else:
|
||||
style_stem, style_branch, style_stem_final = available_styles[style]
|
||||
|
||||
if not len(style_stem) == len(style_branch) == len(style_stem_final):
|
||||
raise ValueError(
|
||||
"`style_stem`, `style_branch`, and `style_stem_final` are of different length"
|
||||
)
|
||||
|
||||
gap_str = " " * len(style_stem)
|
||||
unclosed_depth = set()
|
||||
initial_depth = tree.depth
|
||||
for _node in preorder_iter(tree, max_depth=max_depth):
|
||||
pre_str = ""
|
||||
fill_str = ""
|
||||
if not _node.is_root:
|
||||
node_depth = _node.depth - initial_depth
|
||||
|
||||
# Get fill_str (style_branch or style_stem_final)
|
||||
if _node.right_sibling:
|
||||
unclosed_depth.add(node_depth)
|
||||
fill_str = style_branch
|
||||
else:
|
||||
if node_depth in unclosed_depth:
|
||||
unclosed_depth.remove(node_depth)
|
||||
fill_str = style_stem_final
|
||||
|
||||
# Get pre_str (style_stem, style_branch, style_stem_final, or gap)
|
||||
pre_str = ""
|
||||
for _depth in range(1, node_depth):
|
||||
if _depth in unclosed_depth:
|
||||
pre_str += style_stem
|
||||
else:
|
||||
pre_str += gap_str
|
||||
|
||||
yield pre_str, fill_str, _node
|
||||
|
||||
|
||||
def tree_to_dict(
|
||||
tree: Node,
|
||||
name_key: str = "name",
|
||||
parent_key: str = "",
|
||||
attr_dict: dict = {},
|
||||
all_attrs: bool = False,
|
||||
max_depth: int = None,
|
||||
skip_depth: int = None,
|
||||
leaf_only: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Export tree to dictionary.
|
||||
|
||||
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
|
||||
|
||||
Exported dictionary will have key as node path, and node attributes as a nested dictionary.
|
||||
|
||||
>>> from bigtree import Node, tree_to_dict
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> tree_to_dict(root, name_key="name", parent_key="parent", attr_dict={"age": "person age"})
|
||||
{'/a': {'name': 'a', 'parent': None, 'person age': 90}, '/a/b': {'name': 'b', 'parent': 'a', 'person age': 65}, '/a/b/d': {'name': 'd', 'parent': 'b', 'person age': 40}, '/a/b/e': {'name': 'e', 'parent': 'b', 'person age': 35}, '/a/c': {'name': 'c', 'parent': 'a', 'person age': 60}}
|
||||
|
||||
For a subset of a tree
|
||||
|
||||
>>> tree_to_dict(c, name_key="name", parent_key="parent", attr_dict={"age": "person age"})
|
||||
{'/a/c': {'name': 'c', 'parent': 'a', 'person age': 60}}
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be exported
|
||||
name_key (str): dictionary key for `node.node_name`, defaults to 'name'
|
||||
parent_key (str): dictionary key for `node.parent.node_name`, optional
|
||||
attr_dict (dict): dictionary mapping node attributes to dictionary key,
|
||||
key: node attributes, value: corresponding dictionary key, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes
|
||||
max_depth (int): maximum depth to export tree, optional
|
||||
skip_depth (int): number of initial depth to skip, optional
|
||||
leaf_only (bool): indicator to retrieve only information from leaf nodes
|
||||
|
||||
Returns:
|
||||
(dict)
|
||||
"""
|
||||
tree = tree.copy()
|
||||
data_dict = {}
|
||||
|
||||
def recursive_append(node):
|
||||
if node:
|
||||
if (
|
||||
(not max_depth or node.depth <= max_depth)
|
||||
and (not skip_depth or node.depth > skip_depth)
|
||||
and (not leaf_only or node.is_leaf)
|
||||
):
|
||||
data_child = {}
|
||||
if name_key:
|
||||
data_child[name_key] = node.node_name
|
||||
if parent_key:
|
||||
parent_name = None
|
||||
if node.parent:
|
||||
parent_name = node.parent.node_name
|
||||
data_child[parent_key] = parent_name
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
dict(
|
||||
node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = node.get_attr(k)
|
||||
data_dict[node.path_name] = data_child
|
||||
for _node in node.children:
|
||||
recursive_append(_node)
|
||||
|
||||
recursive_append(tree)
|
||||
return data_dict
|
||||
|
||||
|
||||
def tree_to_nested_dict(
|
||||
tree: Node,
|
||||
name_key: str = "name",
|
||||
child_key: str = "children",
|
||||
attr_dict: dict = {},
|
||||
all_attrs: bool = False,
|
||||
max_depth: int = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Export tree to nested dictionary.
|
||||
|
||||
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
|
||||
|
||||
Exported dictionary will have key as node attribute names, and children as a nested recursive dictionary.
|
||||
|
||||
>>> from bigtree import Node, tree_to_nested_dict
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> tree_to_nested_dict(root, all_attrs=True)
|
||||
{'name': 'a', 'age': 90, 'children': [{'name': 'b', 'age': 65, 'children': [{'name': 'd', 'age': 40}, {'name': 'e', 'age': 35}]}, {'name': 'c', 'age': 60}]}
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be exported
|
||||
name_key (str): dictionary key for `node.node_name`, defaults to 'name'
|
||||
child_key (str): dictionary key for list of children, optional
|
||||
attr_dict (dict): dictionary mapping node attributes to dictionary key,
|
||||
key: node attributes, value: corresponding dictionary key, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes
|
||||
max_depth (int): maximum depth to export tree, optional
|
||||
|
||||
Returns:
|
||||
(dict)
|
||||
"""
|
||||
tree = tree.copy()
|
||||
data_dict = {}
|
||||
|
||||
def recursive_append(node, parent_dict):
|
||||
if node:
|
||||
if not max_depth or node.depth <= max_depth:
|
||||
data_child = {name_key: node.node_name}
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
dict(
|
||||
node.describe(
|
||||
exclude_attributes=["name"], exclude_prefix="_"
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = node.get_attr(k)
|
||||
if child_key in parent_dict:
|
||||
parent_dict[child_key].append(data_child)
|
||||
else:
|
||||
parent_dict[child_key] = [data_child]
|
||||
|
||||
for _node in node.children:
|
||||
recursive_append(_node, data_child)
|
||||
|
||||
recursive_append(tree, data_dict)
|
||||
return data_dict[child_key][0]
|
||||
|
||||
|
||||
def tree_to_dataframe(
|
||||
tree: Node,
|
||||
path_col: str = "path",
|
||||
name_col: str = "name",
|
||||
parent_col: str = "",
|
||||
attr_dict: dict = {},
|
||||
all_attrs: bool = False,
|
||||
max_depth: int = None,
|
||||
skip_depth: int = None,
|
||||
leaf_only: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Export tree to pandas DataFrame.
|
||||
|
||||
All descendants from `tree` will be exported, `tree` can be the root node or child node of tree.
|
||||
|
||||
>>> from bigtree import Node, tree_to_dataframe
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> tree_to_dataframe(root, name_col="name", parent_col="parent", path_col="path", attr_dict={"age": "person age"})
|
||||
path name parent person age
|
||||
0 /a a None 90
|
||||
1 /a/b b a 65
|
||||
2 /a/b/d d b 40
|
||||
3 /a/b/e e b 35
|
||||
4 /a/c c a 60
|
||||
|
||||
|
||||
For a subset of a tree.
|
||||
|
||||
>>> tree_to_dataframe(b, name_col="name", parent_col="parent", path_col="path", attr_dict={"age": "person age"})
|
||||
path name parent person age
|
||||
0 /a/b b a 65
|
||||
1 /a/b/d d b 40
|
||||
2 /a/b/e e b 35
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be exported
|
||||
path_col (str): column name for `node.path_name`, optional
|
||||
name_col (str): column name for `node.node_name`, defaults to 'name'
|
||||
parent_col (str): column name for `node.parent.node_name`, optional
|
||||
attr_dict (dict): dictionary mapping node attributes to column name,
|
||||
key: node attributes, value: corresponding column in dataframe, optional
|
||||
all_attrs (bool): indicator whether to retrieve all `Node` attributes
|
||||
max_depth (int): maximum depth to export tree, optional
|
||||
skip_depth (int): number of initial depth to skip, optional
|
||||
leaf_only (bool): indicator to retrieve only information from leaf nodes
|
||||
|
||||
Returns:
|
||||
(pd.DataFrame)
|
||||
"""
|
||||
tree = tree.copy()
|
||||
data_list = []
|
||||
|
||||
def recursive_append(node):
|
||||
if node:
|
||||
if (
|
||||
(not max_depth or node.depth <= max_depth)
|
||||
and (not skip_depth or node.depth > skip_depth)
|
||||
and (not leaf_only or node.is_leaf)
|
||||
):
|
||||
data_child = {}
|
||||
if path_col:
|
||||
data_child[path_col] = node.path_name
|
||||
if name_col:
|
||||
data_child[name_col] = node.node_name
|
||||
if parent_col:
|
||||
parent_name = None
|
||||
if node.parent:
|
||||
parent_name = node.parent.node_name
|
||||
data_child[parent_col] = parent_name
|
||||
|
||||
if all_attrs:
|
||||
data_child.update(
|
||||
node.describe(exclude_attributes=["name"], exclude_prefix="_")
|
||||
)
|
||||
else:
|
||||
for k, v in attr_dict.items():
|
||||
data_child[v] = node.get_attr(k)
|
||||
data_list.append(data_child)
|
||||
for _node in node.children:
|
||||
recursive_append(_node)
|
||||
|
||||
recursive_append(tree)
|
||||
return pd.DataFrame(data_list)
|
||||
|
||||
|
||||
def tree_to_dot(
|
||||
tree: Union[Node, List[Node]],
|
||||
directed: bool = True,
|
||||
rankdir: str = "TB",
|
||||
bg_colour: str = None,
|
||||
node_colour: str = None,
|
||||
node_shape: str = None,
|
||||
edge_colour: str = None,
|
||||
node_attr: str = None,
|
||||
edge_attr: str = None,
|
||||
):
|
||||
r"""Export tree or list of trees to image.
|
||||
Posible node attributes include style, fillcolor, shape.
|
||||
|
||||
>>> from bigtree import Node, tree_to_dot
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> graph = tree_to_dot(root)
|
||||
|
||||
Export to image, dot file, etc.
|
||||
|
||||
>>> graph.write_png("tree.png")
|
||||
>>> graph.write_dot("tree.dot")
|
||||
|
||||
Export to string
|
||||
|
||||
>>> graph.to_string()
|
||||
'strict digraph G {\nrankdir=TB;\na0 [label=a];\nb0 [label=b];\na0 -> b0;\nd0 [label=d];\nb0 -> d0;\ne0 [label=e];\nb0 -> e0;\nc0 [label=c];\na0 -> c0;\n}\n'
|
||||
|
||||
Defining node and edge attributes
|
||||
|
||||
>>> class CustomNode(Node):
|
||||
... def __init__(self, name, node_shape="", edge_label="", **kwargs):
|
||||
... super().__init__(name, **kwargs)
|
||||
... self.node_shape = node_shape
|
||||
... self.edge_label = edge_label
|
||||
...
|
||||
... @property
|
||||
... def edge_attr(self):
|
||||
... if self.edge_label:
|
||||
... return {"label": self.edge_label}
|
||||
... return {}
|
||||
...
|
||||
... @property
|
||||
... def node_attr(self):
|
||||
... if self.node_shape:
|
||||
... return {"shape": self.node_shape}
|
||||
... return {}
|
||||
>>>
|
||||
>>>
|
||||
>>> root = CustomNode("a", node_shape="circle")
|
||||
>>> b = CustomNode("b", edge_label="child", parent=root)
|
||||
>>> c = CustomNode("c", edge_label="child", parent=root)
|
||||
>>> d = CustomNode("d", node_shape="square", edge_label="child", parent=b)
|
||||
>>> e = CustomNode("e", node_shape="square", edge_label="child", parent=b)
|
||||
>>> graph = tree_to_dot(root, node_colour="gold", node_shape="diamond", node_attr="node_attr", edge_attr="edge_attr")
|
||||
>>> graph.write_png("assets/custom_tree.png")
|
||||
|
||||
.. image:: https://github.com/kayjan/bigtree/raw/master/assets/custom_tree.png
|
||||
|
||||
Args:
|
||||
tree (Node/List[Node]): tree or list of trees to be exported
|
||||
directed (bool): indicator whether graph should be directed or undirected, defaults to True
|
||||
rankdir (str): set direction of graph layout, defaults to 'TB' (top to bottom), can be 'BT' (bottom to top),
|
||||
'LR' (left to right), 'RL' (right to left)
|
||||
bg_colour (str): background color of image, defaults to None
|
||||
node_colour (str): fill colour of nodes, defaults to None
|
||||
node_shape (str): shape of nodes, defaults to None
|
||||
Possible node_shape include "circle", "square", "diamond", "triangle"
|
||||
edge_colour (str): colour of edges, defaults to None
|
||||
node_attr (str): `Node` attribute for node style, overrides `node_colour` and `node_shape`, defaults to None.
|
||||
Possible node style (attribute value) include {"style": "filled", "fillcolor": "gold", "shape": "diamond"}
|
||||
edge_attr (str): `Node` attribute for edge style, overrides `edge_colour`, defaults to None.
|
||||
Possible edge style (attribute value) include {"style": "bold", "label": "edge label", "color": "black"}
|
||||
|
||||
Returns:
|
||||
(pydot.Dot)
|
||||
"""
|
||||
try:
|
||||
import pydot
|
||||
except ImportError: # pragma: no cover
|
||||
raise ImportError(
|
||||
"pydot not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
|
||||
)
|
||||
|
||||
# Get style
|
||||
if bg_colour:
|
||||
graph_style = dict(bgcolor=bg_colour)
|
||||
else:
|
||||
graph_style = dict()
|
||||
|
||||
if node_colour:
|
||||
node_style = dict(style="filled", fillcolor=node_colour)
|
||||
else:
|
||||
node_style = dict()
|
||||
|
||||
if node_shape:
|
||||
node_style["shape"] = node_shape
|
||||
|
||||
if edge_colour:
|
||||
edge_style = dict(color=edge_colour)
|
||||
else:
|
||||
edge_style = dict()
|
||||
|
||||
tree = tree.copy()
|
||||
|
||||
if directed:
|
||||
_graph = pydot.Dot(
|
||||
graph_type="digraph", strict=True, rankdir=rankdir, **graph_style
|
||||
)
|
||||
else:
|
||||
_graph = pydot.Dot(
|
||||
graph_type="graph", strict=True, rankdir=rankdir, **graph_style
|
||||
)
|
||||
|
||||
if not isinstance(tree, list):
|
||||
tree = [tree]
|
||||
|
||||
for _tree in tree:
|
||||
if not isinstance(_tree, Node):
|
||||
raise ValueError("Tree should be of type `Node`, or inherit from `Node`")
|
||||
|
||||
name_dict = collections.defaultdict(list)
|
||||
|
||||
def recursive_create_node_and_edges(parent_name, child_node):
|
||||
_node_style = node_style.copy()
|
||||
_edge_style = edge_style.copy()
|
||||
|
||||
child_label = child_node.node_name
|
||||
if child_node.path_name not in name_dict[child_label]: # pragma: no cover
|
||||
name_dict[child_label].append(child_node.path_name)
|
||||
child_name = child_label + str(
|
||||
name_dict[child_label].index(child_node.path_name)
|
||||
)
|
||||
if node_attr and child_node.get_attr(node_attr):
|
||||
_node_style.update(child_node.get_attr(node_attr))
|
||||
if edge_attr:
|
||||
_edge_style.update(child_node.get_attr(edge_attr))
|
||||
node = pydot.Node(name=child_name, label=child_label, **_node_style)
|
||||
_graph.add_node(node)
|
||||
if parent_name is not None:
|
||||
edge = pydot.Edge(parent_name, child_name, **_edge_style)
|
||||
_graph.add_edge(edge)
|
||||
for child in child_node.children:
|
||||
if child:
|
||||
recursive_create_node_and_edges(child_name, child)
|
||||
|
||||
recursive_create_node_and_edges(None, _tree.root)
|
||||
return _graph
|
||||
|
||||
|
||||
def tree_to_pillow(
|
||||
tree: Node,
|
||||
width: int = 0,
|
||||
height: int = 0,
|
||||
start_pos: Tuple[float, float] = (10, 10),
|
||||
font_family: str = "assets/DejaVuSans.ttf",
|
||||
font_size: int = 12,
|
||||
font_colour: Union[Tuple[float, float, float], str] = "black",
|
||||
bg_colour: Union[Tuple[float, float, float], str] = "white",
|
||||
**kwargs,
|
||||
):
|
||||
"""Export tree to image (JPG, PNG).
|
||||
Image will be similar format as `print_tree`, accepts additional keyword arguments as input to `yield_tree`
|
||||
|
||||
>>> from bigtree import Node, tree_to_pillow
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=b)
|
||||
>>> e = Node("e", age=35, parent=b)
|
||||
>>> pillow_image = tree_to_pillow(root)
|
||||
|
||||
Export to image (PNG, JPG) file, etc.
|
||||
|
||||
>>> pillow_image.save("tree_pillow.png")
|
||||
>>> pillow_image.save("tree_pillow.jpg")
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be exported
|
||||
width (int): width of image, optional as width of image is calculated automatically
|
||||
height (int): height of image, optional as height of image is calculated automatically
|
||||
start_pos (Tuple[float, float]): start position of text, (x-offset, y-offset), defaults to (10, 10)
|
||||
font_family (str): file path of font family, requires .ttf file, defaults to DejaVuSans
|
||||
font_size (int): font size, defaults to 12
|
||||
font_colour (Union[List[int], str]): font colour, accepts tuple of RGB values or string, defaults to black
|
||||
bg_colour (Union[List[int], str]): background of image, accepts tuple of RGB values or string, defaults to white
|
||||
|
||||
Returns:
|
||||
(PIL.Image.Image)
|
||||
"""
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
except ImportError: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Pillow not available. Please perform a\n\npip install 'bigtree[image]'\n\nto install required dependencies"
|
||||
)
|
||||
|
||||
# Initialize font
|
||||
font = ImageFont.truetype(font_family, font_size)
|
||||
|
||||
# Initialize text
|
||||
image_text = []
|
||||
for branch, stem, node in yield_tree(tree, **kwargs):
|
||||
image_text.append(f"{branch}{stem}{node.node_name}\n")
|
||||
|
||||
# Calculate image dimension from text, otherwise override with argument
|
||||
def get_list_of_text_dimensions(text_list):
|
||||
"""Get list dimensions
|
||||
|
||||
Args:
|
||||
text_list (List[str]): list of texts
|
||||
|
||||
Returns:
|
||||
(List[Iterable[int]]): list of (left, top, right, bottom) bounding box
|
||||
"""
|
||||
_image = Image.new("RGB", (0, 0))
|
||||
_draw = ImageDraw.Draw(_image)
|
||||
return [_draw.textbbox((0, 0), text_line, font=font) for text_line in text_list]
|
||||
|
||||
text_dimensions = get_list_of_text_dimensions(image_text)
|
||||
text_height = sum(
|
||||
[text_dimension[3] + text_dimension[1] for text_dimension in text_dimensions]
|
||||
)
|
||||
text_width = max(
|
||||
[text_dimension[2] + text_dimension[0] for text_dimension in text_dimensions]
|
||||
)
|
||||
image_text = "".join(image_text)
|
||||
width = max(width, text_width + 2 * start_pos[0])
|
||||
height = max(height, text_height + 2 * start_pos[1])
|
||||
|
||||
# Initialize and draw image
|
||||
image = Image.new("RGB", (width, height), bg_colour)
|
||||
image_draw = ImageDraw.Draw(image)
|
||||
image_draw.text(start_pos, image_text, font=font, fill=font_colour)
|
||||
return image
|
||||
201
python37/packages/bigtree/tree/helper.py
Normal file
201
python37/packages/bigtree/tree/helper.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.binarynode import BinaryNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.construct import dataframe_to_tree
|
||||
from bigtree.tree.export import tree_to_dataframe
|
||||
from bigtree.tree.search import find_path
|
||||
from bigtree.utils.exceptions import NotFoundError
|
||||
|
||||
__all__ = ["clone_tree", "prune_tree", "get_tree_diff"]
|
||||
|
||||
|
||||
def clone_tree(tree: BaseNode, node_type: Type[BaseNode]) -> BaseNode:
|
||||
"""Clone tree to another `Node` type.
|
||||
If the same type is needed, simply do a tree.copy().
|
||||
|
||||
>>> from bigtree import BaseNode, Node, clone_tree
|
||||
>>> root = BaseNode(name="a")
|
||||
>>> b = BaseNode(name="b", parent=root)
|
||||
>>> clone_tree(root, Node)
|
||||
Node(/a, )
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to be cloned, must inherit from BaseNode
|
||||
node_type (Type[BaseNode]): type of cloned tree
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
if not isinstance(tree, BaseNode):
|
||||
raise ValueError(
|
||||
"Tree should be of type `BaseNode`, or inherit from `BaseNode`"
|
||||
)
|
||||
|
||||
# Start from root
|
||||
root_info = dict(tree.root.describe(exclude_prefix="_"))
|
||||
root_node = node_type(**root_info)
|
||||
|
||||
def recursive_add_child(_new_parent_node, _parent_node):
|
||||
for _child in _parent_node.children:
|
||||
if _child:
|
||||
child_info = dict(_child.describe(exclude_prefix="_"))
|
||||
child_node = node_type(**child_info)
|
||||
child_node.parent = _new_parent_node
|
||||
recursive_add_child(child_node, _child)
|
||||
|
||||
recursive_add_child(root_node, tree.root)
|
||||
return root_node
|
||||
|
||||
|
||||
def prune_tree(tree: Node, prune_path: str, sep: str = "/") -> Node:
|
||||
"""Prune tree to leave only the prune path, returns the root of a *copy* of the original tree.
|
||||
|
||||
All siblings along the prune path will be removed.
|
||||
Prune path name should be unique, can be full path or partial path (trailing part of path) or node name.
|
||||
|
||||
Path should contain `Node` name, separated by `sep`.
|
||||
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
|
||||
|
||||
>>> from bigtree import Node, prune_tree, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
|
||||
>>> root_pruned = prune_tree(root, "a/b")
|
||||
>>> print_tree(root_pruned)
|
||||
a
|
||||
└── b
|
||||
|
||||
Args:
|
||||
tree (Node): existing tree
|
||||
prune_path (str): prune path, all siblings along the prune path will be removed
|
||||
sep (str): path separator
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
prune_path = prune_path.replace(sep, tree.sep)
|
||||
tree_copy = tree.copy()
|
||||
child = find_path(tree_copy, prune_path)
|
||||
if not child:
|
||||
raise NotFoundError(
|
||||
f"Cannot find any node matching path_name ending with {prune_path}"
|
||||
)
|
||||
|
||||
if isinstance(child.parent, BinaryNode):
|
||||
while child.parent:
|
||||
child.parent.children = [child, None]
|
||||
child = child.parent
|
||||
return tree_copy
|
||||
|
||||
while child.parent:
|
||||
child.parent.children = [child]
|
||||
child = child.parent
|
||||
return tree_copy
|
||||
|
||||
|
||||
def get_tree_diff(
|
||||
tree: Node, other_tree: Node, only_diff: bool = True
|
||||
) -> Optional[Node]:
|
||||
"""Get difference of `tree` to `other_tree`, changes are relative to `tree`.
|
||||
|
||||
(+) and (-) will be added relative to `tree`.
|
||||
- For example: (+) refers to nodes that are in `other_tree` but not `tree`.
|
||||
- For example: (-) refers to nodes that are in `tree` but not `other_tree`.
|
||||
|
||||
Note that only leaf nodes are compared and have (+) or (-) indicator. Intermediate parent nodes are not compared.
|
||||
|
||||
Function can return all original tree nodes and differences, or only the differences.
|
||||
|
||||
>>> from bigtree import Node, get_tree_diff, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=b)
|
||||
>>> e = Node("e", parent=root)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── d
|
||||
├── c
|
||||
└── e
|
||||
|
||||
>>> root_other = Node("a")
|
||||
>>> b_other = Node("b", parent=root_other)
|
||||
>>> c_other = Node("c", parent=b_other)
|
||||
>>> d_other = Node("d", parent=root_other)
|
||||
>>> e_other = Node("e", parent=root_other)
|
||||
>>> print_tree(root_other)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
├── d
|
||||
└── e
|
||||
|
||||
>>> tree_diff = get_tree_diff(root, root_other)
|
||||
>>> print_tree(tree_diff)
|
||||
a
|
||||
├── b
|
||||
│ ├── c (+)
|
||||
│ └── d (-)
|
||||
├── c (-)
|
||||
└── d (+)
|
||||
|
||||
>>> tree_diff = get_tree_diff(root, root_other, only_diff=False)
|
||||
>>> print_tree(tree_diff)
|
||||
a
|
||||
├── b
|
||||
│ ├── c (+)
|
||||
│ └── d (-)
|
||||
├── c (-)
|
||||
├── d (+)
|
||||
└── e
|
||||
|
||||
Args:
|
||||
tree (Node): tree to be compared against
|
||||
other_tree (Node): tree to be compared with
|
||||
only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
tree = tree.copy()
|
||||
other_tree = other_tree.copy()
|
||||
name_col = "name"
|
||||
path_col = "PATH"
|
||||
indicator_col = "Exists"
|
||||
|
||||
data = tree_to_dataframe(tree, name_col=name_col, path_col=path_col, leaf_only=True)
|
||||
data_other = tree_to_dataframe(
|
||||
other_tree, name_col=name_col, path_col=path_col, leaf_only=True
|
||||
)
|
||||
data_both = data[[path_col, name_col]].merge(
|
||||
data_other[[path_col, name_col]], how="outer", indicator=indicator_col
|
||||
)
|
||||
|
||||
data_both[name_col] = np.where(
|
||||
data_both[indicator_col] == "left_only",
|
||||
data_both[name_col] + " (-)",
|
||||
np.where(
|
||||
data_both[indicator_col] == "right_only",
|
||||
data_both[name_col] + " (+)",
|
||||
data_both[name_col],
|
||||
),
|
||||
)
|
||||
|
||||
if only_diff:
|
||||
data_both = data_both.query(f"{indicator_col} != 'both'")
|
||||
data_both = data_both.drop(columns=indicator_col).sort_values(path_col)
|
||||
if len(data_both):
|
||||
return dataframe_to_tree(
|
||||
data_both,
|
||||
node_type=tree.__class__,
|
||||
)
|
||||
856
python37/packages/bigtree/tree/modify.py
Normal file
856
python37/packages/bigtree/tree/modify.py
Normal file
@@ -0,0 +1,856 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.tree.search import find_path
|
||||
from bigtree.utils.exceptions import NotFoundError, TreeError
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"shift_nodes",
|
||||
"copy_nodes",
|
||||
"copy_nodes_from_tree_to_tree",
|
||||
"copy_or_shift_logic",
|
||||
]
|
||||
|
||||
|
||||
def shift_nodes(
|
||||
tree: Node,
|
||||
from_paths: List[str],
|
||||
to_paths: List[str],
|
||||
sep: str = "/",
|
||||
skippable: bool = False,
|
||||
overriding: bool = False,
|
||||
merge_children: bool = False,
|
||||
merge_leaves: bool = False,
|
||||
delete_children: bool = False,
|
||||
):
|
||||
"""Shift nodes from `from_paths` to `to_paths` *in-place*.
|
||||
|
||||
- Creates intermediate nodes if to path is not present
|
||||
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
|
||||
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
|
||||
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
|
||||
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
|
||||
- Able to shift node only and delete children, defaults to False (nodes are shifted together with children).
|
||||
|
||||
For paths in `from_paths` and `to_paths`,
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
- Path name must be unique to one node.
|
||||
|
||||
For paths in `to_paths`,
|
||||
- Can set to empty string or None to delete the path in `from_paths`, note that ``copy`` must be set to False.
|
||||
|
||||
If ``merge_children=True``,
|
||||
- If `to_path` is not present, it shifts children of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
|
||||
|
||||
If ``merge_leaves=True``,
|
||||
- If `to_path` is not present, it shifts leaves of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained,
|
||||
original node in `from_path` is retained.
|
||||
|
||||
>>> from bigtree import Node, shift_nodes, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=root)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
├── c
|
||||
└── d
|
||||
|
||||
>>> shift_nodes(root, ["a/c", "a/d"], ["a/b/c", "a/dummy/d"])
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
└── dummy
|
||||
└── d
|
||||
|
||||
To delete node,
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
|
||||
>>> shift_nodes(root, ["a/b"], [None])
|
||||
>>> print_tree(root)
|
||||
a
|
||||
└── c
|
||||
|
||||
In overriding case,
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
│ └── e
|
||||
└── c
|
||||
└── d
|
||||
|
||||
>>> shift_nodes(root, ["a/b/c"], ["a/c"], overriding=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
└── c
|
||||
└── e
|
||||
|
||||
In ``merge_children`` case, child nodes are shifted instead of the parent node.
|
||||
- If the path already exists, child nodes are merged with existing children.
|
||||
- If same node is shifted, the child nodes of the node are merged with the node's parent.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> f = Node("f", parent=root)
|
||||
>>> g = Node("g", parent=f)
|
||||
>>> h = Node("h", parent=g)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ └── d
|
||||
└── f
|
||||
└── g
|
||||
└── h
|
||||
|
||||
>>> shift_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_children=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
├── c
|
||||
│ ├── d
|
||||
│ └── e
|
||||
├── y
|
||||
└── g
|
||||
└── h
|
||||
|
||||
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
|
||||
- If the path already exists, leaf nodes are merged with existing children.
|
||||
- If same node is copied, the leaf nodes of the node are merged with the node's parent.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> f = Node("f", parent=root)
|
||||
>>> g = Node("g", parent=f)
|
||||
>>> h = Node("h", parent=g)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ └── d
|
||||
└── f
|
||||
└── g
|
||||
└── h
|
||||
|
||||
>>> shift_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_leaves=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ └── z
|
||||
├── c
|
||||
│ ├── d
|
||||
│ └── e
|
||||
├── f
|
||||
│ └── g
|
||||
├── y
|
||||
└── h
|
||||
|
||||
In ``delete_children`` case, only the node is shifted without its accompanying children/descendants.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
└── c
|
||||
└── d
|
||||
|
||||
>>> shift_nodes(root, ["a/b/z"], ["a/z"], delete_children=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
│ └── e
|
||||
├── c
|
||||
│ └── d
|
||||
└── z
|
||||
|
||||
Args:
|
||||
tree (Node): tree to modify
|
||||
from_paths (list): original paths to shift nodes from
|
||||
to_paths (list): new paths to shift nodes to
|
||||
sep (str): path separator for input paths, applies to `from_path` and `to_path`
|
||||
skippable (bool): indicator to skip if from path is not found, defaults to False
|
||||
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
|
||||
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
|
||||
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
|
||||
delete_children (bool): indicator to shift node only without children, defaults to False
|
||||
"""
|
||||
return copy_or_shift_logic(
|
||||
tree=tree,
|
||||
from_paths=from_paths,
|
||||
to_paths=to_paths,
|
||||
sep=sep,
|
||||
copy=False,
|
||||
skippable=skippable,
|
||||
overriding=overriding,
|
||||
merge_children=merge_children,
|
||||
merge_leaves=merge_leaves,
|
||||
delete_children=delete_children,
|
||||
to_tree=None,
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def copy_nodes(
|
||||
tree: Node,
|
||||
from_paths: List[str],
|
||||
to_paths: List[str],
|
||||
sep: str = "/",
|
||||
skippable: bool = False,
|
||||
overriding: bool = False,
|
||||
merge_children: bool = False,
|
||||
merge_leaves: bool = False,
|
||||
delete_children: bool = False,
|
||||
):
|
||||
"""Copy nodes from `from_paths` to `to_paths` *in-place*.
|
||||
|
||||
- Creates intermediate nodes if to path is not present
|
||||
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
|
||||
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
|
||||
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
|
||||
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
|
||||
- Able to copy node only and delete children, defaults to False (nodes are copied together with children).
|
||||
|
||||
For paths in `from_paths` and `to_paths`,
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
- Path name must be unique to one node.
|
||||
|
||||
If ``merge_children=True``,
|
||||
- If `to_path` is not present, it copies children of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
|
||||
|
||||
If ``merge_leaves=True``,
|
||||
- If `to_path` is not present, it copies leaves of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
|
||||
|
||||
>>> from bigtree import Node, copy_nodes, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=root)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
├── c
|
||||
└── d
|
||||
|
||||
>>> copy_nodes(root, ["a/c", "a/d"], ["a/b/c", "a/dummy/d"])
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
├── c
|
||||
├── d
|
||||
└── dummy
|
||||
└── d
|
||||
|
||||
In overriding case,
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
│ └── e
|
||||
└── c
|
||||
└── d
|
||||
|
||||
>>> copy_nodes(root, ["a/b/c"], ["a/c"], overriding=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ └── c
|
||||
│ └── e
|
||||
└── c
|
||||
└── e
|
||||
|
||||
In ``merge_children`` case, child nodes are copied instead of the parent node.
|
||||
- If the path already exists, child nodes are merged with existing children.
|
||||
- If same node is copied, the child nodes of the node are merged with the node's parent.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> f = Node("f", parent=root)
|
||||
>>> g = Node("g", parent=f)
|
||||
>>> h = Node("h", parent=g)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ └── d
|
||||
└── f
|
||||
└── g
|
||||
└── h
|
||||
|
||||
>>> copy_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_children=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ ├── d
|
||||
│ └── e
|
||||
├── y
|
||||
└── g
|
||||
└── h
|
||||
|
||||
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
|
||||
- If the path already exists, leaf nodes are merged with existing children.
|
||||
- If same node is copied, the leaf nodes of the node are merged with the node's parent.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> f = Node("f", parent=root)
|
||||
>>> g = Node("g", parent=f)
|
||||
>>> h = Node("h", parent=g)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ └── d
|
||||
└── f
|
||||
└── g
|
||||
└── h
|
||||
|
||||
>>> copy_nodes(root, ["a/b/c", "z", "a/f"], ["a/c", "a/z", "a/f"], merge_leaves=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ ├── d
|
||||
│ └── e
|
||||
├── f
|
||||
│ └── g
|
||||
│ └── h
|
||||
├── y
|
||||
└── h
|
||||
|
||||
In ``delete_children`` case, only the node is copied without its accompanying children/descendants.
|
||||
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> c2 = Node("c", parent=b)
|
||||
>>> e = Node("e", parent=c2)
|
||||
>>> z = Node("z", parent=b)
|
||||
>>> y = Node("y", parent=z)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
└── c
|
||||
└── d
|
||||
|
||||
>>> copy_nodes(root, ["a/b/z"], ["a/z"], delete_children=True)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── c
|
||||
│ │ └── e
|
||||
│ └── z
|
||||
│ └── y
|
||||
├── c
|
||||
│ └── d
|
||||
└── z
|
||||
|
||||
Args:
|
||||
tree (Node): tree to modify
|
||||
from_paths (list): original paths to shift nodes from
|
||||
to_paths (list): new paths to shift nodes to
|
||||
sep (str): path separator for input paths, applies to `from_path` and `to_path`
|
||||
skippable (bool): indicator to skip if from path is not found, defaults to False
|
||||
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
|
||||
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
|
||||
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
|
||||
delete_children (bool): indicator to copy node only without children, defaults to False
|
||||
"""
|
||||
return copy_or_shift_logic(
|
||||
tree=tree,
|
||||
from_paths=from_paths,
|
||||
to_paths=to_paths,
|
||||
sep=sep,
|
||||
copy=True,
|
||||
skippable=skippable,
|
||||
overriding=overriding,
|
||||
merge_children=merge_children,
|
||||
merge_leaves=merge_leaves,
|
||||
delete_children=delete_children,
|
||||
to_tree=None,
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def copy_nodes_from_tree_to_tree(
|
||||
from_tree: Node,
|
||||
to_tree: Node,
|
||||
from_paths: List[str],
|
||||
to_paths: List[str],
|
||||
sep: str = "/",
|
||||
skippable: bool = False,
|
||||
overriding: bool = False,
|
||||
merge_children: bool = False,
|
||||
merge_leaves: bool = False,
|
||||
delete_children: bool = False,
|
||||
):
|
||||
"""Copy nodes from `from_paths` to `to_paths` *in-place*.
|
||||
|
||||
- Creates intermediate nodes if to path is not present
|
||||
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable).
|
||||
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden).
|
||||
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged).
|
||||
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
|
||||
- Able to copy node only and delete children, defaults to False (nodes are copied together with children).
|
||||
|
||||
For paths in `from_paths` and `to_paths`,
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
- Path name must be unique to one node.
|
||||
|
||||
If ``merge_children=True``,
|
||||
- If `to_path` is not present, it copies children of `from_path`
|
||||
- If `to_path` is present, and ``overriding=False``, original and new children are merged
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
|
||||
|
||||
If ``merge_leaves=True``,
|
||||
- If `to_path` is not present, it copies leaves of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained.
|
||||
|
||||
>>> from bigtree import Node, copy_nodes_from_tree_to_tree, print_tree
|
||||
>>> root = Node("a")
|
||||
>>> b = Node("b", parent=root)
|
||||
>>> c = Node("c", parent=root)
|
||||
>>> d = Node("d", parent=c)
|
||||
>>> e = Node("e", parent=root)
|
||||
>>> f = Node("f", parent=e)
|
||||
>>> g = Node("g", parent=f)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
├── c
|
||||
│ └── d
|
||||
└── e
|
||||
└── f
|
||||
└── g
|
||||
|
||||
>>> root_other = Node("aa")
|
||||
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/b", "a/c", "a/e"], ["aa/b", "aa/b/c", "aa/dummy/e"])
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
├── b
|
||||
│ └── c
|
||||
│ └── d
|
||||
└── dummy
|
||||
└── e
|
||||
└── f
|
||||
└── g
|
||||
|
||||
In overriding case,
|
||||
|
||||
>>> root_other = Node("aa")
|
||||
>>> c = Node("c", parent=root_other)
|
||||
>>> e = Node("e", parent=c)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
└── c
|
||||
└── e
|
||||
|
||||
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/b", "a/c"], ["aa/b", "aa/c"], overriding=True)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
├── b
|
||||
└── c
|
||||
└── d
|
||||
|
||||
In ``merge_children`` case, child nodes are copied instead of the parent node.
|
||||
- If the path already exists, child nodes are merged with existing children.
|
||||
|
||||
>>> root_other = Node("aa")
|
||||
>>> c = Node("c", parent=root_other)
|
||||
>>> e = Node("e", parent=c)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
└── c
|
||||
└── e
|
||||
|
||||
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], merge_children=True)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
├── c
|
||||
│ ├── e
|
||||
│ └── d
|
||||
└── f
|
||||
└── g
|
||||
|
||||
In ``merge_leaves`` case, leaf nodes are copied instead of the parent node.
|
||||
- If the path already exists, leaf nodes are merged with existing children.
|
||||
|
||||
>>> root_other = Node("aa")
|
||||
>>> c = Node("c", parent=root_other)
|
||||
>>> e = Node("e", parent=c)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
└── c
|
||||
└── e
|
||||
|
||||
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], merge_leaves=True)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
├── c
|
||||
│ ├── e
|
||||
│ └── d
|
||||
└── g
|
||||
|
||||
In ``delete_children`` case, only the node is copied without its accompanying children/descendants.
|
||||
|
||||
>>> root_other = Node("aa")
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
|
||||
>>> copy_nodes_from_tree_to_tree(root, root_other, ["a/c", "e"], ["a/c", "a/e"], delete_children=True)
|
||||
>>> print_tree(root_other)
|
||||
aa
|
||||
├── c
|
||||
└── e
|
||||
|
||||
Args:
|
||||
from_tree (Node): tree to copy nodes from
|
||||
to_tree (Node): tree to copy nodes to
|
||||
from_paths (list): original paths to shift nodes from
|
||||
to_paths (list): new paths to shift nodes to
|
||||
sep (str): path separator for input paths, applies to `from_path` and `to_path`
|
||||
skippable (bool): indicator to skip if from path is not found, defaults to False
|
||||
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
|
||||
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
|
||||
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
|
||||
delete_children (bool): indicator to copy node only without children, defaults to False
|
||||
"""
|
||||
return copy_or_shift_logic(
|
||||
tree=from_tree,
|
||||
from_paths=from_paths,
|
||||
to_paths=to_paths,
|
||||
sep=sep,
|
||||
copy=True,
|
||||
skippable=skippable,
|
||||
overriding=overriding,
|
||||
merge_children=merge_children,
|
||||
merge_leaves=merge_leaves,
|
||||
delete_children=delete_children,
|
||||
to_tree=to_tree,
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def copy_or_shift_logic(
|
||||
tree: Node,
|
||||
from_paths: List[str],
|
||||
to_paths: List[str],
|
||||
sep: str = "/",
|
||||
copy: bool = False,
|
||||
skippable: bool = False,
|
||||
overriding: bool = False,
|
||||
merge_children: bool = False,
|
||||
merge_leaves: bool = False,
|
||||
delete_children: bool = False,
|
||||
to_tree: Optional[Node] = None,
|
||||
):
|
||||
"""Shift or copy nodes from `from_paths` to `to_paths` *in-place*.
|
||||
|
||||
- Creates intermediate nodes if to path is not present
|
||||
- Able to copy node, defaults to False (nodes are shifted; not copied).
|
||||
- Able to skip nodes if from path is not found, defaults to False (from-nodes must be found; not skippable)
|
||||
- Able to override existing node if it exists, defaults to False (to-nodes must not exist; not overridden)
|
||||
- Able to merge children and remove intermediate parent node, defaults to False (nodes are shifted; not merged)
|
||||
- Able to merge only leaf nodes and remove all intermediate nodes, defaults to False (nodes are shifted; not merged)
|
||||
- Able to shift/copy node only and delete children, defaults to False (nodes are shifted/copied together with children).
|
||||
- Able to shift/copy nodes from one tree to another tree, defaults to None (shifting/copying happens within same tree)
|
||||
|
||||
For paths in `from_paths` and `to_paths`,
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
- Path name must be unique to one node.
|
||||
|
||||
For paths in `to_paths`,
|
||||
- Can set to empty string or None to delete the path in `from_paths`, note that ``copy`` must be set to False.
|
||||
|
||||
If ``merge_children=True``,
|
||||
- If `to_path` is not present, it shifts/copies children of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original and new children are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new children are retained.
|
||||
|
||||
If ``merge_leaves=True``,
|
||||
- If `to_path` is not present, it shifts/copies leaves of `from_path`.
|
||||
- If `to_path` is present, and ``overriding=False``, original children and leaves are merged.
|
||||
- If `to_path` is present and ``overriding=True``, it behaves like overriding and only new leaves are retained,
|
||||
original non-leaf nodes in `from_path` are retained.
|
||||
|
||||
Args:
|
||||
tree (Node): tree to modify
|
||||
from_paths (list): original paths to shift nodes from
|
||||
to_paths (list): new paths to shift nodes to
|
||||
sep (str): path separator for input paths, applies to `from_path` and `to_path`
|
||||
copy (bool): indicator to copy node, defaults to False
|
||||
skippable (bool): indicator to skip if from path is not found, defaults to False
|
||||
overriding (bool): indicator to override existing to path if there is clashes, defaults to False
|
||||
merge_children (bool): indicator to merge children and remove intermediate parent node, defaults to False
|
||||
merge_leaves (bool): indicator to merge leaf nodes and remove intermediate parent node(s), defaults to False
|
||||
delete_children (bool): indicator to shift/copy node only without children, defaults to False
|
||||
to_tree (Node): tree to copy to, defaults to None
|
||||
"""
|
||||
if merge_children and merge_leaves:
|
||||
raise ValueError(
|
||||
"Invalid shifting, can only specify one type of merging, check `merge_children` and `merge_leaves`"
|
||||
)
|
||||
if not (isinstance(from_paths, list) and isinstance(to_paths, list)):
|
||||
raise ValueError(
|
||||
"Invalid type, `from_paths` and `to_paths` should be list type"
|
||||
)
|
||||
if len(from_paths) != len(to_paths):
|
||||
raise ValueError(
|
||||
f"Paths are different length, input `from_paths` have {len(from_paths)} entries, "
|
||||
f"while output `to_paths` have {len(to_paths)} entries"
|
||||
)
|
||||
for from_path, to_path in zip(from_paths, to_paths):
|
||||
if to_path:
|
||||
if from_path.split(sep)[-1] != to_path.split(sep)[-1]:
|
||||
raise ValueError(
|
||||
f"Unable to assign from_path {from_path} to to_path {to_path}\n"
|
||||
f"Verify that `sep` is defined correctly for path\n"
|
||||
f"Alternatively, check that `from_path` and `to_path` is reassigning the same node"
|
||||
)
|
||||
|
||||
transfer_indicator = False
|
||||
node_type = tree.__class__
|
||||
tree_sep = tree.sep
|
||||
if to_tree:
|
||||
transfer_indicator = True
|
||||
node_type = to_tree.__class__
|
||||
tree_sep = to_tree.sep
|
||||
for from_path, to_path in zip(from_paths, to_paths):
|
||||
from_path = from_path.replace(sep, tree.sep)
|
||||
from_node = find_path(tree, from_path)
|
||||
|
||||
# From node not found
|
||||
if not from_node:
|
||||
if not skippable:
|
||||
raise NotFoundError(
|
||||
f"Unable to find from_path {from_path}\n"
|
||||
f"Set `skippable` to True to skip shifting for nodes not found"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Unable to find from_path {from_path}")
|
||||
|
||||
# From node found
|
||||
else:
|
||||
# Node to be deleted
|
||||
if not to_path:
|
||||
to_node = None
|
||||
# Node to be copied/shifted
|
||||
else:
|
||||
to_path = to_path.replace(sep, tree_sep)
|
||||
if transfer_indicator:
|
||||
to_node = find_path(to_tree, to_path)
|
||||
else:
|
||||
to_node = find_path(tree, to_path)
|
||||
|
||||
# To node found
|
||||
if to_node:
|
||||
if from_node == to_node:
|
||||
if merge_children:
|
||||
parent = to_node.parent
|
||||
to_node.parent = None
|
||||
to_node = parent
|
||||
elif merge_leaves:
|
||||
to_node = to_node.parent
|
||||
else:
|
||||
raise TreeError(
|
||||
f"Attempting to shift the same node {from_node.node_name} back to the same position\n"
|
||||
f"Check from path {from_path} and to path {to_path}\n"
|
||||
f"Alternatively, set `merge_children` or `merge_leaves` to True if intermediate node is to be removed"
|
||||
)
|
||||
elif merge_children:
|
||||
# Specify override to remove existing node, else children are merged
|
||||
if not overriding:
|
||||
logging.info(
|
||||
f"Path {to_path} already exists and children are merged"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"Path {to_path} already exists and its children be overridden by the merge"
|
||||
)
|
||||
parent = to_node.parent
|
||||
to_node.parent = None
|
||||
to_node = parent
|
||||
merge_children = False
|
||||
elif merge_leaves:
|
||||
# Specify override to remove existing node, else leaves are merged
|
||||
if not overriding:
|
||||
logging.info(
|
||||
f"Path {to_path} already exists and leaves are merged"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"Path {to_path} already exists and its leaves be overridden by the merge"
|
||||
)
|
||||
del to_node.children
|
||||
else:
|
||||
if not overriding:
|
||||
raise TreeError(
|
||||
f"Path {to_path} already exists and unable to override\n"
|
||||
f"Set `overriding` to True to perform overrides\n"
|
||||
f"Alternatively, set `merge_children` to True if nodes are to be merged"
|
||||
)
|
||||
logging.info(
|
||||
f"Path {to_path} already exists and will be overridden"
|
||||
)
|
||||
parent = to_node.parent
|
||||
to_node.parent = None
|
||||
to_node = parent
|
||||
|
||||
# To node not found
|
||||
else:
|
||||
# Find parent node
|
||||
to_path_list = to_path.split(tree_sep)
|
||||
idx = 1
|
||||
to_path_parent = tree_sep.join(to_path_list[:-idx])
|
||||
if transfer_indicator:
|
||||
to_node = find_path(to_tree, to_path_parent)
|
||||
else:
|
||||
to_node = find_path(tree, to_path_parent)
|
||||
|
||||
# Create intermediate parent node, if applicable
|
||||
while (not to_node) & (idx + 1 < len(to_path_list)):
|
||||
idx += 1
|
||||
to_path_parent = sep.join(to_path_list[:-idx])
|
||||
if transfer_indicator:
|
||||
to_node = find_path(to_tree, to_path_parent)
|
||||
else:
|
||||
to_node = find_path(tree, to_path_parent)
|
||||
if not to_node:
|
||||
raise NotFoundError(
|
||||
f"Unable to find to_path {to_path}\n"
|
||||
f"Please specify valid path to shift node to"
|
||||
)
|
||||
for depth in range(len(to_path_list) - idx, len(to_path_list) - 1):
|
||||
intermediate_child_node = node_type(to_path_list[depth])
|
||||
intermediate_child_node.parent = to_node
|
||||
to_node = intermediate_child_node
|
||||
|
||||
# Reassign from_node to new parent
|
||||
if copy:
|
||||
logging.debug(f"Copying {from_node.node_name}")
|
||||
from_node = from_node.copy()
|
||||
if merge_children:
|
||||
logging.debug(
|
||||
f"Reassigning children from {from_node.node_name} to {to_node.node_name}"
|
||||
)
|
||||
for children in from_node.children:
|
||||
if delete_children:
|
||||
del children.children
|
||||
children.parent = to_node
|
||||
from_node.parent = None
|
||||
elif merge_leaves:
|
||||
logging.debug(
|
||||
f"Reassigning leaf nodes from {from_node.node_name} to {to_node.node_name}"
|
||||
)
|
||||
for children in from_node.leaves:
|
||||
children.parent = to_node
|
||||
else:
|
||||
if delete_children:
|
||||
del from_node.children
|
||||
from_node.parent = to_node
|
||||
316
python37/packages/bigtree/tree/search.py
Normal file
316
python37/packages/bigtree/tree/search.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
from bigtree.node.basenode import BaseNode
|
||||
from bigtree.node.node import Node
|
||||
from bigtree.utils.exceptions import CorruptedTreeError, SearchError
|
||||
from bigtree.utils.iterators import preorder_iter
|
||||
|
||||
__all__ = [
|
||||
"findall",
|
||||
"find",
|
||||
"find_name",
|
||||
"find_names",
|
||||
"find_full_path",
|
||||
"find_path",
|
||||
"find_paths",
|
||||
"find_attr",
|
||||
"find_attrs",
|
||||
"find_children",
|
||||
]
|
||||
|
||||
|
||||
def findall(
|
||||
tree: BaseNode,
|
||||
condition: Callable,
|
||||
max_depth: int = None,
|
||||
min_count: int = None,
|
||||
max_count: int = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search tree for nodes matching condition (callable function).
|
||||
|
||||
>>> from bigtree import Node, findall
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> findall(root, lambda node: node.age > 62)
|
||||
(Node(/a, age=90), Node(/a/b, age=65))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
min_count (int): checks for minimum number of occurrence,
|
||||
raise SearchError if number of results do not meet min_count, defaults to None
|
||||
max_count (int): checks for maximum number of occurrence,
|
||||
raise SearchError if number of results do not meet min_count, defaults to None
|
||||
|
||||
Returns:
|
||||
(tuple)
|
||||
"""
|
||||
result = tuple(preorder_iter(tree, filter_condition=condition, max_depth=max_depth))
|
||||
if min_count and len(result) < min_count:
|
||||
raise SearchError(
|
||||
f"Expected more than {min_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
if max_count and len(result) > max_count:
|
||||
raise SearchError(
|
||||
f"Expected less than {max_count} element(s), found {len(result)} elements\n{result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def find(tree: BaseNode, condition: Callable, max_depth: int = None) -> BaseNode:
|
||||
"""
|
||||
Search tree for *single node* matching condition (callable function).
|
||||
|
||||
>>> from bigtree import Node, find
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find(root, lambda node: node.age == 65)
|
||||
Node(/a/b, age=65)
|
||||
>>> find(root, lambda node: node.age > 5)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bigtree.utils.exceptions.SearchError: Expected less than 1 element(s), found 4 elements
|
||||
(Node(/a, age=90), Node(/a/b, age=65), Node(/a/c, age=60), Node(/a/c/d, age=40))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
condition (Callable): function that takes in node as argument, returns node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
result = findall(tree, condition, max_depth, max_count=1)
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
|
||||
def find_name(tree: Node, name: str, max_depth: int = None) -> Node:
|
||||
"""
|
||||
Search tree for single node matching name attribute.
|
||||
|
||||
>>> from bigtree import Node, find_name
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_name(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
name (str): value to match for name attribute
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
return find(tree, lambda node: node.node_name == name, max_depth)
|
||||
|
||||
|
||||
def find_names(tree: Node, name: str, max_depth: int = None) -> Iterable[Node]:
|
||||
"""
|
||||
Search tree for multiple node(s) matching name attribute.
|
||||
|
||||
>>> from bigtree import Node, find_names
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("b", age=40, parent=c)
|
||||
>>> find_names(root, "c")
|
||||
(Node(/a/c, age=60),)
|
||||
>>> find_names(root, "b")
|
||||
(Node(/a/b, age=65), Node(/a/c/b, age=40))
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
name (str): value to match for name attribute
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[Node])
|
||||
"""
|
||||
return findall(tree, lambda node: node.node_name == name, max_depth)
|
||||
|
||||
|
||||
def find_full_path(tree: Node, path_name: str) -> Node:
|
||||
"""
|
||||
Search tree for single node matching path attribute.
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name must be full path, works similar to `find_path` but faster.
|
||||
|
||||
>>> from bigtree import Node, find_full_path
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_full_path(root, "/a/c/d")
|
||||
Node(/a/c/d, age=40)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
path_name = path_name.rstrip(tree.sep).lstrip(tree.sep)
|
||||
path_list = path_name.split(tree.sep)
|
||||
if path_list[0] != tree.root.node_name:
|
||||
raise ValueError(
|
||||
f"Path {path_name} does not match the root node name {tree.root.node_name}"
|
||||
)
|
||||
parent_node = tree.root
|
||||
child_node = parent_node
|
||||
for child_name in path_list[1:]:
|
||||
child_node = find_children(parent_node, child_name)
|
||||
if not child_node:
|
||||
break
|
||||
parent_node = child_node
|
||||
return child_node
|
||||
|
||||
|
||||
def find_path(tree: Node, path_name: str) -> Node:
|
||||
"""
|
||||
Search tree for single node matching path attribute.
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be full path or partial path (trailing part of path) or node name.
|
||||
|
||||
>>> from bigtree import Node, find_path
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_path(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
>>> find_path(root, "/c")
|
||||
Node(/a/c, age=60)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
path_name = path_name.rstrip(tree.sep)
|
||||
return find(tree, lambda node: node.path_name.endswith(path_name))
|
||||
|
||||
|
||||
def find_paths(tree: Node, path_name: str) -> tuple:
|
||||
"""
|
||||
Search tree for multiple nodes matching path attribute.
|
||||
- Path name can be with or without leading tree path separator symbol.
|
||||
- Path name can be partial path (trailing part of path) or node name.
|
||||
|
||||
>>> from bigtree import Node, find_paths
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("c", age=40, parent=c)
|
||||
>>> find_paths(root, "/a/c")
|
||||
(Node(/a/c, age=60),)
|
||||
>>> find_paths(root, "/c")
|
||||
(Node(/a/c, age=60), Node(/a/c/c, age=40))
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search
|
||||
path_name (str): value to match (full path) or trailing part (partial path) of path_name attribute
|
||||
|
||||
Returns:
|
||||
(tuple)
|
||||
"""
|
||||
path_name = path_name.rstrip(tree.sep)
|
||||
return findall(tree, lambda node: node.path_name.endswith(path_name))
|
||||
|
||||
|
||||
def find_attr(
|
||||
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = None
|
||||
) -> BaseNode:
|
||||
"""
|
||||
Search tree for single node matching custom attribute.
|
||||
|
||||
>>> from bigtree import Node, find_attr
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_attr(root, "age", 65)
|
||||
Node(/a/b, age=65)
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
attr_name (str): attribute name to perform matching
|
||||
attr_value (Any): value to match for attr_name attribute
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(BaseNode)
|
||||
"""
|
||||
return find(
|
||||
tree, lambda node: node.__getattribute__(attr_name) == attr_value, max_depth
|
||||
)
|
||||
|
||||
|
||||
def find_attrs(
|
||||
tree: BaseNode, attr_name: str, attr_value: Any, max_depth: int = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Search tree for node(s) matching custom attribute.
|
||||
|
||||
>>> from bigtree import Node, find_attrs
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=65, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_attrs(root, "age", 65)
|
||||
(Node(/a/b, age=65), Node(/a/c, age=65))
|
||||
|
||||
Args:
|
||||
tree (BaseNode): tree to search
|
||||
attr_name (str): attribute name to perform matching
|
||||
attr_value (Any): value to match for attr_name attribute
|
||||
max_depth (int): maximum depth to search for, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(tuple)
|
||||
"""
|
||||
return findall(
|
||||
tree, lambda node: node.__getattribute__(attr_name) == attr_value, max_depth
|
||||
)
|
||||
|
||||
|
||||
def find_children(tree: Node, name: str) -> Node:
|
||||
"""
|
||||
Search tree for single node matching name attribute.
|
||||
|
||||
>>> from bigtree import Node, find_children
|
||||
>>> root = Node("a", age=90)
|
||||
>>> b = Node("b", age=65, parent=root)
|
||||
>>> c = Node("c", age=60, parent=root)
|
||||
>>> d = Node("d", age=40, parent=c)
|
||||
>>> find_children(root, "c")
|
||||
Node(/a/c, age=60)
|
||||
>>> find_children(c, "d")
|
||||
Node(/a/c/d, age=40)
|
||||
|
||||
Args:
|
||||
tree (Node): tree to search, parent node
|
||||
name (str): value to match for name attribute, child node
|
||||
|
||||
Returns:
|
||||
(Node)
|
||||
"""
|
||||
child = [node for node in tree.children if node and node.node_name == name]
|
||||
if len(child) > 1: # pragma: no cover
|
||||
raise CorruptedTreeError(
|
||||
f"There are more than one path for {child[0].path_name}, check {child}"
|
||||
)
|
||||
elif len(child):
|
||||
return child[0]
|
||||
0
python37/packages/bigtree/utils/__init__.py
Normal file
0
python37/packages/bigtree/utils/__init__.py
Normal file
32
python37/packages/bigtree/utils/exceptions.py
Normal file
32
python37/packages/bigtree/utils/exceptions.py
Normal file
@@ -0,0 +1,32 @@
|
||||
class TreeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LoopError(TreeError):
|
||||
"""Error during node creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CorruptedTreeError(TreeError):
|
||||
"""Error during node creation or tree creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DuplicatedNodeError(TreeError):
|
||||
"""Error during tree creation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotFoundError(TreeError):
|
||||
"""Error during tree creation or tree search"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SearchError(TreeError):
|
||||
"""Error during tree search"""
|
||||
|
||||
pass
|
||||
371
python37/packages/bigtree/utils/iterators.py
Normal file
371
python37/packages/bigtree/utils/iterators.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
__all__ = [
|
||||
"inorder_iter",
|
||||
"preorder_iter",
|
||||
"postorder_iter",
|
||||
"levelorder_iter",
|
||||
"levelordergroup_iter",
|
||||
"dag_iterator",
|
||||
]
|
||||
|
||||
|
||||
def inorder_iter(
|
||||
tree,
|
||||
filter_condition: Callable = None,
|
||||
max_depth: int = None,
|
||||
) -> Iterable:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
In Iteration Algorithm, LNR
|
||||
1. Recursively traverse the current node's left subtree.
|
||||
2. Visit the current node.
|
||||
3. Recursively traverse the current node's right subtree.
|
||||
|
||||
>>> from bigtree import BinaryNode, list_to_binarytree, inorder_iter, print_tree
|
||||
>>> num_list = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
>>> root = list_to_binarytree(num_list)
|
||||
>>> print_tree(root)
|
||||
1
|
||||
├── 2
|
||||
│ ├── 4
|
||||
│ │ └── 8
|
||||
│ └── 5
|
||||
└── 3
|
||||
├── 6
|
||||
└── 7
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root)]
|
||||
['8', '4', '2', '5', '1', '6', '3', '7']
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root, filter_condition=lambda x: x.node_name in ["1", "4", "3", "6", "7"])]
|
||||
['4', '1', '6', '3', '7']
|
||||
|
||||
>>> [node.node_name for node in inorder_iter(root, max_depth=3)]
|
||||
['4', '2', '5', '1', '6', '3', '7']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Callable): function that takes in node as argument, optional
|
||||
Returns node if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
if tree and (not max_depth or not tree.depth > max_depth):
|
||||
yield from inorder_iter(tree.left, filter_condition, max_depth)
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
yield from inorder_iter(tree.right, filter_condition, max_depth)
|
||||
|
||||
|
||||
def preorder_iter(
|
||||
tree,
|
||||
filter_condition: Callable = None,
|
||||
stop_condition: Callable = None,
|
||||
max_depth: int = None,
|
||||
) -> Iterable:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Pre-Order Iteration Algorithm, NLR
|
||||
1. Visit the current node.
|
||||
2. Recursively traverse the current node's left subtree.
|
||||
3. Recursively traverse the current node's right subtree.
|
||||
|
||||
It is topologically sorted because a parent node is processed before its child nodes.
|
||||
|
||||
>>> from bigtree import Node, list_to_tree, preorder_iter, print_tree
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root)]
|
||||
['a', 'b', 'd', 'e', 'g', 'h', 'c', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['a', 'd', 'e', 'g', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
|
||||
['a', 'b', 'd', 'c', 'f']
|
||||
|
||||
>>> [node.node_name for node in preorder_iter(root, max_depth=3)]
|
||||
['a', 'b', 'd', 'e', 'c', 'f']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Callable): function that takes in node as argument, optional
|
||||
Returns node if condition evaluates to `True`
|
||||
stop_condition (Callable): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
if (
|
||||
tree
|
||||
and (not max_depth or not tree.depth > max_depth)
|
||||
and (not stop_condition or not stop_condition(tree))
|
||||
):
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
for child in tree.children:
|
||||
yield from preorder_iter(child, filter_condition, stop_condition, max_depth)
|
||||
|
||||
|
||||
def postorder_iter(
|
||||
tree,
|
||||
filter_condition: Callable = None,
|
||||
stop_condition: Callable = None,
|
||||
max_depth: int = None,
|
||||
) -> Iterable:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Post-Order Iteration Algorithm, LRN
|
||||
1. Recursively traverse the current node's left subtree.
|
||||
2. Recursively traverse the current node's right subtree.
|
||||
3. Visit the current node.
|
||||
|
||||
>>> from bigtree import Node, list_to_tree, postorder_iter, print_tree
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root)]
|
||||
['d', 'g', 'h', 'e', 'b', 'f', 'c', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['d', 'g', 'e', 'f', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
|
||||
['d', 'b', 'f', 'c', 'a']
|
||||
|
||||
>>> [node.node_name for node in postorder_iter(root, max_depth=3)]
|
||||
['d', 'e', 'b', 'f', 'c', 'a']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Callable): function that takes in node as argument, optional
|
||||
Returns node if condition evaluates to `True`
|
||||
stop_condition (Callable): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, optional
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
if (
|
||||
tree
|
||||
and (not max_depth or not tree.depth > max_depth)
|
||||
and (not stop_condition or not stop_condition(tree))
|
||||
):
|
||||
for child in tree.children:
|
||||
yield from postorder_iter(
|
||||
child, filter_condition, stop_condition, max_depth
|
||||
)
|
||||
if not filter_condition or filter_condition(tree):
|
||||
yield tree
|
||||
|
||||
|
||||
def levelorder_iter(
|
||||
tree,
|
||||
filter_condition: Callable = None,
|
||||
stop_condition: Callable = None,
|
||||
max_depth: int = None,
|
||||
) -> Iterable:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Level Order Algorithm
|
||||
1. Recursively traverse the nodes on same level.
|
||||
|
||||
>>> from bigtree import Node, list_to_tree, levelorder_iter, print_tree
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root)]
|
||||
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
['a', 'd', 'e', 'f', 'g']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, stop_condition=lambda x: x.node_name=="e")]
|
||||
['a', 'b', 'c', 'd', 'f']
|
||||
|
||||
>>> [node.node_name for node in levelorder_iter(root, max_depth=3)]
|
||||
['a', 'b', 'c', 'd', 'e', 'f']
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Callable): function that takes in node as argument, optional
|
||||
Returns node if condition evaluates to `True`
|
||||
stop_condition (Callable): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[BaseNode])
|
||||
"""
|
||||
if not isinstance(tree, List):
|
||||
tree = [tree]
|
||||
next_level = []
|
||||
for _tree in tree:
|
||||
if _tree:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
yield _tree
|
||||
next_level.extend(list(_tree.children))
|
||||
if len(next_level):
|
||||
yield from levelorder_iter(
|
||||
next_level, filter_condition, stop_condition, max_depth
|
||||
)
|
||||
|
||||
|
||||
def levelordergroup_iter(
|
||||
tree,
|
||||
filter_condition: Callable = None,
|
||||
stop_condition: Callable = None,
|
||||
max_depth: int = None,
|
||||
) -> Iterable[Iterable]:
|
||||
"""Iterate through all children of a tree.
|
||||
|
||||
Level Order Group Algorithm
|
||||
1. Recursively traverse the nodes on same level, returns nodes level by level in a nested list.
|
||||
|
||||
>>> from bigtree import Node, list_to_tree, levelordergroup_iter, print_tree
|
||||
>>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"]
|
||||
>>> root = list_to_tree(path_list)
|
||||
>>> print_tree(root)
|
||||
a
|
||||
├── b
|
||||
│ ├── d
|
||||
│ └── e
|
||||
│ ├── g
|
||||
│ └── h
|
||||
└── c
|
||||
└── f
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root)]
|
||||
[['a'], ['b', 'c'], ['d', 'e', 'f'], ['g', 'h']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, filter_condition=lambda x: x.node_name in ["a", "d", "e", "f", "g"])]
|
||||
[['a'], [], ['d', 'e', 'f'], ['g']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, stop_condition=lambda x: x.node_name=="e")]
|
||||
[['a'], ['b', 'c'], ['d', 'f']]
|
||||
|
||||
>>> [[node.node_name for node in group] for group in levelordergroup_iter(root, max_depth=3)]
|
||||
[['a'], ['b', 'c'], ['d', 'e', 'f']]
|
||||
|
||||
Args:
|
||||
tree (BaseNode): input tree
|
||||
filter_condition (Callable): function that takes in node as argument, optional
|
||||
Returns node if condition evaluates to `True`
|
||||
stop_condition (Callable): function that takes in node as argument, optional
|
||||
Stops iteration if condition evaluates to `True`
|
||||
max_depth (int): maximum depth of iteration, based on `depth` attribute, defaults to None
|
||||
|
||||
Returns:
|
||||
(Iterable[Iterable])
|
||||
"""
|
||||
if not isinstance(tree, List):
|
||||
tree = [tree]
|
||||
|
||||
current_tree = []
|
||||
next_tree = []
|
||||
for _tree in tree:
|
||||
if (not max_depth or not _tree.depth > max_depth) and (
|
||||
not stop_condition or not stop_condition(_tree)
|
||||
):
|
||||
if not filter_condition or filter_condition(_tree):
|
||||
current_tree.append(_tree)
|
||||
next_tree.extend([_child for _child in _tree.children if _child])
|
||||
yield tuple(current_tree)
|
||||
if len(next_tree) and (not max_depth or not next_tree[0].depth > max_depth):
|
||||
yield from levelordergroup_iter(
|
||||
next_tree, filter_condition, stop_condition, max_depth
|
||||
)
|
||||
|
||||
|
||||
def dag_iterator(dag) -> Iterable[Tuple]:
|
||||
"""Iterate through all nodes of a Directed Acyclic Graph (DAG).
|
||||
Note that node names must be unique.
|
||||
Note that DAG must at least have two nodes to be shown on graph.
|
||||
|
||||
1. Visit the current node.
|
||||
2. Recursively traverse the current node's parents.
|
||||
3. Recursively traverse the current node's children.
|
||||
|
||||
>>> from bigtree import DAGNode, dag_iterator
|
||||
>>> a = DAGNode("a", step=1)
|
||||
>>> b = DAGNode("b", step=1)
|
||||
>>> c = DAGNode("c", step=2, parents=[a, b])
|
||||
>>> d = DAGNode("d", step=2, parents=[a, c])
|
||||
>>> e = DAGNode("e", step=3, parents=[d])
|
||||
>>> [(parent.node_name, child.node_name) for parent, child in dag_iterator(a)]
|
||||
[('a', 'c'), ('a', 'd'), ('b', 'c'), ('c', 'd'), ('d', 'e')]
|
||||
|
||||
Args:
|
||||
dag (DAGNode): input dag
|
||||
|
||||
Returns:
|
||||
(Iterable[Tuple[DAGNode, DAGNode]])
|
||||
"""
|
||||
visited_nodes = set()
|
||||
|
||||
def recursively_parse_dag(node):
|
||||
node_name = node.node_name
|
||||
visited_nodes.add(node_name)
|
||||
|
||||
# Parse upwards
|
||||
for parent in node.parents:
|
||||
parent_name = parent.node_name
|
||||
if parent_name not in visited_nodes:
|
||||
yield parent, node
|
||||
|
||||
# Parse downwards
|
||||
for child in node.children:
|
||||
child_name = child.node_name
|
||||
if child_name not in visited_nodes:
|
||||
yield node, child
|
||||
|
||||
# Parse upwards
|
||||
for parent in node.parents:
|
||||
parent_name = parent.node_name
|
||||
if parent_name not in visited_nodes:
|
||||
yield from recursively_parse_dag(parent)
|
||||
|
||||
# Parse downwards
|
||||
for child in node.children:
|
||||
child_name = child.node_name
|
||||
if child_name not in visited_nodes:
|
||||
yield from recursively_parse_dag(child)
|
||||
|
||||
yield from recursively_parse_dag(dag)
|
||||
0
python37/packages/bigtree/workflows/__init__.py
Normal file
0
python37/packages/bigtree/workflows/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user