Source code for gristmill.optimize

"""Optimizer for the contraction computations."""

import collections
import enum
import heapq
import itertools
import typing
import warnings

from drudge import TensorDef, prod_, Term, Range
from sympy import (
    Integer, Symbol, Expr, IndexedBase, Mul, Indexed, sympify, primitive, Wild
)
from sympy.utilities.iterables import multiset_partitions

from .utils import get_cost_key, add_costs, get_total_size, DSF


#
#  The public driver
#  -----------------
#


[docs]class Strategy(enum.Enum): """The optimization strategy for tensor contractions. This enumeration type gives possible options for the optimization strategy for tensor contractions. Supported values includes, ``GREEDY`` The contraction will be optimized greedily. This should only be used for large inputs where the other strategies cannot finish within a reasonable time. ``BEST`` The global minimum of each tensor contraction will be found by the advanced algorithm in gristmill. And only the optimal contraction(s) will be kept for the summation optimization. ``SEARCHED`` The same strategy as ``BEST`` will be attempted for the optimization of contractions. But all evaluations searched in the optimization process will be kept and considered in subsequent summation optimizations. ``ALL`` All possible contraction sequences will be considered for all contractions. This can be extremely slow. But it might be helpful for manageable problems. """ GREEDY = 0 BEST = 1 SEARCHED = 2 ALL = 3
[docs]def optimize( computs: typing.Iterable[TensorDef], substs=None, interm_fmt='tau^{}', simplify=True, strategy=Strategy.SEARCHED ) -> typing.List[TensorDef]: """Optimize the valuation of the given tensor contractions. This function will transform the given computations, given as tensor definitions, into another list computations mathematically equivalent to the given computation while requiring less floating-point operations (FLOPs). Parameters ---------- computs The computations, can be given as an iterable of tensor definitions. substs A dictionary for making substitutions inside the sizes of ranges. All the ranges need to have size in at most one undetermined variable after the substitution so that they can be totally ordered. interm_fmt The format for the names of the intermediates. simplify If the input is going to be simplified before processing. It can be disabled when the input is already simplified. strategy The optimization strategy, as explained in :py:class:`Strategy`. """ substs = {} if substs is None else substs if simplify: computs = [i.simplify() for i in computs] else: computs = list(computs) if not isinstance(strategy, Strategy): raise TypeError('Invalid optimization strategy', strategy) opt = _Optimizer( computs, substs=substs, interm_fmt=interm_fmt, strategy=strategy ) return opt.optimize()
# # The internal optimization engine # -------------------------------- # # Internal small type definitions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Small type definitions. # _Grain = collections.namedtuple('_Grain', [ 'base', 'exts', 'terms' ]) # # The information on collecting a collectible. # # Interpretation, after the substitutions given in ``substs``, the ``lr`` factor # in the evaluation ``eval_`` will be turned into ``coeff`` times the actual # collectible. # _CollectInfo = collections.namedtuple('_CollectInfo', [ 'eval_', 'lr', 'coeff', 'substs', 'ranges', 'add_cost' ]) _Ranges = collections.namedtuple('_Ranges', [ 'involved_exts', 'sums', 'other_exts' ]) _Collectible = typing.Tuple[Term, ...] _CollectInfos = typing.Dict[int, _CollectInfo] _Collectibles = typing.Dict[_Collectible, _CollectInfos] _Part = collections.namedtuple('_Part', [ 'ref', 'node' ]) # # Core evaluation DAG nodes. # class _EvalNode: """A node in the evaluation graph. """ def __init__(self, base, exts): """Initialize the evaluation node. """ self.base = base self.exts = exts self.evals = [] # type: typing.List[_EvalNode] self.total_cost = None self.n_refs = 0 self.generated = False def get_substs(self, indices): """Get substitutions and symbols requiring exclusion before indexing. """ substs = {} excl = set() assert len(indices) == len(self.exts) for i, j in zip(indices, self.exts): dumm = j[0] substs[dumm] = i excl.add(dumm) excl |= i.atoms(Symbol) continue return substs, excl class _Sum(_EvalNode): """Sum nodes in the evaluation graph.""" def __init__(self, base, exts, sum_terms): """Initialize the node.""" super().__init__(base, exts) self.sum_terms = sum_terms def __repr__(self): """Form a representation string for the node.""" return '_Sum(base={}, exts={}, sum_terms={})'.format( repr(self.base), repr(self.exts), repr(self.sum_terms) ) class _Prod(_EvalNode): """Product nodes in the evaluation graph. """ def __init__(self, base, exts, sums, coeff, factors): """Initialize the node.""" super().__init__(base, exts) self.sums = sums self.coeff = coeff self.factors = factors def __repr__(self): """Form a representation string for the node.""" return '_Prod(base={}, exts={}, coeff={}, factors={})'.format( repr(self.base), repr(self.exts), repr(self.coeff), repr(self.factors) ) # # Core optimizer class # ~~~~~~~~~~~~~~~~~~~~ # class _Optimizer: """Optimizer for tensor contraction computations. This internal optimizer can only be used once for one set of input. """ def __init__(self, computs, substs, interm_fmt, strategy): """Initialize the optimizer.""" self._prepare_grist(computs, substs) self._interm_fmt = interm_fmt self._strategy = strategy self._next_internal_idx = 0 self._interms = {} self._interms_canon = {} self._res = None def optimize(self): """Optimize the evaluation of the given computations. """ if self._res is not None: return self._res res_nodes = [self._form_node(i) for i in self._grist] for i in res_nodes: self._optimize(i) continue self._res = self._linearize(res_nodes) return self._res # # User input pre-processing. # def _prepare_grist(self, computs, substs): """Prepare tensor definitions for optimization. """ self._grist = [] self._drudge = None self._range_var = None # The only variable for range sizes. self._excl = set() self._input_ranges = {} # Substituted range to original range. # Form pre-grist, basically everything is set except the dummy variables # for external indices and summations. pre_grist = [ self._form_pre_grist(comput, substs) for comput in computs ] # Finalize grist formation by resetting the dummies. self._dumms = { k: self._drudge.dumms.value[v] for k, v in self._input_ranges.items() } self._grist = [self._reset_dumms(grain) for grain in pre_grist] return def _form_pre_grist(self, comput, substs): """Form grist from a given computation. """ curr_drudge = comput.rhs.drudge if self._drudge is None: self._drudge = curr_drudge elif self._drudge is not curr_drudge: raise ValueError( 'Invalid computations to optimize, containing two drudges', (self._drudge, curr_drudge) ) else: pass # Externals processing. exts = self._proc_sums(comput.exts, substs) ext_symbs = {i for i, _ in exts} # Terms processing. terms = [] for term in comput.rhs_terms: if not term.is_scalar: raise ValueError( 'Invalid term to optimize', term, 'expecting scalar' ) sums = self._proc_sums(term.sums, substs) amp = term.amp # Add the true free symbols to the exclusion set. self._excl |= term.free_vars - ext_symbs terms.append(Term(sums, amp, ())) continue return _Grain(base=comput.base, exts=exts, terms=terms) def _proc_sums(self, sums, substs): """Process a summation list. The ranges will be replaced with substitution sizes. Relevant members of the object will also be updated. User error will also be reported. """ res = [] for symb, range_ in sums: if not range_.bounded: raise ValueError( 'Invalid range for optimization', range_, 'expecting explicit bound' ) lower, upper = [ self._check_range_var(i.xreplace(substs), range_) for i in [range_.lower, range_.upper] ] new_range = Range(range_.label, lower=lower, upper=upper) if new_range not in self._input_ranges: self._input_ranges[new_range] = range_ elif range_ != self._input_ranges[new_range]: raise ValueError( 'Invalid ranges', (range_, self._input_ranges[new_range]), 'duplicated labels' ) else: pass res.append((symb, new_range)) continue return tuple(res) def _check_range_var(self, expr, range_) -> Expr: """Check size expression for valid symbol presence.""" range_vars = expr.atoms(Symbol) if len(range_vars) == 0: pass elif len(range_vars) == 1: range_var = range_vars.pop() if self._range_var is None: self._range_var = range_var elif self._range_var != range_var: raise ValueError( 'Invalid range', range_, 'unexpected symbol', range_var, 'conflicting with', self._range_var ) else: pass else: raise ValueError( 'Invalid range', range_, 'containing multiple symbols', range_vars ) return expr def _reset_dumms(self, grain): """Reset the dummies in a grain.""" exts, ext_substs, dummbegs = Term.reset_sums( grain.exts, self._dumms, excl=self._excl ) terms = [] for term in grain.terms: sums, curr_substs, _ = Term.reset_sums( term.sums, self._dumms, dummbegs=dict(dummbegs), excl=self._excl ) curr_substs.update(ext_substs) terms.append(term.map( lambda x: x.xreplace(curr_substs), sums=sums )) continue return _Grain(base=grain.base, exts=exts, terms=terms) # # Optimization result post-processing. # def _linearize( self, optimized: typing.Sequence[_EvalNode] ) -> typing.List[TensorDef]: """Linearize optimized forms of the evaluation. """ for node in optimized: self._set_n_refs(node) continue interms = [] res = [] for node in optimized: curr = self._linearize_node(node, interms, keep=True) assert curr is not None res.append(curr) continue return self._finalize(itertools.chain(interms, res)) def _set_n_refs(self, node: _EvalNode): """Set reference counts from an evaluation node. It is always the first evaluation that is going to be used, all rest will be removed to avoid further complications. """ if len(node.evals) == 0: self._optimize(node) assert len(node.evals) > 0 del node.evals[1:] eval_ = node.evals[0] if isinstance(eval_, _Prod): refs = [i for i in eval_.factors if self._is_interm_ref(i)] elif isinstance(eval_, _Sum): refs = eval_.sum_terms else: assert False for i in refs: _, ref = self._parse_interm_ref(i) dep = ref.base if isinstance(ref, Indexed) else ref dep_node = self._interms[dep] dep_node.n_refs += 1 self._set_n_refs(dep_node) continue return def _linearize_node(self, node: _EvalNode, res: list, keep=False): """Linearize evaluation rooted in the given node into the result. If keep if set to True, the evaluation of the given node will not be appended to the result list. """ if node.generated: return None def_, deps = self._form_def(node) for i in deps: self._linearize_node(self._interms[i], res) continue node.generated = True if not keep: res.append(def_) return def_ def _form_def(self, node: _EvalNode): """Form the final definition of an evaluation node.""" assert len(node.evals) == 1 if isinstance(node, _Prod): return self._form_prod_def(node) elif isinstance(node, _Sum): return self._form_sum_def(node) else: assert False def _form_prod_def(self, node: _Prod): """Form the final definition of a product evaluation node.""" exts = node.exts eval_ = node.evals[0] assert isinstance(eval_, _Prod) term, deps = self._form_prod_def_term(eval_) return _Grain(base=node.base, exts=exts, terms=[term]), deps def _form_prod_def_term(self, eval_: _Prod): """Form the term in the final definition of a product evaluation node. """ amp = eval_.coeff deps = [] for factor in eval_.factors: if self._is_interm_ref(factor): dep = factor.base if isinstance(factor, Indexed) else factor interm = self._interms[dep] if self._is_input(interm): # Inline trivial reference to an input. content = self._get_def(factor) assert len(content) == 1 amp *= content[0].amp else: deps.append(dep) amp *= factor else: amp *= factor return Term(eval_.sums, amp, ()), deps def _form_sum_def(self, node: _Sum): """Form the final definition of a sum evaluation node.""" exts = node.exts terms = [] deps = [] eval_ = node.evals[0] assert isinstance(eval_, _Sum) sum_terms = [] self._inline_sum_terms(eval_.sum_terms, sum_terms) for term in sum_terms: coeff, ref = self._parse_interm_ref(term) # Sum term are guaranteed to be formed from references to products, # never directly written in terms of input. term_base = ref.base if isinstance(ref, Indexed) else ref term_node = self._interms[term_base] if term_node.n_refs == 1 or self._is_input(term_node): # Inline intermediates only used here and simple input # references. eval_ = term_node.evals[0] assert isinstance(eval_, _Prod) indices = ref.indices if isinstance(ref, Indexed) else () term = self._index_prod(eval_, indices)[0] factors, term_coeff = term.get_amp_factors(self._interms) # Switch back to evaluation node for using the facilities for # product nodes. new_term, term_deps = self._form_prod_def_term(_Prod( term_node.base, exts, term.sums, coeff * term_coeff, factors )) terms.append(new_term) deps.extend(term_deps) else: terms.append(Term( (), term, () )) deps.append(term_base) continue return _Grain(base=node.base, exts=exts, terms=terms), deps def _inline_sum_terms(self, sum_terms, res): """Inline the summation terms from single-reference terms.""" for sum_term in sum_terms: coeff, ref = self._parse_interm_ref(sum_term) node = self._interms[ ref.base if isinstance(ref, Indexed) else ref ] assert len(node.evals) > 0 eval_ = node.evals[0] if_inline = isinstance(eval_, _Sum) and ( node.n_refs == 1 or len(eval_.sum_terms) == 1 ) if if_inline: if len(node.exts) == 0: substs = None else: substs = { i[0]: j for i, j in zip(eval_.exts, ref.indices) } proced_sum_terms = [ ( i.xreplace(substs) if substs is not None else sum_term ) * coeff for i in eval_.sum_terms ] self._inline_sum_terms(proced_sum_terms, res) continue else: res.append(sum_term) continue return def _is_input(self, node: _EvalNode): """Test if a product node is just a trivial reference to an input.""" if isinstance(node, _Prod): return len(node.sums) == 0 and len(node.factors) == 1 and ( not self._is_interm_ref(node.factors[0]) ) else: return False def _finalize( self, computs: typing.Iterable[_Grain] ) -> typing.List[TensorDef]: """Finalize the linearization result. Things will be cast to drudge tensor definitions, with intermediates holding names formed from the format given by user. """ next_idx = 0 substs = {} # For normal substitution of bases. repls = [] # For removed shallow intermediates def proc_amp(amp): """Process the amplitude by making the found substitutions.""" for i in reversed(repls): amp = amp.replace(*i) continue return amp.xreplace(substs) res = [] for comput in computs: base = comput.base exts = tuple((s, self._input_ranges[r]) for s, r in comput.exts) terms = [ i.map(proc_amp, sums=tuple( (s, self._input_ranges[r]) for s, r in i.sums )) for i in comput.terms ] if base in self._interms: if len(terms) == 1 and len(terms[0].sums) == 0: # Remove shallow intermediates. The saving might be too # modest to justify the additional memory consumption. repl_lhs = base[tuple( _WILD_FACTORY[i] for i, _ in enumerate(exts) )] if len(exts) > 0 else base repl_rhs = proc_amp(terms[0].amp.xreplace( {v[0]: _WILD_FACTORY[i] for i, v in enumerate(exts)} )) repls.append((repl_lhs, repl_rhs)) continue # No new intermediate added. final_base = type(base)(self._interm_fmt.format(next_idx)) next_idx += 1 substs[base] = final_base else: final_base = base res.append(TensorDef( final_base, exts, self._drudge.create_tensor(terms) )) continue return res # # Internal support utilities. # def _get_next_internal(self, symbol=False): """Get the base or symbol for the next internal intermediate. """ idx = self._next_internal_idx self._next_internal_idx += 1 cls = Symbol if symbol else IndexedBase return cls('gristmillInternalIntermediate{}'.format(idx)) @staticmethod def _write_in_orig_ranges(sums): """Write the summations in terms of undecorated bare ranges. The labels in the ranges are assumed to be decorated. """ return tuple( (i, j.replace_label(j.label[0])) for i, j in sums ) def _canon_terms(self, new_sums, terms): """Form a canonical label for a list of terms. The new summation list is prepended to the summation list of all terms. The coefficient ahead of the canonical form is returned before the canonical form. And the permuted new summation list is also returned after the canonical form. Note that this list contains the original dummies given in the new summation list, while the terms has reset new dummies. Note that the ranges in the new summation list are assumed to be decorated with labels earlier than _SUMMED. In the result, they are still in decorated forms and are guaranteed to be permuted in the same way for all given terms. The summations from the terms will be internally decorated but written in bare ranges in the final result. Note that this is definitely a poor man's version of canonicalization of multi-term tensor definitions with external indices. A lot of cases cannot be handled well. Hopefully it can be replaced with a systematic treatment some day in the future. """ new_dumms = {i for i, _ in new_sums} coeffs = [] candidates = collections.defaultdict(list) for term in terms: term, canon_sums = self._canon_term(new_sums, term) factors, coeff = term.amp_factors coeffs.append(coeff) candidates[ term.map(lambda x: prod_(factors)) ].append(canon_sums) continue # Poor man's canonicalization of external indices. # # This algorithm is not guaranteed to work. Here we just choose an # ordering of the external indices that is as safe as possible. But # certainly it is not guaranteed to work for all cases. # # TODO: Fix it! chosen = min(candidates.items(), key=lambda x: ( len(x[1]), -len(x[0].amp.atoms(Symbol) & new_dumms), x[0].sort_key )) canon_new_sums = set(chosen[1]) if len(canon_new_sums) > 1: warnings.warn( 'Internal deficiency: ' 'summation intermediate may not be fully canonicalized' ) # This could also fail when the chosen term has symmetry among the new # summations not present in any other term. This can be hard to check. canon_new_sum = canon_new_sums.pop() preferred = chosen[0].amp_factors[1] canon_coeff = _get_canon_coeff(coeffs, preferred) res_terms = [] for term in terms: canon_term, _ = self._canon_term(canon_new_sum, term, fix_new=True) # TODO: Add support for complex conjugation. res_terms.append(canon_term.map(lambda x: x / canon_coeff)) continue return canon_coeff, tuple( sorted(res_terms, key=lambda x: x.sort_key) ), canon_new_sum def _canon_term(self, new_sums, term, fix_new=False): """Canonicalize a single term. Internal method for _canon_terms, not supposed to be directly called. """ term = Term(tuple(itertools.chain( ( (v[0], v[1].replace_label((v[1].label[0], _EXT, i))) for i, v in enumerate(new_sums) ) if fix_new else new_sums, ( (i, j.replace_label((j.label, _SUMMED))) for i, j in term.sums ) )), term.amp, ()) canoned = term.canon(symms=self._drudge.symms.value) canon_sums = canoned.sums canon_orig_sums = self._write_in_orig_ranges(canon_sums) dumm_reset, _ = canoned.map( lambda x: x, sums=canon_orig_sums ).reset_dumms( dumms=self._dumms, excl=self._excl ) canon_new_sums = [] term_new_sums = [] term_sums = [] i_new = 0 for i, j in zip(dumm_reset.sums, canon_sums): if j[1].label[1] == _SUMMED: # Existing summations. term_sums.append(i) else: if fix_new: assert j[0] == new_sums[i_new][0] range_ = new_sums[i_new][1] else: range_ = j[1] canon_new_sums.append((j[0], range_)) term_new_sums.append((i[0], range_)) i_new += 1 continue return dumm_reset.map(lambda x: x, sums=tuple(itertools.chain( term_new_sums, term_sums ))), tuple(canon_new_sums) def _parse_interm_ref(self, sum_term: Expr): """Get the coefficient and pure intermediate reference in a reference. Despite being SymPy expressions, actually intermediate reference, for instance in a term in an summation node, is very rigid. """ if isinstance(sum_term, Mul): args = sum_term.args assert len(args) == 2 if self._is_interm_ref(args[1]): return args else: assert self._is_interm_ref(args[0]) return args[1], args[0] else: return _UNITY, sum_term def _is_interm_ref(self, expr: Expr): """Test if an expression is a reference to an intermediate.""" return (isinstance(expr, Indexed) and expr.base in self._interms) or ( expr in self._interms ) def _get_def(self, interm_ref: Expr) -> typing.List[Term]: """Get the definition of an intermediate reference. The intermediate reference need to be a pure intermediate reference without any factor. """ if isinstance(interm_ref, Indexed): base = interm_ref.base indices = interm_ref.indices elif isinstance(interm_ref, Symbol): base = interm_ref indices = () else: raise TypeError('Invalid intermediate reference', interm_ref) if base not in self._interms: raise ValueError('Invalid intermediate base', base) node = self._interms[base] if isinstance(node, _Sum): return self._index_sum(node, indices) elif isinstance(node, _Prod): return self._index_prod(node, indices) else: assert False def _index_sum(self, node: _Sum, indices) -> typing.List[Term]: """Substitute the external indices in the sum node""" substs, _ = node.get_substs(indices) res = [] for i in node.sum_terms: coeff, ref = self._parse_interm_ref(i.xreplace(substs)) term = self._get_def(ref)[0].scale(coeff) res.append(term) return res def _index_prod(self, node: _Prod, indices) -> typing.List[Term]: """Substitute the external indices in the evaluation node.""" substs, excl = node.get_substs(indices) # TODO: Add handling of sum intermediate reference in factors. term = Term( node.sums, node.coeff * prod_(node.factors), () ).reset_dumms( self._dumms, excl=self._excl | excl )[0].map(lambda x: x.xreplace(substs)) return [term] # # General optimization. # def _form_node(self, grain: _Grain): """Form an evaluation node from a tensor definition. """ # We assume it is fully simplified and expanded by grist preparation. exts = grain.exts terms = grain.terms if len(terms) == 0: assert False # Should be removed by grist preparation. else: return self._form_sum_from_terms(grain.base, exts, terms) def _optimize(self, node): """Optimize the evaluation of the given node. The evaluation methods will be filled with, possibly multiple, method of evaluations. """ if len(node.evals) > 0: return node if isinstance(node, _Sum): return self._optimize_sum(node) elif isinstance(node, _Prod): return self._optimize_prod(node) else: assert False def _form_prod_interm( self, exts, sums, factors ) -> typing.Tuple[Expr, _EvalNode]: """Form a product intermediate. The factors are assumed to be all non-trivial factors needing processing. """ decored_exts = tuple( (i, j.replace_label((j.label, _EXT))) for i, j in exts ) n_exts = len(decored_exts) term = Term(tuple(sums), prod_(factors), ()) coeff, key, canon_exts = self._canon_terms( decored_exts, [term] ) assert len(key) == 1 if key in self._interms_canon: base = self._interms_canon[key] else: base = self._get_next_internal(n_exts == 0) self._interms_canon[key] = base key_term = key[0] key_exts = self._write_in_orig_ranges(key_term.sums[:n_exts]) key_sums = key_term.sums[n_exts:] key_factors, key_coeff = key_term.get_amp_factors(self._interms) interm = _Prod( base, key_exts, key_sums, key_coeff, key_factors ) self._interms[base] = interm return coeff * base[tuple( i for i, _ in canon_exts )] if isinstance(base, IndexedBase) else base, self._interms[base] def _form_sum_interm(self, exts, terms) -> typing.Tuple[Expr, _EvalNode]: """Form a sum intermediate. """ decored_exts = tuple( (i, j.replace_label((j.label, _EXT))) for i, j in exts ) n_exts = len(decored_exts) coeff, canon_terms, canon_exts = self._canon_terms(decored_exts, terms) if canon_terms in self._interms_canon: base = self._interms_canon[canon_terms] else: base = self._get_next_internal(n_exts == 0) self._interms_canon[canon_terms] = base node_exts = None node_terms = [] for term in canon_terms: term_exts = self._write_in_orig_ranges(term.sums[:n_exts]) if node_exts is None: node_exts = term_exts else: assert node_exts == term_exts node_terms.append(term.map( lambda x: x, sums=term.sums[n_exts:] )) continue node = self._form_sum_from_terms(base, node_exts, node_terms) self._interms[base] = node self._optimize(node) return coeff * base[tuple( i for i, _ in canon_exts )] if isinstance(base, IndexedBase) else base, self._interms[base] def _form_sum_from_terms(self, base, exts, terms): """Form a summation node for given the terms. No processing is done in this method. """ sum_terms = [] for term in terms: sums = term.sums factors, coeff = term.amp_factors interm_ref, _ = self._form_prod_interm(exts, sums, factors) sum_terms.append(interm_ref * coeff) continue return _Sum(base, exts, sum_terms) # # Sum optimization. # def _optimize_sum(self, sum_node: _Sum): """Optimize the summation node.""" # We first optimize the common terms. exts = sum_node.exts terms, new_term_idxes = self._optimize_common_terms(sum_node) # Now we embark upon the heroic factorization. collectibles = collections.defaultdict(dict) # type: _Collectibles while True: for idx in new_term_idxes: term = terms[idx] # Loop over collectibles the new term can offer. for i, j in self._find_collectibles(exts, term): infos = collectibles[i] if idx not in infos: # The same term cannot provide the same collectible # twice. infos[idx] = j continue continue new_term_idxes.clear() to_collect, infos, total_cost = self._choose_collectible( collectibles ) if to_collect is None: break new_term_idx = self._collect(terms, infos, total_cost) new_term_idxes.append(new_term_idx) del collectibles[to_collect] for i in infos.keys(): for j in collectibles.values(): if i in j: del j[i] continue # End Main loop. rem_terms = [i for i in terms if i is not None] sum_node.evals = [_Sum( sum_node.base, sum_node.exts, rem_terms )] return def _optimize_common_terms(self, sum_node: _Sum) -> typing.Tuple[ typing.List[Expr], typing.List[int] ]: """Perform optimization of common intermediate references. """ exts_dict = dict(sum_node.exts) # Intermediate base -> (indices -> coefficient) # # This also gather terms with the same reference to deeper nodes. interm_refs = collections.defaultdict( lambda: collections.defaultdict(lambda: 0) ) for term in sum_node.sum_terms: coeff, ref = self._parse_interm_ref(term) if isinstance(ref, Symbol): base = ref indices = () elif isinstance(ref, Indexed): base = ref.base indices = ref.indices else: assert False interm_refs[base][indices] += coeff continue # Intermediate referenced only once goes to the result directly and wait # to be factored, others wait to be pulled and do not participate in # factorization. res_terms = [] res_collectible_idxes = [] # Indices, coeffs tuple -> base, coeff pull_info = collections.defaultdict(list) for k, v in interm_refs.items(): if len(v) == 0: assert False elif len(v) == 1: res_collectible_idxes.append(len(res_terms)) indices, coeff = v.popitem() res_terms.append( (k[indices] if len(indices) > 0 else k) * coeff ) else: # Here we use name for sorting directly, since here we cannot # have general expressions hence no need to use the expensive # sort_key. raw = list(sorted(v.items(), key=lambda x: tuple( i.name for i in x[0] ))) leading_coeff = raw[0][1] pull_info[tuple( (i, j / leading_coeff) for i, j in raw )].append((k, leading_coeff)) # Now we treat the terms from which new intermediates might be pulled # out. for k, v in pull_info.items(): pivot = k[0][0] assert k[0][1] == 1 if len(v) == 1: # No need to form a new intermediate. base, coeff = v[0] pivot_ref = base[pivot] * coeff else: # We need to form an intermediate here. interm_exts = tuple( (i, exts_dict[i]) for i in pivot ) pivot_ref, interm_node = self._form_sum_interm(interm_exts, [ term.scale(coeff) for base, coeff in v for term in self._get_def(base[pivot]) ]) self._optimize(interm_node) for indices, coeff in k: substs = { i: j for i, j in zip(pivot, indices) } res_terms.append( pivot_ref.xreplace(substs) * coeff / k[0][1] ) continue continue return res_terms, res_collectible_idxes def _find_collectibles(self, exts, term): """Find the collectibles from a given term. Collectibles are going to be yielded as key and infos pairs. """ coeff, ref = self._parse_interm_ref(term) res = [] # type: typing.List[typing.Tuple[_Collectible, _CollectInfo]] if coeff != 1 and coeff != -1: # TODO: Add attempt to collect the coefficient. # # This could give some minor saving. pass prod_node = self._interms[ ref.base if isinstance(ref, Indexed) else ref ] self._optimize(prod_node) if len(prod_node.factors) > 1: # Single-factor does not offer collectible, # collectible * (something + 1) is so rare in real applications. for eval_i in prod_node.evals: res.extend(self._find_collectibles_eval( exts, ref, eval_i, prod_node.total_cost )) continue return res def _find_collectibles_eval( self, exts, ref: Expr, eval_: _Prod, opt_cost ): """Get the collectibles for a particular evaluations of a product.""" # To begin, we first need to substitute the external indices in for this # particular evaluation inside its ambient. total_cost = eval_.total_cost assert total_cost is not None if len(eval_.exts) == 0: assert isinstance(ref, Symbol) else: assert isinstance(ref, Indexed) eval_terms = self._index_prod(eval_, ref.indices) assert len(eval_terms) == 1 eval_term = eval_terms[0] factors, coeff = eval_term.get_amp_factors(self._interms) eval_ = _Prod( _SUBSTED_EVAL_BASE, exts, eval_term.sums, coeff, factors ) eval_.total_cost = total_cost sums = eval_.sums factors = eval_.factors assert len(factors) == 2 # Each evaluation could give two collectibles. res = [] for lr in range(2): factor = factors[lr] collectible, ranges, coeff, substs = self._get_collectible_interm( exts, sums, factor ) res.append((collectible, _CollectInfo( eval_=eval_, lr=lr, coeff=coeff, substs=substs, ranges=ranges, add_cost=eval_.total_cost - opt_cost ))) continue return res def _get_collectible_interm(self, exts, sums, interm_ref): """Get a collectible from an intermediate reference.""" terms = self._get_def(interm_ref) involved_symbs = interm_ref.atoms(Symbol) involved_exts = [] other_exts = [] for i, v in enumerate(exts): symb, range_ = v if symb in involved_symbs: involved_exts.append(( symb, range_.replace_label((range_.label, _EXT, i)) )) else: other_exts.append((symb, range_)) # Undecorated. continue involved_sums = [] for i, j in sums: # Sums not involved in both should be pushed in. assert i in involved_symbs involved_sums.append(( i, j.replace_label((j.label, _SUMMED_EXT)) )) continue coeff, key, all_sums = self._canon_terms( tuple(itertools.chain(involved_exts, involved_sums)), terms ) ranges = _Ranges( involved_exts=self._write_in_orig_ranges(involved_exts), sums=self._write_in_orig_ranges(involved_sums), other_exts=other_exts ) new_sums = (i for i in all_sums if i[1].label[1] == _SUMMED_EXT) return key, ranges, coeff, { i[0]: j[0] for i, j in zip(involved_sums, new_sums) } def _choose_collectible(self, collectibles: _Collectibles): """Choose the most profitable collectible factor. The collectible, its infos, and the final cost of the evaluation after the collection will be returned. """ with_saving = ( i for i in collectibles.items() if len(i[1]) > 1 ) optimal = None new_total_cost = None largest_saving = None for collectible, infos in with_saving: # Any range is sufficient for the determination of savings. raw_saving = self._get_collectible_saving( next(iter(infos.values())).ranges ) saving = raw_saving - sum( i.add_cost for i in infos.values() ) saving_key = get_cost_key(saving) if_save = len(saving_key[1]) > 0 and saving_key[1][0] > 0 if_better = ( largest_saving is None or saving_key > largest_saving[1] ) if if_save and if_better: largest_saving = (saving, saving_key) optimal = (collectible, infos) orig_cost = sum(i.eval_.total_cost for i in infos.values()) new_total_cost = orig_cost - raw_saving continue if optimal is None: return None, None, None else: return optimal[0], optimal[1], new_total_cost @staticmethod def _get_collectible_saving(ranges: _Ranges) -> Expr: """Get the saving factor for a collectible.""" other_size = get_total_size(ranges.other_exts) sum_size = get_total_size(ranges.sums) ext_size = get_total_size(ranges.involved_exts) return other_size * add_costs( 2 * sum_size * ext_size, ext_size, -sum_size ) def _collect(self, terms, collect_infos: _CollectInfos, new_cost): """Collect the given collectible factor. This function will mutate the given terms list. Set one of the collected terms to the new sum term, whose index is going to be returned, with all the rest collected terms set to None. """ residue_terms = [] residue_exts = None new_term_idx = min(collect_infos.keys()) for k, v in collect_infos.items(): coeff, _ = self._parse_interm_ref(terms[k]) eval_ = v.eval_ coeff *= eval_.coeff * v.coeff # Three levels of coefficients. residue_terms.extend( i.map(lambda x: coeff * x) for i in self._get_def( eval_.factors[0 if v.lr == 1 else 1].xreplace(v.substs) ) ) curr_exts = tuple( itertools.chain(v.ranges.other_exts, v.ranges.sums), ) if residue_exts is None: residue_exts = curr_exts else: assert residue_exts == curr_exts continue new_ref, _ = self._form_sum_interm(residue_exts, residue_terms) for k, v in collect_infos.items(): if k == new_term_idx: terms[k] = self._form_collected(terms[k], v, new_ref, new_cost) else: terms[k] = None return new_term_idx def _form_collected( self, term, info: _CollectInfo, new_ref, new_cost ) -> Expr: """Form new sum term with some factors collected based on a term. """ eval_ = info.eval_ collected_factor = eval_.factors[info.lr].xreplace(info.substs) interm_coeff, interm = self._parse_interm_ref(new_ref) coeff = interm_coeff / info.coeff _, orig_ref = self._parse_interm_ref(term) orig_node = self._interms[ orig_ref.base if isinstance(orig_ref, Indexed) else orig_ref ] orig_exts = orig_node.exts base = self._get_next_internal(len(orig_exts) == 0) new_node = _Prod(base, orig_exts, eval_.sums, coeff, [ collected_factor, interm ]) new_node.total_cost = new_cost new_node.evals = [new_node] self._interms[base] = new_node return ( base[tuple(i for i, _ in orig_exts)] if isinstance(base, IndexedBase) else base ) # # Product optimization. # def _optimize_prod(self, prod_node): """Optimize the product evaluation node. """ assert len(prod_node.evals) == 0 n_factors = len(prod_node.factors) if n_factors < 2: assert n_factors == 1 prod_node.evals.append(prod_node) prod_node.total_cost = self._get_prod_final_cost( get_total_size(prod_node.exts), get_total_size(prod_node.sums) ) return strategy = self._strategy evals = prod_node.evals optimal_cost = None for final_cost, broken_sums, parts_gen in self._gen_factor_parts( prod_node ): def need_break(): """If we need to break the current loop.""" if strategy == Strategy.GREEDY: return True elif strategy == Strategy.BEST or strategy == Strategy.SEARCHED: return get_cost_key(final_cost) > optimal_cost[0] elif strategy == Strategy.ALL: return False else: assert False if (optimal_cost is not None) and need_break(): break # Else for parts in parts_gen: # Recurse, two parts. assert len(parts) == 2 for i in parts: self._optimize(i.node) continue total_cost = ( final_cost + parts[0].node.total_cost + parts[1].node.total_cost ) total_cost_key = get_cost_key(total_cost) if_new_optimal = ( optimal_cost is None or optimal_cost[0] > total_cost_key ) if if_new_optimal: optimal_cost = (total_cost_key, total_cost) if self._strategy == Strategy.BEST: evals.clear() # New optimal is always added. def need_add_eval(): """If the current evaluation need to be added.""" if self._strategy == Strategy.BEST: return total_cost_key == optimal_cost[0] else: return True if if_new_optimal or need_add_eval(): new_eval = self._form_prod_eval( prod_node, broken_sums, parts ) new_eval.total_cost = total_cost evals.append(new_eval) continue assert len(evals) > 0 prod_node.total_cost = optimal_cost[1] return def _gen_factor_parts(self, prod_node: _Prod): """Generate all the partitions of factors in a product node.""" # Compute things invariant to different summations for performance. exts = prod_node.exts exts_total_size = get_total_size(exts) factor_atoms = [ i.atoms(Symbol) for i in prod_node.factors ] sum_involve = [ {j for j, v in enumerate(factor_atoms) if i in v} for i, _ in prod_node.sums ] dumm2index = tuple( {v[0]: j for j, v in enumerate(i)} for i in [prod_node.exts, prod_node.sums] ) # Indices of external and internal dummies involved by each factors. factor_infos = [ tuple( set(i[j] for j in atoms if j in i) for i in dumm2index ) for atoms in factor_atoms ] # Actual generation. for broken_size, kept in self._gen_kept_sums(prod_node.sums): broken_sums = [i for i, j in zip(prod_node.sums, kept) if not j] final_cost = self._get_prod_final_cost( exts_total_size, broken_size ) yield final_cost, broken_sums, self._gen_parts_w_kept_sums( prod_node, kept, sum_involve, factor_infos ) continue @staticmethod def _gen_kept_sums(sums): """Generate kept summations in increasing size of broken summations. The results will be given as boolean array giving if the corresponding entry is to be kept. """ sizes = [i.size for _, i in sums] n_sums = len(sizes) def get_size(kept): """Wrap the kept summation with its size.""" size = sympify(prod_( i for i, j in zip(sizes, kept) if not j )) return get_cost_key(size), size, kept init = [True] * n_sums # Everything is kept. queue = [get_size(init)] while len(queue) > 0: curr = heapq.heappop(queue) yield curr[1], curr[2] curr_kept = curr[2] for i in range(n_sums): if curr_kept[i]: new_kept = list(curr_kept) new_kept[i] = False heapq.heappush(queue, get_size(new_kept)) continue else: break continue def _gen_parts_w_kept_sums( self, prod_node: _Prod, kept, sum_involve, factor_infos ): """Generate all partitions with given summations kept. First we the factors are divided into chunks indivisible according to the kept summations. Then their bipartitions which really break the broken sums are generated. """ dsf = DSF(i for i, _ in enumerate(factor_infos)) for i, j in zip(kept, sum_involve): if i: dsf.union(j) continue chunks = dsf.sets if len(chunks) < 2: return for part in self._gen_parts_from_chunks(kept, chunks, sum_involve): assert len(part) == 2 yield tuple( self._form_part(prod_node, i, sum_involve, factor_infos) for i in part ) return @staticmethod def _gen_parts_from_chunks(kept, chunks, sum_involve): """Generate factor partitions from chunks. Here special care is taken to respect the broken summations in the result. """ n_chunks = len(chunks) for chunks_part in multiset_partitions(n_chunks, m=2): factors_part = tuple(set( factor_i for chunk_i in chunk_part_i for factor_i in chunks[chunk_i] ) for chunk_part_i in chunks_part) for i, v in enumerate(kept): if v: continue # Now we have broken sum, it need to be involved by both parts. involve = sum_involve[i] if any(part.isdisjoint(involve) for part in factors_part): break else: yield factors_part def _form_part(self, prod_node, factor_idxes, sum_involve, factor_infos): """Form a partition for the given factors.""" involved_exts, involved_sums = [ set.union(*[factor_infos[i][label] for i in factor_idxes]) for label in [0, 1] ] factors = [prod_node.factors[i] for i in factor_idxes] exts = [ v for i, v in enumerate(prod_node.exts) if i in involved_exts ] sums = [] for i, v in enumerate(prod_node.sums): if sum_involve[i].isdisjoint(factor_idxes): continue elif sum_involve[i] <= factor_idxes: sums.append(v) else: exts.append(v) continue ref, node = self._form_prod_interm(exts, sums, factors) return _Part(ref=ref, node=node) @staticmethod def _get_prod_final_cost(exts_total_size, sums_total_size) -> Expr: """Compute the final cost for a pairwise product evaluation.""" if sums_total_size == 1: return exts_total_size else: return _TWO * exts_total_size * sums_total_size def _form_prod_eval( self, prod_node: _Prod, broken_sums, parts: typing.Tuple[_Part, ...] ): """Form an evaluation for a product node.""" assert len(parts) == 2 coeff = _UNITY factors = [] for i in parts: curr_coeff, curr_ref = self._parse_interm_ref(i.ref) coeff *= curr_coeff factors.append(curr_ref) continue return _Prod( prod_node.base, prod_node.exts, broken_sums, coeff * prod_node.coeff, factors ) # # Utility constants. # _UNITY = Integer(1) _NEG_UNITY = Integer(-1) _TWO = Integer(2) _EXT = 0 _SUMMED_EXT = 1 _SUMMED = 2 _SUBSTED_EVAL_BASE = Symbol('gristmillSubstitutedEvalBase') # # Utility static functions. # class _SymbFactory(dict): """A small symbol factory.""" def __missing__(self, key): return Symbol('gristmillInternalSymbol{}'.format(key)) _SYMB_FACTORY = _SymbFactory() class _WildFactory(dict): """A small wild symbol factory.""" def __missing__(self, key): return Wild('gristmillInternalWild{}'.format(key)) _WILD_FACTORY = _WildFactory() def _get_canon_coeff(coeffs, preferred): """Get the canonical coefficient from a list of coefficients.""" coeff, _ = primitive(sum( v * _SYMB_FACTORY[i] for i, v in enumerate(coeffs) )) # The primitive computation does not take phase into account. n_neg = 0 n_pos = 0 for i in coeffs: if i.has(_NEG_UNITY) or i.is_negative: n_neg += 1 else: n_pos += 1 continue if n_neg > n_pos: phase = _NEG_UNITY elif n_pos > n_neg: phase = _UNITY else: preferred_phase = ( _NEG_UNITY if preferred.has(_NEG_UNITY) or preferred.is_negative else _UNITY ) phase = preferred_phase return coeff * phase # # Optimization result verification # -------------------------------- #
[docs]def verify_eval_seq( eval_seq: typing.Sequence[TensorDef], res: typing.Sequence[TensorDef], simplify=False ) -> bool: """Verify the correctness of an evaluation sequence for the results. The last entries of the evaluation sequence should be in one-to-one correspondence with the original form in the ``res`` argument. This function returns ``True`` when the evaluation sequence is symbolically equivalent to the given raw form. When a difference is found, ``ValueError`` will be raised with relevant information. Note that this function can be very slow for large evaluations. But it is advised to be used for all optimizations in mission-critical tasks. Parameters ---------- eval_seq The evaluation sequence to verify, can be the output from :py:func:`optimize` directly. res The original result to test the evaluation sequence against. It can be the input to :py:func:`optimize` directly. simplify If simplification is going to be performed after each step of the back-substitution. It is advised for larger complex evaluations. """ n_res = len(res) n_interms = len(eval_seq) - n_res substed_eval_seq = [] defs_dict = {} for idx, eval_ in enumerate(eval_seq): base = eval_.base free_vars = eval_.rhs.free_vars curr_defs = [ defs_dict[i] for i in free_vars if i in defs_dict ] rhs = eval_.rhs.subst_all(curr_defs, simplify=simplify) new_def = TensorDef(base, eval_.exts, rhs) substed_eval_seq.append(new_def) if idx < n_interms: defs_dict[ base.label if isinstance(base, IndexedBase) else base ] = new_def continue for i, j in zip(substed_eval_seq[-n_res:], res): if i.lhs != j.lhs: raise ValueError( 'Unequal left-hand sides', i.lhs, 'with', j.lhs, 'for', j ) diff = (i.rhs - j.rhs).simplify() if diff != 0: raise ValueError( 'Unequal definition for ', j.lhs, j ) continue return True