Source code for bonobo.structs.graphs

import html
import json
from collections import namedtuple
from copy import copy

from graphviz import ExecutableNotFound
from import Digraph

from bonobo.constants import BEGIN
from bonobo.util import get_name

GraphRange = namedtuple('GraphRange', ['graph', 'input', 'output'])

[docs]class Graph: """ Represents a directed graph of nodes. """ name = '' def __init__(self, *chain): self.edges = {BEGIN: set()} self.named = {} self.nodes = [] self.add_chain(*chain) def __iter__(self): yield from self.nodes def __len__(self): """ Node count. """ return len(self.nodes) def __getitem__(self, key): return self.nodes[key]
[docs] def outputs_of(self, idx, create=False): """ Get a set of the outputs for a given node index. """ if create and not idx in self.edges: self.edges[idx] = set() return self.edges[idx]
[docs] def add_node(self, c): """ Add a node without connections in this graph and returns its index. """ idx = len(self.nodes) self.edges[idx] = set() self.nodes.append(c) return idx
[docs] def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None): """ Add a chain in this graph. """ if len(nodes): _input = self._resolve_index(_input) _output = self._resolve_index(_output) _first = None _last = None for i, node in enumerate(nodes): _last = self.add_node(node) if not i and _name: if _name in self.named: raise KeyError('Duplicate name {!r} in graph.'.format(_name)) self.named[_name] = _last if _first is None: _first = _last self.outputs_of(_input, create=True).add(_last) _input = _last if _output is not None: self.outputs_of(_input, create=True).add(_output) if hasattr(self, '_topologcally_sorted_indexes_cache'): del self._topologcally_sorted_indexes_cache return GraphRange(self, _first, _last) return GraphRange(self, None, None)
[docs] def copy(self): g = Graph() g.edges = copy(self.edges) g.named = copy(self.named) g.nodes = copy(self.nodes) return g
@property def topologically_sorted_indexes(self): """Iterate in topological order, based on networkx's topological_sort() function. """ try: return self._topologcally_sorted_indexes_cache except AttributeError: seen = set() order = [] explored = set() for i in self.edges: if i in explored: continue fringe = [i] while fringe: w = fringe[-1] # depth first search if w in explored: # already looked down this branch fringe.pop() continue seen.add(w) # mark as seen # Check successors for cycles and for new nodes new_nodes = [] for n in self.outputs_of(w): if n not in explored: if n in seen: # CYCLE !! raise RuntimeError("Graph contains a cycle.") new_nodes.append(n) if new_nodes: # Add new_nodes to fringe fringe.extend(new_nodes) else: # No new nodes so w is fully explored explored.add(w) order.append(w) fringe.pop() # done considering this node self._topologcally_sorted_indexes_cache = tuple(filter(lambda i: type(i) is int, reversed(order))) return self._topologcally_sorted_indexes_cache @property def graphviz(self): try: return self._graphviz except AttributeError: g = Digraph() g.attr(rankdir='LR') g.node('BEGIN', shape='point') for i in self.outputs_of(BEGIN): g.edge('BEGIN', str(i)) for ix in self.topologically_sorted_indexes: g.node(str(ix), label=get_name(self[ix])) for iy in self.outputs_of(ix): g.edge(str(ix), str(iy)) self._graphviz = g return self._graphviz def _repr_dot_(self): return str(self.graphviz) def _repr_html_(self): try: return '<div>{}</div><pre>{}</pre>'.format(self.graphviz._repr_svg_(), html.escape(repr(self))) except (ExecutableNotFound, FileNotFoundError) as exc: return '<strong>{}</strong>: {}'.format(type(exc).__name__, str(exc)) def _resolve_index(self, mixed): """ Find the index based on various strategies for a node, probably an input or output of chain. Supported inputs are indexes, node values or names. """ if mixed is None: return None if type(mixed) is int or mixed in self.edges: return mixed if isinstance(mixed, str) and mixed in self.named: return self.named[mixed] if mixed in self.nodes: return self.nodes.index(mixed) raise ValueError('Cannot find node matching {!r}.'.format(mixed))
def _get_graphviz_node_id(graph, i): escaped_index = str(i) escaped_name = json.dumps(get_name(graph[i])) return '{{{} [label={}]}}'.format(escaped_index, escaped_name)