Source code for gristmill.optimize

"""Optimizer for the contraction computations."""

import collections
import enum
import functools
import heapq
import itertools
import operator
import typing
import warnings

from drudge import TensorDef, prod_, Term, Range, sum_
from networkx import Graph
from sympy import (
    Integer, Symbol, Expr, IndexedBase, Mul, Indexed, primitive, Wild,
    default_sort_key, Pow
)

from .utils import (
    Size, get_total_size, DSF, Tuple4Cmp, form_sized_range
)


#
#  The public driver
#  -----------------
#


[docs]class ContrStrat(enum.Enum): """The strategies for handling tensor contractions. This class holds possible options for different ways of handling contractions in the optimization, for both the termination of the main loop and the retention of parenthesizations for sum optimization. Specifically, we have options ``GREEDY`` The contraction within each term will be optimized greedily. This accelerates the optimization with big sacrifice of the result quality. So it should only be used for inputs having terms containing many factors by a very dense pattern. ``OPT`` The global minimum of each tensor contraction will be found by the advanced algorithm in gristmill. And only the optimal contraction(s) will be kept for the sum optimization. ``TRAV`` The same strategy as ``OPT`` will be attempted for the optimization of contractions. But all evaluations traversed in the optimization process will be kept and considered in subsequent summation optimizations. ``EXHAUST`` All possible parenthesizations will be considered for all terms. This can be extremely slow. But it might be helpful for problems having terms all with manageable number of factors. """ GREEDY = 0 OPT = 1 TRAV = 2 EXHAUST = 3
[docs]class RepeatedTermsStrat(enum.Enum): """Optimization for repeated terms in a sum. In some sum of tensor contractions, some terms might be different components of the same computed tensor. For instance, in .. math:: r_{a, b} = s_a t_b + s_b t_a if we define .. math:: i_{a, b} = s_a t_b the two terms are actually :math:`i_{a, b}` and :math:`i_{b, a}`. For problem with repeated terms, we have strategies, ``SKIP`` Repeated terms are simply skipped during the optimization by factorization. In this way, repeated terms are guaranteed not to be computed twice even implicitly. ``NATURAL`` Repeated terms participates factorization only when faster evaluation is given by this. Technically, this is achieved by setting the excess cost of the evaluation of the terms to be the **full cost** of the evaluation, rather than the difference with the optimal cost. This setting should give acceptable result for most purposes. ``IGNORE`` Ignore the fact that the terms are repeated. They are going to be treated exactly like other terms. """ SKIP = 0 NATURAL = 1 IGNORE = 2
[docs]def optimize(computs: typing.Iterable[TensorDef], substs=None, simplify=True, interm_fmt='tau^{}', contr_strat=ContrStrat.TRAV, opt_sum=True, repeated_terms_strat=RepeatedTermsStrat.NATURAL, opt_symm=True, req_an_opt=False, greedy_cutoff=-1, drop_cutoff=-1, remove_shallow=True) -> typing.List[TensorDef]: """Optimize the evaluation of the given tensor computations. This function will transform the given computations, given as tensor definitions, into another list of computations mathematically equivalent to the given computations, while requiring less arithmetic operations. Parameters ---------- computs The computations, can be given as an iterable of tensor definitions. substs A dictionary for making substitutions inside the sizes of ranges. All the ranges need to have size in at most one undetermined variable after the substitution, so that they can be totally ordered. When one symbol still remains in the sizes, the asymptotic cost (scaling and prefactor) will be optimized. Or when all symbols are gone after the substitution, optimization is going to be based on the numeric sizes. Numeric sizes tend to make the optimization faster due to the usage of built-in integer or floating point arithmetic in lieu of the more complex polynomial arithmetic. simplify If the input is going to be simplified before processing. It can be disabled when the input is already simplified. interm_fmt The format for the names of the intermediates. contr_strat The strategy for handling contractions, as explained in :py:class:`ContrStrat`. repeated_terms_strat The strategy for handling repeated terms in sums, as explained in :py:class:`RepeatedTermsStrat`. opt_sum If sums of multiple terms will be attempted to be optimized by using constriction (factorization). opt_symm If common symmetrization of multiple tensors, input or intermediate, is going to be optimized. For instance, with it, :math:`x_{a, b} + y_{a, b} - 2 x_{b, a} - 2 y_{b, a}` can be optimized into first computing :math:`p_{a, b} = x_{a, b} + y_{a, b}` followed by :math:`p_{a, b} - 2 p_{b, a}`. req_an_opt If each constriction operation is required to have optimal parenthesization for at lease one of its terms. This requirement attempts to accelerate the constriction searching by having a smaller number of branches at the first-edge level of the recursion tree. However, it has a chance of giving deteriorated optimization, and it is not guaranteed to be faster since pivoting at this level have to be disabled. So it is set as False by default. It might be worth experimenting for large inputs, especially with exhaust strategy for contractions, or when greedy is turned on. greedy_cutoff The depth cutoff for making greedy selection in constriction. Beyond this depth in the recursion tree (inclusive), only the choices making locally best saving will be considered. With negative values, full Bron-Kerbosch backtracking is performed. drop_cutoff The depth cutoff for picking only a random one with greedy saving in summation optimization. The difference with the option ``greedy_cutoff`` is that here only **one** choice giving the locally best saving will be considered, rather than all of them. This could give better acceleration than ``greedy_cutoff`` at the presence of large degeneracy, while results could be less optimized. For large inputs, a value of ``2`` is advised. remove_shallow Shallow intermediates are outer-product intermediates that come with no summations. Normally these intermediates cannot give saving big enough to justify their memory usage. So by default, they just dropped, with their content inlined into places where they are referenced. """ # This interface function is primarily just for sanity checking and # normalization of the input. substs = {} if substs is None else substs computs = [ i.simplify() if simplify else i.reset_dumms() for i in computs ] if len(computs) == 0: raise ValueError('No computation is given!') if not isinstance(contr_strat, ContrStrat): raise TypeError('Invalid contraction strategy', contr_strat) opt = _Optimizer( computs, substs=substs, interm_fmt=interm_fmt, contr_strat=contr_strat, opt_sum=opt_sum, repeated_terms_strat=repeated_terms_strat, opt_symm=opt_symm, req_an_opt=req_an_opt, greedy_cutoff=greedy_cutoff, drop_cutoff=drop_cutoff, remove_shallow=remove_shallow ) return opt.optimize()
# # The internal optimization engine # -------------------------------- # # General small type definitions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Base for tensor definitions. # # Symbol for 0-order tensors, IndexedBase for other cases. _Base = typing.Union[Symbol, IndexedBase] # Symbol/range pairs. _SrPairs = typing.Sequence[typing.Tuple[Symbol, Range]] # Sequences of terms. _Terms = typing.Sequence[Term] # Indices to tensor bases. _Indices = typing.Tuple[Expr] class _Grain(typing.NamedTuple): """A piece of grain ready for optimization. Basically it is a tensor definition with localized terms. """ base: _Base exts: _SrPairs terms: _Terms class _IntermRef(typing.NamedTuple): """A reference to an intermediate.""" coeff: Expr base: _Base indices: _Indices power: int @property def ref(self): """The reference to intermediate without coefficient.""" return _index(self.base, self.indices) ** self.power # # Utility constants # ~~~~~~~~~~~~~~~~~ # _ZERO = Integer(0) _UNITY = Integer(1) _NEG_UNITY = Integer(-1) _EXT = 0 _SUMMED_EXT = 1 _SUMMED = 2 _SUBSTED_EVAL_BASE = Symbol('gristmillSubstitutedEvalBase') # # Global factories # ~~~~~~~~~~~~~~~~ # class _SymbFactory(dict): """A small symbol factory.""" def __missing__(self, key): return Symbol('gristmillInternalSymbol{}'.format(key)) _SYMB_FACTORY = _SymbFactory() class _WildFactory(dict): """A small wild symbol factory.""" def __missing__(self, key): return Wild('gristmillInternalWild{}'.format(key)) _WILD_FACTORY = _WildFactory() # # Utility static functions # ~~~~~~~~~~~~~~~~~~~~~~~~ # def _get_canon_coeff(coeffs, preferred): """Get the canonical coefficient from a list of coefficients.""" expr = sum( v * _SYMB_FACTORY[i] for i, v in enumerate(coeffs) ).together() frac = _UNITY # The fractional part. if isinstance(expr, Mul): for i in expr.args: if isinstance(i, Pow) and i.args[1] < 0: frac *= i continue expr /= frac coeff, _ = primitive(expr, *[ _SYMB_FACTORY[i] for i, _ in enumerate(coeffs) ]) # Initial coefficient without phase. init_coeff = coeff * frac # The primitive computation does not take phase into account. negs = [] poses = [] for i in coeffs: i /= init_coeff if i.has(_NEG_UNITY) or i.is_negative: negs.append(-i) else: poses.append(i) continue neg_sig, pos_sig = [ (len(i), tuple(sorted(default_sort_key(j) for j in i))) for i in [negs, poses] ] if neg_sig > pos_sig: phase = _NEG_UNITY elif pos_sig > neg_sig: phase = _UNITY else: preferred_phase = ( _NEG_UNITY if preferred.has(_NEG_UNITY) or preferred.is_negative else _UNITY ) phase = preferred_phase return (coeff * phase * frac).simplify() def _index(base, indices, strip=False) -> Expr: """Index the given base with indices. When strip is set to true, the indices are assumed to be symbol/range pairs list. """ if strip: indices = tuple(i for i, _ in indices) else: indices = tuple(indices) return base if len(indices) == 0 else IndexedBase(base)[indices] # # Core evaluation DAG nodes # ~~~~~~~~~~~~~~~~~~~~~~~~~ # class _EvalNode: """A node in the evaluation graph. """ def __init__(self, base: Symbol, exts: _SrPairs): """Initialize the evaluation node. """ self.base = base self.exts = exts # For optimization. self.evals = [] # type: typing.List[_EvalNode] self.total_cost = None # For result finalization. self.n_refs = 0 self.generated = False def get_substs(self, indices): """Get substitutions and symbols requiring exclusion before indexing. First resetting dummies excluding the returned symbols and then making the returned substitution on each term could achieve indexing. Since the real free symbols are already gather from all inputs, the free symbols are not considered here. But they should be added for the resetting. """ substs = {} excl = set() assert len(indices) == len(self.exts) for i, j in zip(indices, self.exts): dumm = j[0] substs[dumm] = i excl.add(dumm) excl |= i.atoms(Symbol) continue return substs, excl class _Sum(_EvalNode): """Sum nodes in the evaluation graph.""" def __init__(self, base, exts, sum_terms): """Initialize the node.""" super().__init__(base, exts) self.sum_terms = sum_terms def __repr__(self): """Form a representation string for the node.""" return '_Sum(base={}, exts={}, sum_terms={})'.format( repr(self.base), repr(self.exts), repr(self.sum_terms) ) class _Prod(_EvalNode): """Product nodes in the evaluation graph. """ def __init__(self, base, exts, sums, coeff, factors): """Initialize the node.""" super().__init__(base, exts) self.sums = sums self.coeff = coeff self.factors = factors def __repr__(self): """Form a representation string for the node.""" return '_Prod(base={}, exts={}, sums={}, coeff={}, factors={})'.format( repr(self.base), repr(self.exts), repr(self.sums), repr(self.coeff), repr(self.factors) ) class _Interm(typing.NamedTuple): """Newly formed intermediate. This small utility carries both a symbolic reference to an intermediate and the actual node for this, which can be helpful for getting information about a newly-formed intermediate. """ ref: Expr node: _EvalNode # # Internals for product optimization # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # def _get_prod_final_cost(exts_total_size, sums_total_size) -> Size: """Compute the final cost for a pairwise product evaluation.""" if sums_total_size == 1: return exts_total_size else: return 2 * exts_total_size * sums_total_size def _gen_broken_sums(sums): """Generate broken summations in increasing size of broken summations. The size and the actual subset of broken summations are generated. """ sizes = [i.size for _, i in sums] # Sizes are assumed to be sorted. n_sums = len(sizes) init = Tuple4Cmp((1, 0)) # Nothing is broken. queue = [init] while len(queue) > 0: curr = heapq.heappop(queue) yield curr curr_size, curr_broken = curr next_idx = curr_broken.bit_length() if next_idx < n_sums: joined_size = curr_size * sizes[next_idx] joined_set = curr_broken | 1 << next_idx heapq.heappush(queue, Tuple4Cmp(( joined_size, joined_set ))) if next_idx > 0: top_idx = next_idx - 1 new_size, rem = divmod(joined_size, sizes[top_idx]) assert rem == 0 assert joined_set & 1 << top_idx heapq.heappush(queue, Tuple4Cmp(( new_size, joined_set ^ 1 << top_idx ))) continue # # Internals for summation optimization # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Convention: nodes always refer to nodes in the DAG for tensor computations, # while vertices are used for vertices in the constriction graph. # # Organized references to products in a summation. # # Intermediate base -> (indices -> coefficient) _OrgTerms = typing.DefaultDict[ Symbol, typing.DefaultDict[typing.Tuple[Expr, ...], Expr] ] # # Symbolic names for the parts of the bicliques. # _LEFT = 0 _RIGHT = 1 _OPPOS = { _LEFT: _RIGHT, _RIGHT: _LEFT } # For type annotation, actually is should be ``_LEFT | _RIGHT`` in Haskell # algebraic data type notation. _LR = typing.NewType('_LR', int) _LRS = (_LEFT, _RIGHT) class _LastStepIdxes(typing.NamedTuple): """The involved indices of the last step of a constriction. The external and summation indices involved by the left/right factor in the last step of a contraction. This is going to be used as the key for accessing the actual graph. """ exts: typing.Tuple[_SrPairs, _SrPairs] sums: _SrPairs class _EdgeInfo(typing.NamedTuple): """Information about an edge on a constriction graph.""" term: int eval_: _Prod coeff: Expr exc_cost: Size class _BaseInfo: """Information about a base referenced in a sum node. This is an open struct, with most of its manipulation done inside the optimizer class. """ __slots__ = [ 'count', 'base', 'node' ] def __init__(self, base: _Base, node: _Prod): """Initialize the information. The count is initialized to **zero**. """ self.base = base self.node = node self.count: int = 0 return # # Intermediate data and results for the Kron-Kerbosch process. # class _VertInfo(typing.NamedTuple): """Information about a vertex in a constriction graph. expr The original expression for the factor. exts The involved external indices. canon The canonicalized content for the factor. """ exts: int expr: Expr canon: _Terms class _Delta(object): """Additional information about augmentation by a designated vertex. """ __slots__ = [ 'coeff', 'leading_coeff', 'terms', 'exc_cost', 'saving' ] def __init__( self, coeff: Expr, leading_coeff: typing.Optional[Expr], terms: int, exc_cost: Size ): """Initialize the delta.""" self.coeff = coeff self.leading_coeff = leading_coeff self.terms = terms self.exc_cost = exc_cost self.saving: Size = 0 class _DesVert(typing.NamedTuple): """Vertices designated for a specific part.""" part: int vert: int # Sets of designated vertices. _DesVerts = typing.Set[_DesVert] # Dictionary of the designated vertices augmenting the current biclique along # with their delta. _DesVertsWDelta = typing.Mapping[_DesVert, _Delta] # Zipped vertices and coefficients. _VertsWCoeff = typing.List[typing.Tuple[int, Expr]] # Parts for a constriction, left and right. _ConstrParts = typing.Tuple[_VertsWCoeff, _VertsWCoeff] class _Biclique(typing.NamedTuple): """A biclique to be yielded.""" parts: _ConstrParts leading_coeff: Expr terms: int saving: Size constr_graph: '_ConstrGraph' # # Cost-related utilities for the Kron-Kerbosch process. # class _CostCoeffs(typing.NamedTuple): """Cached quantities for getting gross saving of bicliques. final The final cost for contraction and make an addition of the results. preps The cost of making an addition for left and right factors. """ final: Size preps: typing.Tuple[Size, Size] def _get_cost_coeffs(last_step_idxes: _LastStepIdxes) -> _CostCoeffs: """Get the cost coefficients for the given last step indices.""" sums = last_step_idxes.sums exts = last_step_idxes.exts ext_size = get_total_size(itertools.chain.from_iterable(exts)) final = _get_prod_final_cost( ext_size, get_total_size(sums) ) + ext_size preps = ( get_total_size(itertools.chain(exts[0], sums)), get_total_size(itertools.chain(exts[1], sums)) ) # Explicitly repeated for linter. return _CostCoeffs(final=final, preps=preps) class _VertGross(typing.Dict[typing.Tuple[int, int], typing.Tuple[Size, Size]]): """Gross saving of vertices. Given any numbers of vertices in the two parts, the gross saving of an additional vertex in the two parts can be queried. The result are memorized for performance. """ __slots__ = [ '_cost_coeffs' ] def __init__(self, last_step_idxes: _LastStepIdxes): """Initialize the dictionary.""" self._cost_coeffs = _get_cost_coeffs(last_step_idxes) def __missing__(self, key): """Compute the gross savings for new keys.""" assert len(key) == 2 assert all(i >= 0 for i in key) if any(i == 0 for i in key): res = (0, 0) else: cost_coeffs = self._cost_coeffs res = tuple( key[_OPPOS[i]] * cost_coeffs.final - cost_coeffs.preps[i] for i in _LRS ) self[key] = res return res # # The core classes. # class _BronKerbosch: """Iterable for the maximal bicliques. For performance reasons, the bicliques generated will contain references to internal mutable data. It is the **responsibility of the caller** to make proper copy when it is necessary. """ def __init__( self, last_step_idxes: _LastStepIdxes, constr_graph: '_ConstrGraph' ): """Initialize the iterator.""" # Static data during the recursion, cached here for easier and faster # access. self._constr_graph = constr_graph self._opt = constr_graph.constr_graphs.opt self._req_an_opt = self._opt.req_an_opt self._greedy_cutoff = self._opt.greedy_cutoff self._drop_cutoff = self._opt.drop_cutoff # Dynamic data during the recursion. # # Zipped nodes and coefficients, for left and right. self._curr: _ConstrParts = ([], []) # The leading coefficient. self._leading_coeff = None # The stack of actual saving. # # Keeping the saving as stack could save the cost of subtraction by # using some additional memory. self._savings = [] # The set of terms currently in the biclique. self._terms = 0 # Gross saving of new vertices. self._vert_gross: _VertGross = _VertGross(last_step_idxes) def __iter__(self): """Iterate over the maximal bicliques.""" exts = self._constr_graph.exts # All left and right nodes. subg = { _DesVert(part=part, vert=vert): _Delta( coeff=_UNITY, leading_coeff=None, terms=0, exc_cost=0 ) for vert, info in self._constr_graph.verts for part in _LRS if info.exts == exts[part] } assert len(subg) > 0 yield from self._expand(subg, set(subg.keys())) # If things all goes correctly, the stack should be reverted to initial # state by now. assert all(len(i) == 0 for i in self._curr) assert self._terms == 0 assert len(self._savings) == 0 assert self._leading_coeff is None return def _expand( self, subg: _DesVertsWDelta, cand: _DesVerts, ): """Generate the bicliques from the current state. This is the core of the Bron-Kerbosch algorithm. """ # Cached variables of the current state. curr = self._curr n_verts = tuple(len(i) for i in curr) savings = self._savings depth = len(savings) curr_saving = savings[-1] if depth > 0 else 0 exts = self._constr_graph.exts # The code here are adapted from the code in NetworkX for maximal clique # problem of simple general graphs. The original code are kept as much # as possible and put in comments. The original code on which the code # is based can be found at, # # https://github.com/networkx/networkx/blob # /48f4b5736174844c77044fae90e3e7adf1dabc10/networkx/algorithms # /clique.py#L277-L299 # if_maximal = all(i.saving < 0 for i in subg.values()) # Redundant check on biclique size is used to skip the possibly # expansive saving comparison. if_profitable = all( i > 0 for i in n_verts ) and any(i > 1 for i in n_verts) and curr_saving >= 0 if if_maximal and if_profitable: # If maximal and profitable. # # if not subg_q: # yield Q[:] # yield _Biclique( parts=curr, leading_coeff=self._leading_coeff, terms=self._terms, saving=curr_saving, constr_graph=self._constr_graph ) # The quadratic loop. subgq = {} for q_v, q_d in subg.items(): subg_q = {} subgq[q_v] = subg_q for r_v, r_d in subg.items(): updated_r_d = self._update_delta(q_v, q_d, r_v, r_d) if updated_r_d is not None: subg_q[r_v] = updated_r_d continue continue # # u = max(subg, key=lambda u: len(cand & adj[u])) # for q in cand - adj[u]: # # to_loop need to be eagerly evaluated for avoiding complication with # the mutation of cand during the loop and the set operations for # pivoting. pivots: typing.Iterable[_DesVert] = [] if n_verts[0] == 0: to_loop = {i for i in cand if i.part == 0} elif n_verts[1] == 0: to_loop = {i for i in cand if i.part == 1} if exts[0] == exts[1]: # First part, first vertex, the vertex exist_vert: int = curr[0][0][0] to_loop = {i for i in to_loop if i.vert > exist_vert} if self._req_an_opt: to_loop = {i for i in to_loop if subg[i].exc_cost == 0} else: gross = self._vert_gross[(1, 1)][1] pivots = ( k for k, v in subg.items() if k.part == 1 and gross - v.exc_cost >= 0 ) else: to_loop = {i for i in cand if subg[i].saving >= 0} if len(to_loop) == 0: return pivots = (k for k, v in subg.items() if v.saving > 0) cut_greedy = 0 <= self._greedy_cutoff <= depth cut_full = 0 <= self._drop_cutoff <= depth if cut_greedy or cut_full: greedy_saving = max(subg[i].saving for i in to_loop) to_loop = { i for i in to_loop if subg[i].saving == greedy_saving } if cut_full: to_loop = {to_loop.pop()} pivots = [] # Designated vertices that can be excluded for each pivot. fqs = ( {i for i in subgq[k].keys() if i.part == k.part} for k in pivots ) try: excl = max(fqs, key=lambda x: len(x & to_loop)) except ValueError: pass else: to_loop -= excl for q_v in to_loop: q_d = subg[q_v] part, vert = q_v.part, q_v.vert # # cand.remove(q) # cand.remove(q_v) # # Q.append(q) # curr[part].append((vert, q_d.coeff)) if q_d.leading_coeff is not None: self._leading_coeff = q_d.leading_coeff new_terms = q_d.terms assert self._terms & new_terms == 0 self._terms |= new_terms savings.append(curr_saving + q_d.saving) # # adj_q = adj[q] # subg_q = subg & adj_q # subg_q = subgq[q_v] # # if not subg_q: # yield Q[:] # # Moved to top for clarity. # # # cand_q = cand & adj_q # cand_q = {i for i in cand if i in subg_q} # if cand_q: # for clique in expand(subg_q, cand_q): # yield clique yield from self._expand(subg_q, cand_q) # # Q.pop() # curr[part].pop() assert self._terms & new_terms == new_terms self._terms ^= new_terms savings.pop() if q_d.leading_coeff is not None: self._leading_coeff = None continue def _update_delta( self, new_v: _DesVert, new_d: _Delta, curr_v: _DesVert, curr_d: _Delta ) -> typing.Optional[_Delta]: """Update the delta assuming a new node is added to the stack. This is the core and performance bottleneck of the Bron-Kerbosch algorithm. """ new_terms = new_d.terms curr_terms = curr_d.terms if new_terms & curr_terms != 0: return None new_p = new_v.part curr_p = curr_v.part curr_coeff = curr_d.coeff new_leading_coeff = new_d.leading_coeff curr_leading_coeff = curr_d.leading_coeff updated_d = _Delta( coeff=curr_coeff, leading_coeff=curr_leading_coeff, terms=curr_terms, exc_cost=curr_d.exc_cost ) if new_p == curr_p: if new_leading_coeff is not None: assert curr_leading_coeff is not None updated_d.coeff = ( curr_leading_coeff / new_leading_coeff ).simplify() updated_d.leading_coeff = None else: new_neighb = self._constr_graph.graph[new_v.vert] curr_vert = curr_v.vert if curr_vert not in new_neighb: return None edge = new_neighb[curr_vert]['info'] edge_term = 1 << edge.term if_conflict = ( edge_term & new_terms != 0 or edge_term & curr_terms != 0 or edge_term & self._terms != 0 ) if if_conflict: return None updated_d.terms |= edge_term updated_d.exc_cost += edge.exc_cost edge_coeff = edge.coeff if new_leading_coeff is not None: # The previous node gives the first edge. updated_d.coeff = ( edge_coeff / new_leading_coeff ).simplify() elif self._leading_coeff is None: # This node gives the first edge. updated_d.leading_coeff = edge_coeff else: proj = edge_coeff / (self._leading_coeff * new_d.coeff) if (proj - curr_coeff).simplify() != 0: return None n_verts = [len(i) for i in self._curr] n_verts[new_p] += 1 gross = self._vert_gross[tuple(n_verts)][curr_p] updated_d.saving = gross - updated_d.exc_cost return updated_d class _ConstrGraph: """Constriction graph for a given involvement of indices. We have separate graphs for different involved indices combinations. For each combination, the graph has the factors as vertices, and actual evaluations with the factors as edges. Internally, the graph is stored as a NetworkX graph. """ def __init__( self, constr_graphs: '_ConstrGraphs', exts_l: int, exts_r: int ): """Initialize the constriction graph. graphs The constriction graphs. exts_l, exts_r The pair of integers encoding the external indices involved by the two parts of the graph. """ self.constr_graphs = constr_graphs self.exts = (exts_l, exts_r) self.graph = Graph() self._verts = {} # From canonicalized factor to the vertex number. self.terms = 0 # The optimal biclique in the current graph. None when it is not yet # determined, False when it is determined that there is no profitable # biclique in the current graph. self._opt_saving = None self._opt_biclique = None @property def verts(self): """The nodes in the graph as integers with the information.""" return ( (i, j['info']) for i, j in self.graph.nodes_iter(data=True) ) def add_edge( self, node_infos: typing.Tuple[_VertInfo, _VertInfo], coeff: Expr, term: int, eval_: _Prod ): """Add a new edge to the graph.""" graph = self.graph term_bases = self.constr_graphs.term_bases repeated_terms_strat = self.constr_graphs.opt.repeated_terms_strat # Treat excess cost first, since it might lead to direct return. base_info = term_bases[term] count = base_info.count assert count > 0 if count == 1 or repeated_terms_strat == RepeatedTermsStrat.IGNORE: exc_cost = eval_.total_cost - base_info.node.total_cost elif repeated_terms_strat == RepeatedTermsStrat.SKIP: return elif repeated_terms_strat == RepeatedTermsStrat.NATURAL: exc_cost = eval_.total_cost else: exc_cost = None # For linter. assert None nodes = [] for i in node_infos: canon = i.canon if canon in self._verts: idx = self._verts[canon] else: idx = len(self._verts) self._verts[canon] = idx graph.add_node(idx, info=i) nodes.append(idx) continue edge_info = _EdgeInfo( term=term, eval_=eval_, coeff=coeff, exc_cost=exc_cost ) n1, n2 = nodes neighb1 = graph[n1] if n2 in neighb1: # It is possible that two evaluations actually the same be # recorded twice in the evaluation of product nodes because of # symmetry. assert neighb1[n2]['info'].term == edge_info.term else: graph.add_edge(*nodes, info=edge_info) self.terms |= 1 << term def get_opt_biclique( self, last_step_idxes: _LastStepIdxes ) -> typing.Tuple[typing.Optional[Size], typing.Optional[_Biclique]]: """Get the optimal biclique in the current graph. """ if self._opt_saving is not None: if self._opt_saving is False: return None, None else: return self._opt_saving, self._opt_biclique opt_saving = None opt_biclique = None for biclique in _BronKerbosch(last_step_idxes, self): saving = biclique.saving if opt_saving is None or saving > opt_saving: opt_saving = saving # Make copy only when we need them. parts = biclique.parts assert len(parts) == 2 opt_biclique = _Biclique( parts=(list(parts[0]), list(parts[1])), leading_coeff=biclique.leading_coeff, terms=biclique.terms, saving=biclique.saving, constr_graph=biclique.constr_graph ) continue if opt_saving is None: assert opt_biclique is None self._opt_saving = False self._opt_biclique = None return None, None else: self._opt_saving = opt_saving self._opt_biclique = opt_biclique return opt_saving, opt_biclique def remove_terms(self, terms: int) -> bool: """Remove all edges for the given terms. Vertices no longer connected to anything is removed as well. If a value of True is returned, we have an empty graph after the removal. """ graph = self.graph if self.terms & terms != 0: edges2remove = [ (n1, n2) for n1, n2, info in graph.edges_iter(data='info') if 1 << info.term & terms != 0 ] graph.remove_edges_from(edges2remove) nodes2remove = [ i for i in graph.nodes_iter() if graph.degree(i) == 0 ] graph.remove_nodes_from(nodes2remove) self.terms ^= self.terms & terms # Reset cached optimal biclique. self._opt_saving = None self._opt_biclique = None return graph.number_of_nodes() == 0 class _ConstrGraphs(typing.Dict[_LastStepIdxes, _ConstrGraph]): """The constriction graphs from a sum of contractions. The constriction graphs are organized according to their external and summation indices involved by the factors in the last step to achieve better performance with one big graph separated into pieces. With this decomposition, for instance, we can cache maximum bicliques in subgraphs unaffected by the latest constriction. Here just the basic data is defined, with most actual operations directly performed inside the optimizer. Attributes ---------- bases The mapping from the actual base to the base info. term_bases The list of base info for each of the terms. """ __slots__ = [ 'opt', 'bases', 'term_bases' ] def __init__(self, opt: '_Optimizer'): """Initialize the graphs. Here only the most basic resource initialization is performed. """ super().__init__() self.opt = opt self.bases: typing.Dict[_Base, _BaseInfo] = {} # None for plain scalar terms. self.term_bases: typing.List[typing.Optional[_BaseInfo]] = [] def get_opt_biclique(self) -> typing.Tuple[ typing.Optional[_LastStepIdxes], typing.Optional[_Biclique] ]: """Choose the most profitable biclique. """ opt_saving = None opt_last_step_idxes = None opt_biclique = None for last_step_idxes, constr_graph in self.items(): curr_opt_saving, curr_opt_biclique = constr_graph.get_opt_biclique( last_step_idxes ) if curr_opt_saving is None: continue if opt_saving is None or curr_opt_saving > opt_saving: opt_saving = curr_opt_saving opt_last_step_idxes = last_step_idxes opt_biclique = curr_opt_biclique continue return opt_last_step_idxes, opt_biclique def cleanup_constred(self, if_untouched: int, biclique: _Biclique) -> int: """Clean up the terms after a constriction.""" terms = biclique.terms assert if_untouched & terms == terms if_untouched ^= terms to_remove = [] for last_step_idxes, constr_graph in self.items(): if_empty = constr_graph.remove_terms(biclique.terms) if if_empty: to_remove.append(last_step_idxes) continue for i in to_remove: del self[i] continue return if_untouched # # Core optimizer class # ~~~~~~~~~~~~~~~~~~~~ # class _Optimizer: """Optimizer for tensor contraction computations. This internal optimizer can only be used once for one set of input. """ # # Public functions. # def __init__( self, computs, substs, interm_fmt, contr_strat, opt_sum, repeated_terms_strat, opt_symm, req_an_opt, greedy_cutoff, drop_cutoff, remove_shallow ): """Initialize the optimizer.""" # Information to be read from the input computations. # # The only drudge for the inputs. self._drudge = None # The only variable for range sizes. self._range_var = None # Mapping from the substituted range to original range. self._input_ranges = {} # Symbols that should not be used for dummies. self._excl = set() # Read, process, and verify user input. self._grist = [ self._form_grain(comput, substs) for comput in computs ] # Dummies stock in terms of the substituted range. assert self._drudge is not None self._dumms = { k: self._drudge.dumms.value[v] for k, v in self._input_ranges.items() } # Storage of user options to be accessed during the optimization. # # Public for the each of accessing from other internal classes. self.interm_fmt = interm_fmt self.contr_strat = contr_strat self.opt_sum = opt_sum self.repeated_terms_strat = repeated_terms_strat self.opt_symm = opt_symm self.req_an_opt = req_an_opt self.greedy_cutoff = greedy_cutoff self.drop_cutoff = drop_cutoff self.remove_shallow = remove_shallow # Other internal data preparation. self._next_internal_idx = 0 # From intermediate base to actual evaluation node. self._interms = {} # From the canonical form to intermediate base. self._interms_canon = {} self._res = None def optimize(self): """Optimize the evaluation of the given computations. """ if self._res is not None: return self._res res_nodes = [self._form_node(i) for i in self._grist] for i in res_nodes: self._optimize(i) continue self._res = self._linearize(res_nodes) return self._res # # User input pre-processing. # def _form_grain(self, comput, substs): """Form grain for a given computation. """ curr_drudge = comput.rhs.drudge if self._drudge is None: self._drudge = curr_drudge elif self._drudge is not curr_drudge: raise ValueError( 'Invalid computations to optimize, containing two drudges', (self._drudge, curr_drudge) ) else: pass # Externals processing. exts = self._proc_sums(comput.exts, substs) ext_symbs = {i for i, _ in exts} # Terms processing. terms = [] for term in comput.rhs_terms: if not term.is_scalar: raise ValueError( 'Invalid term to optimize', term, 'expecting scalar' ) sums = self._proc_sums(term.sums, substs, sort=True) amp = term.amp # Add the true free symbols to the exclusion set. self._excl |= term.free_vars - ext_symbs terms.append(Term(sums, amp, ())) continue return _Grain( base=comput.base if len(exts) == 0 else comput.base.args[0], exts=exts, terms=terms ) def _proc_sums(self, sums, substs, sort=False): """Process a summation list. The ranges will be replaced with substituted sizes. Relevant members of the optimizer will also be updated. User error will also be reported. """ res = [] for symb, range_ in sums: new_range, range_var = form_sized_range(range_, substs) if range_var is not None: if self._range_var is None: self._range_var = range_var elif self._range_var != range_var: raise ValueError( 'Invalid range', range_, 'unexpected symbol', range_var, 'conflicting with', self._range_var ) else: pass if new_range not in self._input_ranges: self._input_ranges[new_range] = range_ elif range_.size != self._input_ranges[new_range].size: raise ValueError( 'Invalid ranges', (range_, self._input_ranges[new_range]), 'duplicated labels' ) else: pass res.append((symb, new_range)) continue if sort: res.sort(key=lambda x: x[1].size) return tuple(res) # # Optimization result post-processing. # def _linearize( self, optimized: typing.Sequence[_EvalNode] ) -> typing.List[TensorDef]: """Linearize optimized forms of the evaluation. """ for node in optimized: self._set_n_refs(node) continue # Separate the intermediates and the results so that the results can be # guaranteed to be at the end of the evaluation sequence. interms = [] res = [] for node in optimized: curr = self._linearize_node(node, interms, keep=True) assert curr is not None res.append(curr) continue return self._finalize(itertools.chain(interms, res)) def _set_n_refs(self, node: _EvalNode): """Set reference counts from an evaluation node. """ if len(node.evals) == 0: self._optimize(node) # We need to find an evaluation with optimal cost. assert len(node.evals) > 0 node.evals = [next( i for i in node.evals if i.total_cost == node.total_cost )] eval_ = node.evals[0] if isinstance(eval_, _Prod): possible_refs = [i for i in eval_.factors] elif isinstance(eval_, _Sum): possible_refs = eval_.sum_terms else: assert False for i in possible_refs: ref = self._parse_interm_ref(i) if ref is None: continue dep_node = self._interms[ref.base] dep_node.n_refs += 1 self._set_n_refs(dep_node) continue return def _linearize_node(self, node: _EvalNode, res: list, keep=False): """Linearize evaluation rooted in the given node into the result. If keep if set to True, the evaluation of the given node will not be appended to the result list. """ if node.generated: return None def_, deps = self._form_def(node) for i in deps: self._linearize_node(self._interms[i], res) continue node.generated = True if not keep: res.append(def_) return def_ def _form_def(self, node: _EvalNode): """Form the final definition of an evaluation node. The dependencies will also be returned. """ assert len(node.evals) == 1 if isinstance(node, _Prod): return self._form_prod_def(node) elif isinstance(node, _Sum): return self._form_sum_def(node) else: assert False def _form_prod_def(self, node: _Prod): """Form the final definition of a product evaluation node.""" exts = node.exts eval_ = node.evals[0] assert isinstance(eval_, _Prod) term, deps = self._form_prod_def_term(eval_) return _Grain( base=node.base, exts=exts, terms=[term] ), deps def _form_prod_def_term(self, eval_: _Prod): """Form the term in the final definition of a product evaluation node. """ amp = eval_.coeff deps = [] for factor in eval_.factors: ref = self._parse_interm_ref(factor) if ref is not None: assert ref.coeff == 1 interm = self._interms[ref.base] if self._is_input(interm): # Inline trivial reference to an input. content = self._get_content(factor) assert len(content) == 1 assert len(content[0].sums) == 0 amp *= content[0].amp ** ref.power else: deps.append(ref.base) amp *= factor else: amp *= factor return Term(eval_.sums, amp, ()), deps def _form_sum_def(self, node: _Sum): """Form the final definition of a sum evaluation node.""" exts = node.exts exts_dict = dict(node.exts) terms = [] deps = [] eval_ = node.evals[0] assert isinstance(eval_, _Sum) sum_terms = [] self._inline_sum_terms(eval_.sum_terms, sum_terms) for term in sum_terms: ref = self._parse_interm_ref(term) if ref is None: terms.append(Term((), term, ())) # No dependency for pure scalars. continue assert ref.power == 1 # Higher power not possible in sum. # Sum term are guaranteed to be formed from references to products, # never directly written in terms of input. term_node = self._interms[ref.base] if term_node.n_refs == 1 or self._is_input(term_node): # Inline intermediates only used here and simple input # references. eval_ = term_node.evals[0] assert isinstance(eval_, _Prod) contents = self._index_prod(eval_, ref.indices) assert len(contents) == 1 term = contents[0] factors, term_coeff = term.get_amp_factors( self._interms, exts_dict ) # Switch back to evaluation node for using the facilities for # product nodes. tmp_node = _Prod( term_node.base, exts, term.sums, ref.coeff * term_coeff, factors ) new_term, term_deps = self._form_prod_def_term(tmp_node) terms.append(new_term) deps.extend(term_deps) else: terms.append(Term( (), term, () )) deps.append(ref.base) continue return _Grain( base=node.base, exts=exts, terms=terms ), deps def _inline_sum_terms( self, sum_terms: typing.Sequence[Expr], res: typing.List[Expr] ): """Inline the summation terms from single-reference terms. This function mutates the given result list rather than returning the result to avoid repeated list creation in recursive calls. """ for sum_term in sum_terms: ref = self._parse_interm_ref(sum_term) if ref is None: res.append(sum_term) continue assert ref.power == 1 node = self._interms[ref.base] assert len(node.evals) > 0 eval_ = node.evals[0] if_inline = isinstance(eval_, _Sum) and ( node.n_refs == 1 or len(eval_.sum_terms) == 1 ) if if_inline: if len(node.exts) == 0: substs = None else: substs = { i[0]: j for i, j in zip(eval_.exts, ref.indices) } proced_sum_terms = [ ( i.xreplace(substs) if substs is not None else sum_term ) * ref.coeff for i in eval_.sum_terms ] self._inline_sum_terms(proced_sum_terms, res) continue else: res.append(sum_term) continue return def _is_input(self, node: _EvalNode): """Test if a product node is just a trivial reference to an input.""" if isinstance(node, _Prod): return len(node.sums) == 0 and len(node.factors) == 1 and ( self._parse_interm_ref(node.factors[0]) is None ) else: return False def _finalize( self, computs: typing.Iterable[_Grain] ) -> typing.List[TensorDef]: """Finalize the linearization result. Things will be cast to drudge tensor definitions, with intermediates holding names formed from the format given by user. """ next_idx = 0 substs = {} # For normal substitution of bases. repls = [] # For removed shallow intermediates def proc_amp(amp): """Process the amplitude by making the found substitutions.""" for i in reversed(repls): amp = amp.replace(*i) continue return amp.xreplace(substs) # Cache some properties. remove_shallow = self.remove_shallow res = [] for comput in computs: exts = tuple((s, self._input_ranges[r]) for s, r in comput.exts) if_scalar = len(exts) == 0 base = comput.base if if_scalar else IndexedBase(comput.base) terms = [ i.map(proc_amp, sums=tuple( (s, self._input_ranges[r]) for s, r in i.sums )) for i in comput.terms ] # No internal intermediates should be leaked. for i in terms: assert not any(j in self._interms for j in i.free_vars) if comput.base in self._interms: if_shallow = ( remove_shallow and len(terms) == 1 and len(terms[0].sums) == 0 ) if if_shallow: # Remove shallow intermediates. The saving might be too # modest to justify the additional memory consumption. # # TODO: Move it earlier to a better place. repl_lhs = base if if_scalar else base[tuple( _WILD_FACTORY[i] for i, _ in enumerate(exts) )] repl_rhs = proc_amp(terms[0].amp.xreplace( {v[0]: _WILD_FACTORY[i] for i, v in enumerate(exts)} )) repls.append((repl_lhs, repl_rhs)) continue # No new intermediate added. final_base = ( Symbol if if_scalar else IndexedBase )(self.interm_fmt.format(next_idx)) next_idx += 1 substs[base] = final_base else: final_base = base res.append(TensorDef( final_base, exts, self._drudge.create_tensor(terms) ).reset_dumms()) continue return res # # Internal support utilities. # def _get_next_internal(self): """Get the symbol for the next internal intermediate. """ idx = self._next_internal_idx self._next_internal_idx += 1 return Symbol('gristmillInternalIntermediate{}'.format(idx)) @staticmethod def _write_in_orig_ranges(sums): """Write the summations in terms of undecorated bare ranges. The labels in the ranges are assumed to be decorated. """ return tuple( (i, j.replace_label(j.label[0])) for i, j in sums ) def _canon_terms(self, new_sums: _SrPairs, terms: typing.Iterable[Term]): """Form a canonical label for a list of terms. The new summation list is prepended to the summation list of all terms. The coefficient ahead of the canonical form is returned before the canonical form. And the permuted new summation list is also returned after the canonical form. Note that this list contains the original dummies given in the new summation list, while the terms has reset new dummies. Note that the ranges in the new summation list are assumed to be decorated with labels earlier than _SUMMED. In the result, they are still in decorated forms and are guaranteed to be permuted in the same way for all given terms. The summations from the terms will be internally decorated but written in bare ranges in the final result. Note that this is definitely a poor man's version of canonicalization of multi-term tensor definitions with external indices. A lot of cases cannot be handled well. Hopefully it can be replaced with a systematic treatment some day in the future. """ new_dumms = {i for i, _ in new_sums} coeffs = [] candidates = collections.defaultdict(list) for idx, term in enumerate(terms): term, canon_sums = self._canon_term(new_sums, term) factors, coeff = term.get_amp_factors(self._interms) coeffs.append(coeff) candidates[ term.map(lambda x: prod_(factors)) ].append((canon_sums, idx)) continue # Poor man's canonicalization of external indices. # # This algorithm is not guaranteed to work. Here we just choose an # ordering of the external indices that is as safe as possible. But # certainly it is not guaranteed to work for all cases. # # TODO: Fix it! chosen = min(candidates.items(), key=lambda x: ( len(x[1]), -len(x[0].amp.atoms(Symbol) & new_dumms), x[0].sort_key )) canon_new_sums = set(i for i, _ in chosen[1]) if len(canon_new_sums) > 1: warnings.warn( 'Internal deficiency: ' 'summation intermediate may not be fully canonicalized' ) # This could also fail when the chosen term has symmetry among the new # summations not present in any other term. This can be hard to check. canon_new_sum = canon_new_sums.pop() preferred = prod_( coeffs[i] for _, i in candidates[chosen[0]] ) canon_coeff = _get_canon_coeff(coeffs, preferred) res_terms = [] for term in terms: canon_term, _ = self._canon_term(canon_new_sum, term, fix_new=True) # TODO: Add support for complex conjugation. res_terms.append(canon_term.map(lambda x: x / canon_coeff)) continue return canon_coeff, tuple( sorted(res_terms, key=lambda x: x.sort_key) ), canon_new_sum def _canon_term(self, new_sums, term, fix_new=False): """Canonicalize a single term. Internal method for _canon_terms, not supposed to be directly called. """ term = Term(tuple(itertools.chain( ( (v[0], v[1].replace_label((v[1].label[0], _EXT, i))) for i, v in enumerate(new_sums) ) if fix_new else new_sums, ( (i, j.replace_label((j.label, _SUMMED))) for i, j in term.sums ) )), term.amp, ()) canoned = term.canon(symms=self._drudge.symms.value) canon_sums = canoned.sums canon_orig_sums = self._write_in_orig_ranges(canon_sums) dumm_reset, _ = canoned.map(sums=canon_orig_sums).reset_dumms( dumms=self._dumms, excl=self._excl ) canon_new_sums = [] term_new_sums = [] term_sums = [] i_new = 0 for i, j in zip(dumm_reset.sums, canon_sums): if j[1].label[1] == _SUMMED: # Existing summations. term_sums.append(i) else: if fix_new: assert j[0] == new_sums[i_new][0] range_ = new_sums[i_new][1] else: range_ = j[1] canon_new_sums.append((j[0], range_)) term_new_sums.append((i[0], range_)) i_new += 1 continue return dumm_reset.map(sums=tuple(itertools.chain( term_new_sums, term_sums ))), tuple(canon_new_sums) def _parse_interm_ref(self, expr: Expr) -> typing.Optional[_IntermRef]: """Parse an expression that is possibly an intermediate reference. """ coeff = _UNITY base = None indices = None power = None if isinstance(expr, Mul): args = expr.args else: args = [expr] for i in args: if any(j in self._interms for j in i.atoms(Symbol)): assert base is None ref, power = i.as_base_exp() if isinstance(ref, Indexed): base = ref.base.args[0] indices = ref.indices elif isinstance(ref, Symbol): base = ref indices = () else: assert False assert base in self._interms else: coeff *= i return None if base is None else _IntermRef( coeff=coeff, base=base, indices=indices, power=power ) def _get_content(self, interm_ref: Expr) -> typing.List[Term]: """Get the content of an intermediate reference. This function might be removed after the new factorization algorithm is implemented. """ ref = self._parse_interm_ref(interm_ref) assert ref is not None node = self._interms[ref.base] if isinstance(node, _Sum): content = self._index_sum(node, ref.indices) elif isinstance(node, _Prod): content = self._index_prod(node, ref.indices) else: assert False return [ i.scale(ref.coeff) for i in self._raise_power(content, ref.power) ] def _index_sum(self, node: _Sum, indices) -> typing.List[Term]: """Substitute the external indices in the sum node""" substs, _ = node.get_substs(indices) res = [] for i in node.sum_terms: term = i.xreplace(substs) ref = self._parse_interm_ref(term) if ref is None: res.append(term) else: term_def = self._get_content(term) res.extend(term_def) continue return res def _index_prod(self, node: _Prod, indices) -> typing.List[Term]: """Substitute the external indices in the evaluation node.""" substs, excl = node.get_substs(indices) term = Term( node.sums, node.coeff * prod_(node.factors), () ).reset_dumms( self._dumms, excl=self._excl | excl )[0].map(lambda x: x.xreplace(substs)) return [term] def _raise_power( self, terms: typing.Sequence[Term], exp: int ) -> typing.List[Term]: """Raise the sum of the given terms to the given power.""" curr = [] # type: typing.List[Term] for _ in range(exp): if len(curr) == 0: curr = list(terms) else: # TODO: Make the multiplication more efficient. curr = [i.mul_term( j, dumms=self._dumms, excl=self._excl | i.free_vars | j.free_vars ) for i, j in itertools.product(curr, terms)] return curr # # General optimization. # def _form_node(self, grain: _Grain): """Form an evaluation node from a tensor definition. This is the entry point for optimization. """ # We assume it is fully simplified and expanded by grist preparation. exts = grain.exts terms = grain.terms if len(terms) == 0: raise ValueError( 'Tensor is constant zero, probably it is not what you meant', grain.base ) return self._form_sum_from_terms( grain.base, exts, terms ) def _optimize(self, node): """Optimize the evaluation of the given node. The evaluation methods will be filled with, possibly multiple, method of evaluations. """ # For node with known evaluations, skip actual optimization. This # enables the acceleration from dynamic programming. if len(node.evals) > 0: return node if isinstance(node, _Sum): return self._optimize_sum(node) elif isinstance(node, _Prod): return self._optimize_prod(node) else: assert False def _form_prod_interm(self, exts, sums, factors) -> _Interm: """Form a product intermediate. The factors are assumed to be all non-trivial factors needing processing. """ decored_exts = tuple( (i, j.replace_label((j.label, _EXT))) for i, j in exts ) n_exts = len(decored_exts) term = Term(tuple(sums), prod_(factors).simplify(), ()) coeff, key, canon_exts = self._canon_terms( decored_exts, [term] ) assert len(key) == 1 if key in self._interms_canon: base = self._interms_canon[key] else: base = self._get_next_internal() self._interms_canon[key] = base key_term = key[0] key_exts = self._write_in_orig_ranges(key_term.sums[:n_exts]) key_sums = key_term.sums[n_exts:] # The external symbols will automatically be considered in # get_amp_factors since they are in the summation list right now. key_factors, key_coeff = key_term.get_amp_factors(self._interms) interm = _Prod( base, key_exts, key_sums, key_coeff, key_factors ) self._interms[base] = interm return _Interm( ref=coeff * _index(base, canon_exts, strip=True), node=self._interms[base] ) def _form_sum_interm( self, exts: _SrPairs, terms: typing.Sequence[Term] ) -> _Interm: """Form a sum intermediate. """ decored_exts = tuple( (i, j.replace_label((j.label, _EXT))) for i, j in exts ) n_exts = len(decored_exts) coeff, canon_terms, canon_exts = self._canon_terms(decored_exts, terms) if canon_terms in self._interms_canon: base = self._interms_canon[canon_terms] else: base = self._get_next_internal() self._interms_canon[canon_terms] = base node_exts = None node_terms = [] for term in canon_terms: term_exts = self._write_in_orig_ranges(term.sums[:n_exts]) if node_exts is None: node_exts = term_exts else: assert node_exts == term_exts node_terms.append(term.map(sums=term.sums[n_exts:])) continue node = self._form_sum_from_terms(base, node_exts, node_terms) self._interms[base] = node self._optimize(node) return _Interm( ref=coeff * _index(base, canon_exts, strip=True), node=self._interms[base] ) def _form_sum_from_terms( self, base: Symbol, exts: _SrPairs, terms: typing.Iterable[Term] ): """Form a summation node for given the terms. No processing is done in this method. It just forms the node. """ sum_terms = [] plain_scalars = [] ext_symbs = {i for i, _ in exts} for term in terms: sums = term.sums factors, coeff = term.get_amp_factors(self._interms, ext_symbs) if len(factors) == 0: plain_scalars.append(coeff) else: interm_ref, _ = self._form_prod_interm(exts, sums, factors) sum_terms.append(interm_ref * coeff) continue if len(plain_scalars) > 0: sum_terms.append(sum_(plain_scalars)) return _Sum(base, exts, sum_terms) # # Sum optimization. # def _optimize_sum(self, sum_node: _Sum): """Optimize the summation node.""" # We first optimize the common terms. exts = sum_node.exts scalars, terms, _ = self._organize_sum_terms(sum_node.sum_terms) if self.opt_sum: new_terms, old_terms = self.constr_sum(terms, exts) else: new_terms = [] old_terms = terms if self.opt_symm: old_terms = self._optimize_common_symmtrization(old_terms, exts) res_terms = scalars + old_terms + new_terms sum_node.evals = [_Sum( sum_node.base, sum_node.exts, res_terms )] return def _organize_sum_terms(self, terms: typing.Iterable[Expr]) -> typing.Tuple[ typing.List[Expr], typing.List[Expr], _OrgTerms ]: """Organize terms in the summation node. """ # Intermediate base -> (indices -> coefficient) # # This first gather terms with the same reference to deeper nodes. org_terms = collections.defaultdict( lambda: collections.defaultdict(lambda: _ZERO) ) plain_scalars = [] for term in terms: ref = self._parse_interm_ref(term) if ref is None: plain_scalars.append(term) continue assert ref.power == 1 org_terms[ref.base][ref.indices] += ref.coeff continue res_terms = [] for k, v in org_terms.items(): assert len(v) > 0 for indices, coeff in v.items(): coeff = coeff.simplify() if coeff != 0: res_terms.append( _index(k, indices) * coeff ) continue return plain_scalars, res_terms, org_terms def _optimize_common_symmtrization(self, terms, exts): """Optimize common symmetrization in the intermediate references. """ res_terms = [] exts_dict = dict(exts) scalars, _, org_terms = self._organize_sum_terms(terms) assert len(scalars) == 0 # Indices, coeffs tuple -> base, coeff pull_info = collections.defaultdict(list) for k, v in org_terms.items(): if len(v) == 0: assert False elif len(v) == 1: indices, coeff = v.popitem() res_terms.append( _index(k, indices) * coeff ) else: # Here we use name for sorting directly, since here we cannot # have general expressions hence no need to use the expensive # sort_key. raw = list(v.items()) # Indices/coefficient pairs. raw.sort(key=lambda x: [i.name for i in x[0]]) leading_coeff = raw[0][1] pull_info[tuple( (i, j / leading_coeff) for i, j in raw )].append((k, leading_coeff)) # Now we treat the terms from which new intermediates might be pulled # out. for k, v in pull_info.items(): pivot = k[0][0] assert len(pivot) > 0 assert k[0][1] == 1 if len(v) == 1: # No need to form a new intermediate. base, coeff = v[0] pivot_ref = _index(base, pivot) * coeff else: # We need to form an intermediate here. interm_exts = tuple( (i, exts_dict[i]) for i in pivot ) interm_terms = [ term.scale(coeff) for base, coeff in v for term in self._get_content(_index(base, pivot)) ] pivot_ref, interm_node = self._form_sum_interm( interm_exts, interm_terms ) self._optimize(interm_node) for indices, coeff in k: substs = { i: j for i, j in zip(pivot, indices) } res_terms.append( pivot_ref.xreplace(substs) * coeff ) continue continue return res_terms def constr_sum( self, terms: typing.Sequence[Expr], exts: _SrPairs ): """Constrict the summations greedily. """ if_untouched = (1 << len(terms)) - 1 new_terms = [] constr_graphs = self._form_constr_graphs(terms, exts) while True: last_step_idxes, biclique = constr_graphs.get_opt_biclique() if last_step_idxes is None: break new_terms.append(self._form_constred_term( last_step_idxes, biclique )) if_untouched = constr_graphs.cleanup_constred( if_untouched, biclique ) continue # End Main loop. untouched_terms = [ v for i, v in enumerate(terms) if if_untouched & (1 << i) != 0 ] return new_terms, untouched_terms def _form_constr_graphs( self, terms: typing.Sequence[Expr], exts: _SrPairs ) -> _ConstrGraphs: """Form the constriction graphs for the terms. The additional information about the bases of each of the terms are also returned. """ constr_graphs = _ConstrGraphs(self) base_infos = constr_graphs.bases term_bases = constr_graphs.term_bases term_ref_nodes = [] for term_idx, term in enumerate(terms): ref = self._parse_interm_ref(term) if ref is None: term_bases.append(None) term_ref_nodes.append((None, None)) continue base = ref.base node = self._interms[base] assert isinstance(node, _Prod) term_ref_nodes.append((ref, node)) if base in base_infos: base_info = base_infos[base] else: base_info = _BaseInfo(base, node) base_infos[base] = base_info base_info.count += 1 term_bases.append(base_info) self._optimize(node) continue # This loop should have the correct bases count. for term_idx, term in enumerate(terms): ref, node = term_ref_nodes[term_idx] if ref is None: continue for eval_ in node.evals: assert isinstance(eval_, _Prod) self._aug_constr_graphs_4_eval( constr_graphs, term_idx, ref, eval_, exts ) continue return constr_graphs def _aug_constr_graphs_4_eval( self, res: _ConstrGraphs, term_idx: int, ref: _IntermRef, eval_: _Prod, exts: _SrPairs ): """Augment the constriction graphs for an evaluation. """ if len(eval_.factors) < 2: return assert len(eval_.factors) == 2 eval_terms = self._index_prod(eval_, ref.indices) assert len(eval_terms) == 1 eval_term = eval_terms[0] ext_symbs = {i for i, _ in eval_.exts} factors, coeff = eval_term.get_amp_factors( self._interms, ext_symbs ) coeff *= ref.coeff assert len(factors) == 2 assert factors[0] != factors[1] sums = tuple(sorted( eval_term.sums, key=lambda x: (x[1], default_sort_key(x[0])) )) excl = set(self._excl) excl.update(ext_symbs) symms = self._drudge.symms.value factor_infos = [] for f_i in factors: content = self._get_content(f_i) assert len(content) == 1 content = content[0] symbs = f_i.atoms(Symbol) exts_idxes = tuple( i for i, v in enumerate(exts) if v[0] in symbs ) exts_int = functools.reduce(operator.or_, ( 1 << i for i in exts_idxes ), 0) for i, _ in sums: assert i in symbs # In order to really make sure, the content will be re-canonicalized # based on the current ambient. canon = content.canon(symms=symms).reset_dumms( self._dumms, excl=excl | content.free_vars )[0] _, canon_coeff = canon.get_amp_factors( self._interms, ext_symbs ) canon = canon.map( lambda x: x / canon_coeff, skip_vecs=True ) coeff *= canon_coeff factor_infos.append(( tuple(exts[i] for i in exts_idxes), _VertInfo(exts=exts_int, expr=f_i, canon=canon) )) continue factor_infos.sort(key=lambda x: x[1].exts) assert len(factor_infos) == 2 last_step_idxes = _LastStepIdxes( exts=(factor_infos[0][0], factor_infos[1][0]), sums=sums ) if last_step_idxes in res: constr_graph = res[last_step_idxes] else: constr_graph = _ConstrGraph( res, factor_infos[0][1].exts, factor_infos[1][1].exts ) res[last_step_idxes] = constr_graph constr_graph.add_edge( (factor_infos[0][1], factor_infos[1][1]), coeff=coeff, term=term_idx, eval_=eval_ ) return def _form_constred_term( self, last_step_idxes: _LastStepIdxes, biclique: _Biclique ) -> Expr: """Form the factored term for the given constriction.""" verts = biclique.constr_graph.graph.node # Form and optimize the two new summation nodes. factors = [biclique.leading_coeff] for exts_i, part_i in zip(last_step_idxes.exts, biclique.parts): scaled_terms = [ verts[i]['info'].canon.scale(j) for i, j in part_i ] exts = tuple(itertools.chain(exts_i, last_step_idxes.sums)) if len(scaled_terms) > 1: expr, eval_node = self._form_sum_interm(exts, scaled_terms) else: scaled_term = scaled_terms[0] expr, eval_node = self._form_prod_interm( exts, scaled_term.sums, [scaled_term.amp] ) factors.append(expr) self._optimize(eval_node) continue # Form the contraction node for the two new summation nodes. exts = tuple(sorted( set(itertools.chain.from_iterable(last_step_idxes.exts)), key=lambda x: default_sort_key(x[0]) )) expr, eval_node = self._form_prod_interm( exts, last_step_idxes.sums, factors ) # Make phony optimization of the intermediate. eval_node.total_cost = 1 eval_node.evals = [eval_node] return expr # # Product optimization. # def _optimize_prod(self, prod_node): """Optimize the product evaluation node. """ # This function should not be called on an already-optimized node. assert len(prod_node.evals) == 0 n_factors = len(prod_node.factors) if n_factors < 2: assert n_factors == 1 prod_node.evals.append(prod_node) sums_size = get_total_size(prod_node.sums) prod_node.total_cost = ( get_total_size(prod_node.exts) * sums_size ) if sums_size != 1 else 0 return contr_strat = self.contr_strat greedy_mode = contr_strat == ContrStrat.GREEDY normal_mode = ( contr_strat == ContrStrat.OPT or contr_strat == ContrStrat.TRAV ) exhaust_mode = contr_strat == ContrStrat.EXHAUST if_inclusive = ( contr_strat == ContrStrat.TRAV or contr_strat == ContrStrat.EXHAUST ) evals = prod_node.evals optimal_cost = None for final_cost, broken_sums, biparts_gen in self._gen_factor_biparts( prod_node ): def need_break() -> bool: """If we need to break the current loop.""" if optimal_cost is None: return False if greedy_mode: return True elif normal_mode: return final_cost > optimal_cost elif exhaust_mode: return False else: assert False if need_break(): break # Else for bipart in biparts_gen: if need_break(): break # Recurse, two parts. assert len(bipart) == 2 for i in bipart: self._optimize(i.node) continue total_cost = ( final_cost + bipart[0].node.total_cost + bipart[1].node.total_cost ) if_new_optimal = ( optimal_cost is None or optimal_cost > total_cost ) if if_new_optimal: optimal_cost = total_cost if not if_inclusive: evals.clear() if if_new_optimal or if_inclusive: new_eval = self._form_prod_eval( prod_node, broken_sums, bipart ) new_eval.total_cost = total_cost evals.append(new_eval) continue assert len(evals) > 0 prod_node.total_cost = optimal_cost return def _gen_factor_biparts(self, prod_node: _Prod): """Generate all the bipartitions of factors in a product node.""" # # Compute things invariant to different summations for performance. # exts = prod_node.exts exts_total_size = get_total_size(exts) factors = prod_node.factors sums = prod_node.sums ext2idx, sum2idx = tuple( {v[0]: j for j, v in enumerate(i)} for i in (prod_node.exts, sums) ) # Factors involving each of the summations, as iterable lists. sum_infos: typing.List[typing.List[int]] = [ [] for _ in range(len(sum2idx)) ] # Ext and sum involvements of factors. factor_infos = [[0, 0] for _ in factors] for i, v in enumerate(factors): for j in v.atoms(Symbol): if j in sum2idx: sum_idx = sum2idx[j] sum_infos[sum_idx].append(i) factor_infos[i][1] |= 1 << sum_idx elif j in ext2idx: factor_infos[i][0] |= 1 << ext2idx[j] else: pass # # Actual two-level generation. # for broken_size, broken in _gen_broken_sums(sums): broken_sums = [ v for i, v in enumerate(sums) if broken & (1 << i) ] # Sums to be retained in the evaluation. final_cost = _get_prod_final_cost( exts_total_size, broken_size ) yield final_cost, broken_sums, self._gen_biparts_w_kept_sums( prod_node, broken, sum_infos, factor_infos ) continue def _gen_biparts_w_kept_sums( self, prod_node: _Prod, broken, sum_infos, factor_infos ): """Generate all bipartitions with given summations kept. First the factors are divided into chunks indivisible according to the kept summations. Then their bipartitions which really break the broken sums are generated. """ n_factors = len(factor_infos) dsf = DSF(n_factors) for i, v in enumerate(sum_infos): if not (broken & 1 << i): dsf.union(v) continue if dsf.n_sets < 2: return # The sums, externals, and factors involved by each chunks. sums = [] factors = [] exts = [] # Map root factors to the indices of the chunk in the above lists. indices = {} index = 0 for i in dsf: root = dsf.find(i) if root not in indices: indices[root] = index index += 1 factors.append(0) sums.append(0) exts.append(0) chunk = indices[root] factors[chunk] |= 1 << i exts[chunk] |= factor_infos[i][0] sums[chunk] |= factor_infos[i][1] continue # Loop over bipartitions of the indivisible chunks. n_chunks = index for p1 in range(1, 2 ** n_chunks - 1, 2): # Get the sums in the two chunks first. sums1, sums2 = 0, 0 for i in range(n_chunks): if p1 & 1 << i: sums1 |= sums[i] else: sums2 |= sums[i] continue if all(i & broken == broken for i in (sums1, sums2)): # Only now we get the factors and the externals. factors1, factors2 = 0, 0 exts1, exts2 = 0, 0 for i in range(n_chunks): if p1 & 1 << i: factors1 |= factors[i] exts1 |= exts[i] else: factors2 |= factors[i] exts2 |= exts[i] continue yield tuple( self._form_part_interm(prod_node, broken, *i) for i in [ (exts1, sums1, factors1), (exts2, sums2, factors2) ] ) return def _form_part_interm(self, prod_node, broken, exts, sums, factors): """Form an intermediate for a partition for the given factors.""" factors_list = [ v for i, v in enumerate(prod_node.factors) if factors & 1 << i ] exts_list = [ v for i, v in enumerate(prod_node.exts) if exts & 1 << i ] sums_list = [] for i, v in enumerate(prod_node.sums): mask = 1 << i if not (sums & mask): # Sums not involved. continue elif broken & mask: exts_list.append(v) else: sums_list.append(v) continue return self._form_prod_interm(exts_list, sums_list, factors_list) def _form_prod_eval( self, prod_node: _Prod, broken_sums, parts: typing.Tuple[_Interm, ...] ): """Form an evaluation for a product node.""" assert len(parts) == 2 coeff = _UNITY factors = [] for i in parts: curr_ref = self._parse_interm_ref(i.ref) coeff *= curr_ref.coeff factors.append(curr_ref.ref) continue assert len(factors) == 2 if factors[0] == factors[1]: factors = [factors[0] ** 2] return _Prod( prod_node.base, prod_node.exts, broken_sums, coeff * prod_node.coeff, factors ) # # Optimization result verification # -------------------------------- #
[docs]def verify_eval_seq( eval_seq: typing.Sequence[TensorDef], res: typing.Sequence[TensorDef], simplify=False ) -> bool: """Verify the correctness of an evaluation sequence for the results. The last entries of the evaluation sequence should be in one-to-one correspondence with the original form in the ``res`` argument. This function returns ``True`` when the evaluation sequence is symbolically equivalent to the given raw form. When a difference is found, ``ValueError`` will be raised with relevant information. Note that this function can be very slow for large evaluations. But it is advised to be used for all optimizations in mission-critical tasks. Parameters ---------- eval_seq The evaluation sequence to verify, can be the output from :py:func:`optimize` directly. res The original result to test the evaluation sequence against. It can be the input to :py:func:`optimize` directly. simplify If simplification is going to be performed after each step of the back-substitution. It is advised for larger complex evaluations. """ n_res = len(res) n_interms = len(eval_seq) - n_res substed_eval_seq = [] defs_dict = {} for idx, eval_ in enumerate(eval_seq): base = eval_.base free_vars = eval_.rhs.free_vars curr_defs = [ defs_dict[i] for i in free_vars if i in defs_dict ] exts = {i for i, _ in eval_.exts} rhs = eval_.rhs.subst_all(curr_defs, simplify=simplify, excl=exts) new_def = TensorDef(base, eval_.exts, rhs) substed_eval_seq.append(new_def) if idx < n_interms: defs_dict[ base.label if isinstance(base, IndexedBase) else base ] = new_def continue for i, j in zip(substed_eval_seq[-n_res:], res): ref = j.simplify() if i.lhs != ref.lhs: raise ValueError( 'Unequal left-hand sides', i.lhs, 'with', ref.lhs ) diff = (i.rhs - ref.rhs).simplify() if diff != 0: raise ValueError( 'Unequal definition for ', j.lhs, j ) continue return True