diff --git a/workflows/utils.py b/workflows/utils.py index 9d4ff5d..e1f4596 100644 --- a/workflows/utils.py +++ b/workflows/utils.py @@ -105,3 +105,93 @@ class RenkuWorkflowBuilder: with open(filepath, 'w') as f: f.write(self.to_yaml()) +import os +import re +import yaml +from graphviz import Digraph +from IPython.display import Image + +def plot_workflow_graph(yaml_file_path, + output_dir=".", + output_name="workflow_graph", + output_format="png", + dpi=300, + show_parameters=False): + + def shorten_path(path, keep_start=1, keep_end=2): + """ + Shortens a long path by keeping a few elements from the start and end. + E.g. 'a/b/c/d/e/f.txt' -> 'a/.../e/f.txt' + """ + parts = path.strip('/').split('/') + if len(parts) <= (keep_start + keep_end): + return path + return '/'.join(parts[:keep_start]) + '/.../' + '/'.join(parts[-keep_end:]) + + def split_path_label(path): + parts = path.split('/') + if len(parts) >= 2: + return f"{'/'.join(parts[:-1])}/\n{parts[-1]}" + return path + + # Load YAML workflow file + with open(yaml_file_path, 'r') as f: + workflow_full = yaml.safe_load(f) + + dot = Digraph(format=output_format) + dot.attr(rankdir='LR') #'TB') # vertical layout + + # Set DPI only if format supports it (like png) + if output_format.lower() == 'png': + dot.attr(dpi=str(dpi)) + + used_paths = set() + + for step_name, step in workflow_full['steps'].items(): + # Extract parameters + params = step.get("parameters", []) + + # Extract parameters if enabled + param_lines = [] + if show_parameters: + params = step.get("parameters", []) + for param in params: + for k, v in param.items(): + val = v.get("value", "") + param_lines.append(f"{k} = {val}") + + param_label = "\n".join(param_lines) + label = f"{param_label}\n{step_name}" if param_label else step_name + + dot.node(step_name, label=label, shape="box", style="filled", fillcolor="lightblue") + + for input_item in step.get('inputs', []): + for key, val in input_item.items(): + if isinstance(val, dict) and 'path' in val: + path = shorten_path(val['path']) + label = split_path_label(path) + if path not in used_paths: + dot.node(path, label=label, tooltip=path, shape="ellipse", style="filled", fillcolor="lightgrey") + used_paths.add(path) + dot.edge(path, step_name) + + for output_item in step.get('outputs', []): + for key, val in output_item.items(): + if isinstance(val, dict) and 'path' in val: + path = shorten_path(val['path']) + label = split_path_label(path) + if path not in used_paths: + dot.node(path, label=label, tooltip=path, shape="ellipse", style="filled", fillcolor="lightgreen") + used_paths.add(path) + dot.edge(step_name, path) + + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, output_name) + dot.render(output_path) + + # For SVG or PDF, you may want to return the file path or raw output instead of Image() + if output_format.lower() in ['png', 'jpg', 'jpeg', 'gif']: + return Image(output_path + f".{output_format}") + else: + print(f"Graph saved to: {output_path}.{output_format}") + return output_path + f".{output_format}"