"""Optimizer for the contraction computations."""
import collections
import heapq
import itertools
import types
import typing
import warnings
from drudge import TensorDef, prod_, Term, Range, sum_
from sympy import (
Integer, Symbol, Expr, IndexedBase, Mul, Indexed, primitive, Wild,
default_sort_key, Pow
)
from sympy.utilities.iterables import multiset_partitions
from .utils import (
Size, get_total_size, mul_sizes, DSF, Tuple4Cmp, form_sized_range
)
#
# The public driver
# -----------------
#
[docs]class Strategy:
"""The optimization strategy for tensor contractions.
This class holds possible options for different aspects of the optimization
strategy for tensor contractions. Options for different aspects of the
problem should be combined by using the bitwise-or ``|`` operator.
For the optimization of the single-term contractions, we have
``GREEDY``
The contraction within each term will be optimized greedily. This
should only be used for inputs having terms containing many factors by a
very dense pattern.
``BEST``
The global minimum of each tensor contraction will be found by the
advanced algorithm in gristmill. And only the optimal contraction(s)
will be kept for the summation optimization.
``SEARCHED``
The same strategy as ``BEST`` will be attempted for the optimization of
contractions. But all evaluations searched in the optimization process
will be kept and considered in subsequent summation optimizations.
``ALL``
All possible contraction sequences will be considered for all terms.
This can be extremely slow. But it might be helpful for problems having
terms all with manageable number of factors.
For the summation factorization, we have
``SUM``
Factorize the summations in the result.
``INACCURATE``
Do not accurately calculate the saving in summation optimization. This
will skip the exact arithmetic for the costs and use a special heuristic
the estimate the actual saving.
For the common factor optimization, we have
``COMMON``
Skip computation of the same factor up to permutation of indices in
summations.
We also have the default optimization strategy as ``DEFAULT``, which will be
``SEARCHED | SUM | COMMON``.
"""
GREEDY = 0
BEST = 1
SEARCHED = 2
ALL = 3
PROD_MASK = 0b11
SUM = 1 << 2
INACCURATE = 1 << 3
COMMON = 1 << 4
# Internal options, not useful for users. If evaluations with negative
# local/global saving will be considered. They turn out to be not quite
# useful and is pending removal.
RUSH_LOCAL = 1 << 5
RUSH_GLOBAL = 1 << 6
DEFAULT = SEARCHED | SUM | COMMON
MAX = 1 << 7
[docs]def optimize(
computs: typing.Iterable[TensorDef], substs=None, interm_fmt='tau^{}',
simplify=True, strategy=Strategy.DEFAULT, greedy_cutoff=-1,
drop_cutoff=-1
) -> typing.List[TensorDef]:
"""Optimize the valuation of the given tensor contractions.
This function will transform the given computations, given as tensor
definitions, into another list of computations mathematically equivalent to
the given computation while requiring less floating-point operations
(FLOPs).
Parameters
----------
computs
The computations, can be given as an iterable of tensor definitions.
substs
A dictionary for making substitutions inside the sizes of ranges. All
the ranges need to have size in at most one undetermined variable after
the substitution, so that they can be totally ordered. When one symbol
still remains in the sizes, the asymptotic cost (scaling and prefactor)
will be optimized. Or when all symbols are gone after the substitution,
optimization is going to be based on the numeric sizes. Numeric sizes
tend to make the optimization faster due to the usage of built-in
integer or floating point arithmetic in lieu of the more complex
polynomial arithmetic.
interm_fmt
The format for the names of the intermediates.
simplify
If the input is going to be simplified before processing. It can be
disabled when the input is already simplified.
strategy
The optimization strategy, as explained in :py:class:`Strategy`.
greedy_cutoff
The depth cutoff for making greedy selection in summation optimization.
Beyond this depth in the recursion tree (inclusive), only the choices
making locally best saving will be considered. With negative values,
full Bron-Kerbosch backtracking is performed.
drop_cutoff
The depth cutoff for picking only a random one with greedy saving in
summation optimization. The difference with the option
``greedy_cutoff`` is that here only **one** choice giving the locally
best saving will be considered, rather than all of them. This could
give better acceleration than ``greedy_cutoff`` at the presence of large
degeneracy, while results could be less optimized. For large inputs, a
value of ``2`` is advised.
"""
substs = {} if substs is None else substs
computs = [
i.simplify() if simplify else i.reset_dumms()
for i in computs
]
if len(computs) == 0:
raise ValueError('No computation is given!')
if not isinstance(strategy, int) or strategy >= Strategy.MAX:
raise TypeError('Invalid optimization strategy', strategy)
opt = _Optimizer(
computs, substs=substs, interm_fmt=interm_fmt, strategy=strategy,
greedy_cutoff=greedy_cutoff, drop_cutoff=drop_cutoff
)
return opt.optimize()
#
# The internal optimization engine
# --------------------------------
#
# General small type definitions and functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# These named tuples should be upgraded when PySpark has support for Python 3.6
# in their stable version.
#
# For general optimization.
#
_Grain = collections.namedtuple('_Grain', [
'base',
'exts',
'terms'
])
_IntermRef = collections.namedtuple('_IntermRef', [
'coeff',
'base',
'indices',
'power'
])
def _get_ref_from_interm_ref(self: _IntermRef):
"""Get the reference to intermediate without coefficient."""
return _index(self.base, self.indices) ** self.power
_IntermRef.ref = property(_get_ref_from_interm_ref)
# Symbol/range pairs.
#
# This type is mostly for the convenience of annotation.
_SrPairs = typing.Sequence[typing.Tuple[Symbol, Range]]
#
# Internals for summation and product optimization
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Summation optimization.
#
# Organized references to products in a summation.
#
# Intermediate base -> (indices -> coefficient)
_OrgTerms = typing.DefaultDict[
Symbol, typing.DefaultDict[typing.Tuple[Expr, ...], Expr]
]
#
# Static description of the collection graph.
#
_LEFT = 0
_RIGHT = 1
_OPPOS = {
_LEFT: _RIGHT,
_RIGHT: _LEFT
}
# For type annotation, actually is should be ``_LEFT | _RIGHT`` in Haskell
# algebraic data type notation.
_LR = int
_Ranges = collections.namedtuple('_Ranges', [
'exts',
'sums'
])
_Edge = collections.namedtuple('_Edge', [
'term',
'eval_',
'base',
'coeff',
'exc_cost',
])
_Adjs = typing.Tuple[
typing.Dict[Term, typing.Dict[Term, _Edge]],
typing.Dict[Term, typing.Dict[Term, _Edge]]
]
class _BaseInfo:
"""Information about a base referenced in a sum node.
This is an open struct, with most of its manipulation done inside the
optimizer class.
"""
__slots__ = [
'count',
'cost'
]
def __init__(self, cost):
"""Initialize the information.
The count will be initialized to one.
"""
self.count = 1
self.cost = cost
class _BaseInfoDict(dict):
"""Mapping from symbol of bases to its information.
Symbol -> _BaseInfo.
"""
def add_base(self, base, cost):
"""Add the given base."""
if base in self:
self[base].count += 1
else:
self[base] = _BaseInfo(cost)
return
def remove_terms(self, terms, term_base):
"""Remove the terms from base information dictionary.
The bases that have been updated by this will be returned in a
set.
"""
updated = set()
for i in terms:
base = term_base[i]
updated.add(base)
if self[base].count > 1:
self[base].count -= 1
else:
del self[base]
return updated
#
# Intermediate data and results for the Kron-Kerbosch process.
#
# Additional information about a node when it is used to augment the current
# biclique.
_Delta = collections.namedtuple('_Delta', [
'coeff',
'terms',
'bases',
'exc_cost',
'saving'
])
# Dictionary of the nodes that can possibly to used to augment the current
# biclique. To be used for variables like ``subg`` and ``cand`` in the
# Bron-Kerbosch algorithm.
_Nodes = typing.Dict[
typing.Tuple[int, Term], typing.Optional[_Delta]
]
_Biclique = collections.namedtuple('_Biclique', [
'nodes', # Left and right, nodes and coefficients.
'leading_coeff',
'terms',
'saving'
])
#
# Cost-related utilities for the Kron-Kerbosch process.
#
# These coefficients cached here can make the computation of the saving of a
# biclique easy and fast.
_CostCoeffs = collections.namedtuple('_CostCoeffs', [
# The final cost for contraction and make an addition of the results.
'final',
# The cost of making an addition for left and right factors.
'preps'
])
def _get_cost_coeffs(ranges: _Ranges) -> _CostCoeffs:
"""Get the cost coefficients for the given ranges."""
ext_size = get_total_size(itertools.chain.from_iterable(
ranges.exts
))
final = _get_prod_final_cost(
ext_size, get_total_size(ranges.sums)
) + ext_size
preps = tuple(
get_total_size(itertools.chain(i, ranges.sums))
for i in ranges.exts
)
return _CostCoeffs(final=final, preps=preps)
_Saving = collections.namedtuple('_Saving', [
# Total current saving.
'saving',
# Additional saving when one more left/right factor is collected.
'deltas'
])
def _get_collect_saving(coeffs: _CostCoeffs, n_s: typing.Sequence[int]):
"""Get the saving for collection.
For the given ranges, when we make a collection of the given number of left
factors and the given number of right factors, we have saving,
.. math::
n_l n_r C(s) e_l e_r s + (n_l n_r - 1) e_l e_r
- (n_l - 1) e_l s - (n_r - 1) e_r s - C(s) e_l e_r s
where :math:`C(s)` equals one for no summation and two for the presence of
summations. It also equals
.. math::
(n_l n_r - 1) (C(s) e_l e_r s + e_l e_r)
- (n_l - 1) e_l s - (n_r - 1) e_r s
When we collect terms with :math:`n_l`, it reads,
.. math::
n_l (
n_r C(s) e_l e_r s + n_r e_l e_r - e_l s
)
- n_r e_r s
+ e_l s + e_r s - e_l e_r - C(s) e_l e_r s
or symmetrically
.. math::
n_r (
n_l C(s) e_l e_r s + n_l e_l e_r - e_r s
)
- n_l e_l s
+ e_l s + e_r s - e_l e_r - C(s) e_l e_r s
"""
assert len(n_s) == 2
assert len(coeffs.preps) == 2
n_terms = mul_sizes(n_s)
saving = (n_terms - 1) * coeffs.final - sum(
(i - 1) * j
for i, j in zip(n_s, coeffs.preps)
)
deltas = []
# for i, j in zip(reversed(n_s), coeffs.preps):
for i, v in enumerate(coeffs.preps):
o = 1 if i == 0 else 0
if n_s[i] == 0:
# This could allow bicliques empty in a direction to be augmented by
# any left or right term. A value of infinity has to be used to
# mask the possible non-zero excess costs.
deltas.append(float('inf'))
elif n_s[o] == 0:
# This prevents a dimension get expanded without anything.
deltas.append(-float('inf'))
else:
deltas.append(n_s[o] * coeffs.final - v)
continue
return _Saving(saving=saving, deltas=tuple(deltas))
#
# The core classes.
#
class _BronKerbosch:
"""Iterable for the maximal bicliques."""
def __init__(
self, adjs: _Adjs, base_infos: _BaseInfoDict,
ranges: _Ranges, greedy_cutoff=-1, drop_cutoff=-1,
rush_local=False, rush_global=False, inaccurate=False
):
"""Initialize the iterator."""
# Static data during the recursion.
self._adjs = adjs
self._base_infos = base_infos
self._cost_coeffs = _get_cost_coeffs(ranges)
self._greedy_cutoff = greedy_cutoff
self._drop_cutoff = drop_cutoff
self._rush_local = rush_local
self._rush_global = rush_global
self._inaccurate = inaccurate
# Dynamic data during the recursion.
#
# Nodes and coefficients, for left and right.
self._curr = (
([], []),
([], [])
)
# The set of terms currently in the biclique.
self._terms = set() # type: typing.Set[Symbol]
# The count of bases in the **uncollected** terms.
self._bases = collections.Counter()
for k, v in base_infos.items():
self._bases[k] = v.count
# The stack of excess costs.
self._exc_costs = []
# The leading coefficient.
self._leading_coeff = None
def __iter__(self):
"""Iterate over the maximal bicliques."""
# All left and right nodes.
nodes = {
(i, j): _Delta(
coeff=_UNITY, terms=set(), bases={}, exc_cost=0, saving=0
)
for i, v in enumerate(self._adjs)
for j in v.keys()
}
assert len(nodes) > 0
yield from self._expand(nodes, dict(nodes), dict(nodes), dict(nodes))
# If things all goes correctly, the stack should be reverted to initial
# state by now.
for i in self._curr:
for j in i:
assert len(j) == 0
continue
continue
assert len(self._terms) == 0
for k, v in self._bases.items():
assert v == self._base_infos[k].count
assert len(self._exc_costs) == 0
assert self._leading_coeff is None
return
def _expand(
self, subg: _Nodes, curr_subg: _Nodes,
cand: _Nodes, curr_cand: _Nodes
):
"""Generate the bicliques from the current state.
This is the core of the Bron-Kerbosch algorithm.
"""
exc_costs = self._exc_costs
depth = len(exc_costs)
# The current state.
curr = self._curr
terms = self._terms
bases = self._bases
inaccurate = self._inaccurate
# The code here are adapted from the code in NetworkX for maximal clique
# problem of simple general graphs. The original code are kept as much
# as possible and put in comments. The original code on which the code
# is based can be found at,
#
# https://github.com/networkx/networkx/blob
# /48f4b5736174844c77044fae90e3e7adf1dabc10/networkx/algorithms
# /clique.py#L277-L299
#
#
# u = max(subg, key=lambda u: len(cand & adj[u]))
#
# Here it is very expensive to make sure that a node can be a pivot.
# Hence currently we do not perform it here.
# Recursion is stopped earlier than here.
assert len(curr_subg) > 0
to_loop = curr_cand.items()
if len(to_loop) == 0:
return
cut_greedy = (
0 <= self._greedy_cutoff <= depth
)
cut_full = (
0 <= self._drop_cutoff <= depth
)
if cut_greedy or cut_full:
greedy_saving = max(i[1].saving for i in to_loop)
to_loop = (
i for i in to_loop if i[1].saving == greedy_saving
)
if cut_full:
to_loop = [next(to_loop)]
#
# for q in cand - adj[u]:
#
for q, delta in to_loop:
#
# cand.remove(q)
#
colour, node = q
assert q in cand
del cand[q]
#
# Q.append(q)
#
new_terms = delta.terms
new_bases = delta.bases
curr[colour][0].append(node)
curr[colour][1].append(delta.coeff)
assert terms.isdisjoint(delta.terms)
terms |= new_terms
bases.subtract(new_bases)
exc_costs.append(delta.exc_cost)
oppos = _OPPOS[colour]
if len(curr[colour][0]) == 1 and len(curr[oppos][0]) > 0:
leading_edge = self._adjs[colour][node][
curr[oppos][0][0]
]
assert self._leading_coeff is None
self._leading_coeff = leading_edge.coeff
ns, saving = self._count_stack(inaccurate=inaccurate)
#
# adj_q = adj[q]
# subg_q = subg & adj_q
#
subg_q, curr_subg_q = self._filter_nodes(
subg, saving, colour, node
)
#
# if not subg_q:
# yield Q[:]
#
if len(curr_subg_q) == 0:
# These cases cannot possibly give saving.
if_skip = any(i == 0 for i in ns) or all(
i == 1 for i in ns
)
if not if_skip:
# The total saving.
if inaccurate:
saving = saving.saving
has_saving = saving > 0
assert has_saving
else:
saving = saving.saving - sum(exc_costs)
if not self._rush_global:
for k, v in self._bases.items():
if v > 0:
saving -= self._base_infos[k].cost * (
self._base_infos[k].count - v
)
has_saving = saving > 0
if has_saving:
yield _Biclique(
nodes=curr, leading_coeff=self._leading_coeff,
terms=terms, saving=saving
)
else:
#
# cand_q = cand & adj_q
#
cand_q, curr_cand_q = self._filter_nodes(
cand, saving, colour, node
)
# if cand_q:
# for clique in expand(subg_q, cand_q):
# yield clique
if len(curr_cand_q) > 0:
yield from self._expand(
subg_q, curr_subg_q, cand_q, curr_cand_q
)
#
# Q.pop()
#
for i in curr[colour]:
i.pop()
terms -= new_terms
bases.update(new_bases)
exc_costs.pop()
if len(curr[colour][0]) == 0:
self._leading_coeff = None
def _filter_nodes(
self, nodes: _Nodes,
saving: _Saving, new_colour: _LR, new_node: Term
) -> typing.Tuple[_Nodes, _Nodes]:
"""Filter the nodes for the current stack.
In the original Bron-Kerbosch algorithm, both subg and cand are filtered
by union with the adjacent nodes of the newly added node. Now the
computation can be a lot more complex than that. We need to note,
1. No term already contained can be decomposed in another way in a
different evaluation.
2. The coefficients need to match the existing proportion.
We also have less to note in that we do not require any connectivity
among nodes with the same colour.
Here all expandable nodes and the profitable ones among them for the
current step will be returned. The profitable nodes for the current
step contains only the nodes that is profitable right now. The all
expandable nodes has all nodes that are valid to be augmented into the
current stack.
"""
all_ = {}
for node_key, delta in nodes.items():
colour, node = node_key
res = self._update_delta(
new_colour, new_node, saving, colour, node, delta
)
if res is not None:
all_[node_key] = res
continue
curr = {k: v for k, v in all_.items() if v.saving > 0}
return all_, curr
def _update_delta(
self, new_colour, new_node, saving, colour, node, delta
) -> typing.Optional[_Delta]:
"""Update the delta when a new node is added to the stack.
This is the performance bottleneck of the Bron-Kerbosch algorithm.
"""
adj = self._adjs[colour][node]
inaccurate = self._inaccurate
# Most basic filtering. The node with the same colour as the new node
# will not be affected by the new addition.
if colour != new_colour and new_node not in adj:
return None
oppos_curr = self._curr[_OPPOS[colour]]
if len(oppos_curr[0]) > 0:
leading_edge = adj[oppos_curr[0][0]]
else:
leading_edge = None
if colour != new_colour:
new_edge = adj[new_node]
# We have at least the new node was just added.
assert leading_edge is not None
ratio = (new_edge.coeff / leading_edge.coeff).simplify()
if ratio != oppos_curr[1][-1]:
return None
res_terms = set(delta.terms)
res_terms.add(new_edge.term)
res_bases = collections.Counter()
res_bases.update(delta.bases)
res_bases[new_edge.base] += 1
if inaccurate:
if delta.exc_cost != -1 and new_edge.exc_cost == 0:
res_exc_cost = -1
else:
res_exc_cost = delta.exc_cost
else:
res_exc_cost = delta.exc_cost + new_edge.exc_cost
else:
res_terms = delta.terms
res_bases = delta.bases
res_exc_cost = delta.exc_cost
if not res_terms.isdisjoint(self._terms):
return None
res_coeff = (
delta.coeff
if self._leading_coeff is None or leading_edge is None
else leading_edge.coeff / self._leading_coeff
)
base_saving = saving.deltas[colour]
if inaccurate:
res_saving = self._get_inaccurate_delta_saving(
base_saving, res_exc_cost, res_bases
)
else:
res_saving = self._get_delta_saving(
base_saving, res_exc_cost, res_bases
)
res_delta = _Delta(
coeff=res_coeff, terms=res_terms, bases=res_bases,
exc_cost=res_exc_cost, saving=res_saving
)
# Sanity checking, should be disabled in production.
# assert res_delta == self._form_delta(colour, node, saving)
return res_delta
def _get_delta_saving(self, base_saving, exc_cost, bases) -> Size:
"""Get the saving incurred by applying a given delta."""
res = base_saving - exc_cost
if not self._rush_local:
for k, v in bases.items():
if self._bases[k] - v > 0:
base_saving -= self._base_infos[k].cost * v
continue
return res
def _get_inaccurate_delta_saving(self, base_saving, exc_cost, bases):
"""Get the saving incurred by a delta in inaccurate mode."""
if exc_cost == -1 or abs(base_saving) == float('inf'):
res_saving = base_saving
else:
res_saving = 0
curbed_by_common = not self._rush_local and any(
self._bases[k] - v > 0 for k, v in bases.items()
)
if curbed_by_common:
res_saving = 0
return res_saving
def _count_stack(self, inaccurate=False):
"""Count the current size of the stack.
The saving will also be returned.
"""
ns = tuple(
len(self._curr[i][0]) for i in [_LEFT, _RIGHT]
)
if inaccurate:
deltas = []
for i, j in zip(ns, reversed(ns)):
if i == 0:
deltas.append(float('inf'))
elif j == 0:
deltas.append(-float('inf'))
else:
deltas.append(j)
continue
saving = _Saving(saving=ns[0] * ns[1], deltas=tuple(deltas))
else:
saving = _get_collect_saving(self._cost_coeffs, ns)
return ns, saving
def _form_delta(
self, colour: _LR, node: Term, saving: _Saving
) -> typing.Optional[_Delta]:
"""Form the delta for adding a new node from scratch.
When it is expandable, the relevant node information will be
returned, or None will be the result.
THIS FUNCTION IS DEPRECATED AND PENDING REMOVAL. Currently it is only
used for the sanity checking of the optimized result.
"""
# Cache frequently used information.
oppos_colour = _OPPOS[colour]
adjs = self._adjs[colour][node]
curr = self._curr
terms = self._terms
base_coeff = None
exc_cost = 0
new_terms = set()
new_bases = collections.Counter()
for i, v in enumerate(curr[oppos_colour][0]):
if v not in adjs:
return
edge = adjs[v] # type: _Edge
if edge.term in terms or edge.term in new_terms:
return
coeff = edge.coeff
if i == 0:
base_coeff = coeff
ratio = coeff / base_coeff
if ratio != curr[oppos_colour][1][i]:
return
exc_cost += edge.exc_cost
new_terms.add(edge.term)
new_bases[edge.base] += 1
continue
# When we get here, it should be expandable now.
if self._leading_coeff is None:
coeff = _UNITY
else:
coeff = base_coeff / self._leading_coeff
new_saving = self._get_delta_saving(
saving.deltas[colour], exc_cost, new_bases
)
# For empty stack, we always get here with base information (coeff=1,
# new_terms=empty, exc_cost=0).
return _Delta(
coeff=coeff, terms=new_terms, bases=new_bases, exc_cost=exc_cost,
saving=new_saving
)
class _CollectGraph:
"""Graph for the collectibles of a given range.
This data structure, and the maximal biclique generation in Bron-Kerbosch
style, are the core of the factorization algorithm for sums.
We have separate graph for different ranges. For each range, the graph has
the factors as nodes, and actual evaluations with the factors as edges.
Internally, the graph is stored as two sparse adjacent lists.
"""
def __init__(self):
"""Initialize the collectible table."""
self._adjs = (
collections.defaultdict(dict),
collections.defaultdict(dict)
)
self._terms = set()
self._base_infos = _BaseInfoDict()
# The optimal biclique in the current graph. None when it is not yet
# determined, zero when it is determined that there is no profitable
# biclique in the current graph.
self._opt_saving = None
self._opt_biclique = None
def add_edge(
self, left, right, term, eval_, base, coeff,
opt_cost, eval_cost
):
"""Add a new edge to the graph."""
edge = _Edge(
term=term, eval_=eval_, coeff=coeff,
exc_cost=eval_cost - opt_cost, base=base
)
left_adj = self._adjs[_LEFT][left]
if right not in left_adj:
left_adj[right] = edge
else:
# It is possible that two evaluations actually the same be recorded
# twice in the evaluation of product nodes because of symmetry.
assert left_adj[right].term == term
right_adj = self._adjs[_RIGHT][right]
if left not in right_adj:
right_adj[left] = edge
else:
assert right_adj[left].term == term
if term not in self._terms:
self._terms.add(term)
# We do not need actual cost here. For optimization purpose, the
# bases should always be read from the centralized base infos across
# all graphs.
self._base_infos.add_base(base, None)
def get_opt_biclique(
self, ranges: _Ranges, base_infos: _BaseInfoDict,
greedy_cutoff=-1, drop_cutoff=-1,
rush_local=False, rush_global=False, inaccurate=False
) -> typing.Tuple[typing.Optional[Size], typing.Optional[_Biclique]]:
"""Get the optimal biclique in the current graph.
"""
if self._opt_saving is not None:
if self._opt_saving == 0:
return None, None
else:
return self._opt_saving, self._opt_biclique
opt_saving = None
opt_biclique = None
for biclique in self.gen_bicliques(
ranges, base_infos,
greedy_cutoff=greedy_cutoff, drop_cutoff=drop_cutoff,
rush_local=rush_local, rush_global=rush_global,
inaccurate=inaccurate
):
saving = biclique.saving
if opt_saving is None or saving > opt_saving:
opt_saving = saving
# Make copy only when we need them.
opt_biclique = _Biclique(
nodes=tuple(
tuple(tuple(j) for j in i)
for i in biclique.nodes
),
leading_coeff=biclique.leading_coeff,
terms=frozenset(biclique.terms),
saving=biclique.saving
)
continue
if opt_saving is None:
assert opt_biclique is None
self._opt_saving = 0
self._opt_biclique = None
else:
if inaccurate:
saving = _get_collect_saving(_get_cost_coeffs(ranges), [
len(i) for i in opt_biclique.nodes
])
opt_biclique = opt_biclique._replace(saving=saving)
self._opt_saving = opt_saving
self._opt_biclique = opt_biclique
return opt_saving, opt_biclique
def gen_bicliques(
self, ranges: _Ranges, base_infos: _BaseInfoDict,
greedy_cutoff=-1, drop_cutoff=-1,
rush_local=False, rush_global=False, inaccurate=False
) -> typing.Iterable[_Biclique]:
"""Generate the bicliques within the graph.
For performance reasons, the bicliques generated will contain references
to internal mutable data. It is the responsibility of the caller to
make proper copy when it is necessary.
"""
yield from _BronKerbosch(
self._adjs, base_infos, ranges,
greedy_cutoff=greedy_cutoff, drop_cutoff=drop_cutoff,
rush_local=rush_local, rush_global=rush_global,
inaccurate=inaccurate
)
def remove_terms(
self, terms: typing.AbstractSet[int], term_base,
updated_bases, base_infos
) -> bool:
"""Remove all edges and nodes involving the given terms.
If a value of True is returned, we have an empty graph after the
removal.
"""
if self._terms.isdisjoint(terms):
if_empty = False
if_updated = False
else:
if_updated = True
new_adjs = (
collections.defaultdict(dict),
collections.defaultdict(dict)
)
if_empty = True
for old, new in zip(self._adjs, new_adjs):
for from_node, conns in old.items():
new_conns = {
to_node: edge
for to_node, edge in conns.items()
if edge.term not in terms
}
if len(new_conns) > 0:
if_empty = False
new[from_node] = new_conns
continue
continue
self._adjs = new_adjs
self._base_infos.remove_terms(terms & self._terms, term_base)
self._terms -= terms
# We need to update the maximum biclique when a base is recently updated
# such that it become exclusively-involved by this graph.
if_dirty = if_updated or any(
i in self._base_infos and
base_infos[i].count == self._base_infos[i].count
for i in updated_bases
)
if if_dirty:
self._opt_saving = None
self._opt_biclique = None
return if_empty
_Collectibles = typing.DefaultDict[_Ranges, _CollectGraph]
#
# For product optimization.
#
_Part = collections.namedtuple('_Part', [
'ref',
'node'
])
def _get_prod_final_cost(exts_total_size, sums_total_size) -> Size:
"""Compute the final cost for a pairwise product evaluation."""
if sums_total_size == 1:
return exts_total_size
else:
return 2 * exts_total_size * sums_total_size
#
# Core evaluation DAG nodes
# ~~~~~~~~~~~~~~~~~~~~~~~~~
#
class _EvalNode:
"""A node in the evaluation graph.
"""
def __init__(self, base: Symbol, exts: _SrPairs):
"""Initialize the evaluation node.
"""
self.base = base
self.exts = exts
# For optimization.
self.evals = [] # type: typing.List[_EvalNode]
self.total_cost = None
# For result finalization.
self.n_refs = 0
self.generated = False
def get_substs(self, indices):
"""Get substitutions and symbols requiring exclusion before indexing.
First resetting dummies excluding the returned symbols and then making
the returned substitution on each term could achieve indexing. Since
the real free symbols are already gather from all inputs, the free
symbols are not considered here. But they should be added for the
resetting.
"""
substs = {}
excl = set()
assert len(indices) == len(self.exts)
for i, j in zip(indices, self.exts):
dumm = j[0]
substs[dumm] = i
excl.add(dumm)
excl |= i.atoms(Symbol)
continue
return substs, excl
class _Sum(_EvalNode):
"""Sum nodes in the evaluation graph."""
def __init__(self, base, exts, sum_terms):
"""Initialize the node."""
super().__init__(base, exts)
self.sum_terms = sum_terms
def __repr__(self):
"""Form a representation string for the node."""
return '_Sum(base={}, exts={}, sum_terms={})'.format(
repr(self.base), repr(self.exts), repr(self.sum_terms)
)
class _Prod(_EvalNode):
"""Product nodes in the evaluation graph.
"""
def __init__(self, base, exts, sums, coeff, factors):
"""Initialize the node."""
super().__init__(base, exts)
self.sums = sums
self.coeff = coeff
self.factors = factors
def __repr__(self):
"""Form a representation string for the node."""
return '_Prod(base={}, exts={}, sums={}, coeff={}, factors={})'.format(
repr(self.base), repr(self.exts), repr(self.sums),
repr(self.coeff), repr(self.factors)
)
#
# Core optimizer class
# ~~~~~~~~~~~~~~~~~~~~
#
class _Optimizer:
"""Optimizer for tensor contraction computations.
This internal optimizer can only be used once for one set of input.
"""
#
# Public functions.
#
def __init__(
self, computs, substs, interm_fmt, strategy,
greedy_cutoff=-1, drop_cutoff=-1
):
"""Initialize the optimizer."""
# Information to be read from the input computations.
#
# The only drudge for the inputs.
self._drudge = None
# The only variable for range sizes.
self._range_var = None
# Mapping from the substituted range to original range.
self._input_ranges = {}
# Symbols that should not be used for dummies.
self._excl = set()
# Read, process, and verify user input.
self._grist = [
self._form_grain(comput, substs) for comput in computs
]
# Dummies stock in terms of the substituted range.
assert self._drudge is not None
self._dumms = {
k: self._drudge.dumms.value[v]
for k, v in self._input_ranges.items()
}
# Other internal data preparation.
self._interm_fmt = interm_fmt
self._strategy = strategy
self._greedy_cutoff = greedy_cutoff
self._drop_cutoff = drop_cutoff
self._next_internal_idx = 0
# From intermediate base to actual evaluation node.
self._interms = {}
# From the canonical form to intermediate base.
self._interms_canon = {}
self._res = None
def optimize(self):
"""Optimize the evaluation of the given computations.
"""
if self._res is not None:
return self._res
res_nodes = [self._form_node(i) for i in self._grist]
for i in res_nodes:
self._optimize(i)
continue
self._res = self._linearize(res_nodes)
return self._res
#
# User input pre-processing.
#
def _form_grain(self, comput, substs):
"""Form grain for a given computation.
"""
curr_drudge = comput.rhs.drudge
if self._drudge is None:
self._drudge = curr_drudge
elif self._drudge is not curr_drudge:
raise ValueError(
'Invalid computations to optimize, containing two drudges',
(self._drudge, curr_drudge)
)
else:
pass
# Externals processing.
exts = self._proc_sums(comput.exts, substs)
ext_symbs = {i for i, _ in exts}
# Terms processing.
terms = []
for term in comput.rhs_terms:
if not term.is_scalar:
raise ValueError(
'Invalid term to optimize', term, 'expecting scalar'
)
sums = self._proc_sums(term.sums, substs)
amp = term.amp
# Add the true free symbols to the exclusion set.
self._excl |= term.free_vars - ext_symbs
terms.append(Term(sums, amp, ()))
continue
return _Grain(
base=comput.base if len(exts) == 0 else comput.base.args[0],
exts=exts, terms=terms
)
def _proc_sums(self, sums, substs):
"""Process a summation list.
The ranges will be replaced with substituted sizes. Relevant members of
the optimizer will also be updated. User error will also be reported.
"""
res = []
for symb, range_ in sums:
new_range, range_var = form_sized_range(range_, substs)
if range_var is not None:
if self._range_var is None:
self._range_var = range_var
elif self._range_var != range_var:
raise ValueError(
'Invalid range', range_, 'unexpected symbol',
range_var, 'conflicting with', self._range_var
)
else:
pass
if new_range not in self._input_ranges:
self._input_ranges[new_range] = range_
elif range_.size != self._input_ranges[new_range].size:
raise ValueError(
'Invalid ranges', (range_, self._input_ranges[new_range]),
'duplicated labels'
)
else:
pass
res.append((symb, new_range))
continue
return tuple(res)
#
# Optimization result post-processing.
#
def _linearize(
self, optimized: typing.Sequence[_EvalNode]
) -> typing.List[TensorDef]:
"""Linearize optimized forms of the evaluation.
"""
for node in optimized:
self._set_n_refs(node)
continue
# Separate the intermediates and the results so that the results can be
# guaranteed to be at the end of the evaluation sequence.
interms = []
res = []
for node in optimized:
curr = self._linearize_node(node, interms, keep=True)
assert curr is not None
res.append(curr)
continue
return self._finalize(itertools.chain(interms, res))
def _set_n_refs(self, node: _EvalNode):
"""Set reference counts from an evaluation node.
"""
if len(node.evals) == 0:
self._optimize(node)
# We need to find an evaluation with optimal cost.
assert len(node.evals) > 0
node.evals = [next(
i for i in node.evals if i.total_cost == node.total_cost
)]
eval_ = node.evals[0]
if isinstance(eval_, _Prod):
possible_refs = [i for i in eval_.factors]
elif isinstance(eval_, _Sum):
possible_refs = eval_.sum_terms
else:
assert False
for i in possible_refs:
ref = self._parse_interm_ref(i)
if ref is None:
continue
dep_node = self._interms[ref.base]
dep_node.n_refs += 1
self._set_n_refs(dep_node)
continue
return
def _linearize_node(self, node: _EvalNode, res: list, keep=False):
"""Linearize evaluation rooted in the given node into the result.
If keep if set to True, the evaluation of the given node will not be
appended to the result list.
"""
if node.generated:
return None
def_, deps = self._form_def(node)
for i in deps:
self._linearize_node(self._interms[i], res)
continue
node.generated = True
if not keep:
res.append(def_)
return def_
def _form_def(self, node: _EvalNode):
"""Form the final definition of an evaluation node.
The dependencies will also be returned.
"""
assert len(node.evals) == 1
if isinstance(node, _Prod):
return self._form_prod_def(node)
elif isinstance(node, _Sum):
return self._form_sum_def(node)
else:
assert False
def _form_prod_def(self, node: _Prod):
"""Form the final definition of a product evaluation node."""
exts = node.exts
eval_ = node.evals[0]
assert isinstance(eval_, _Prod)
term, deps = self._form_prod_def_term(eval_)
return _Grain(
base=node.base, exts=exts, terms=[term]
), deps
def _form_prod_def_term(self, eval_: _Prod):
"""Form the term in the final definition of a product evaluation node.
"""
amp = eval_.coeff
deps = []
for factor in eval_.factors:
ref = self._parse_interm_ref(factor)
if ref is not None:
assert ref.coeff == 1
interm = self._interms[ref.base]
if self._is_input(interm):
# Inline trivial reference to an input.
content = self._get_content(factor)
assert len(content) == 1
assert len(content[0].sums) == 0
amp *= content[0].amp ** ref.power
else:
deps.append(ref.base)
amp *= factor
else:
amp *= factor
return Term(eval_.sums, amp, ()), deps
def _form_sum_def(self, node: _Sum):
"""Form the final definition of a sum evaluation node."""
exts = node.exts
exts_dict = dict(node.exts)
terms = []
deps = []
eval_ = node.evals[0]
assert isinstance(eval_, _Sum)
sum_terms = []
self._inline_sum_terms(eval_.sum_terms, sum_terms)
for term in sum_terms:
ref = self._parse_interm_ref(term)
if ref is None:
terms.append(Term((), term, ()))
# No dependency for pure scalars.
continue
assert ref.power == 1 # Higher power not possible in sum.
# Sum term are guaranteed to be formed from references to products,
# never directly written in terms of input.
term_node = self._interms[ref.base]
if term_node.n_refs == 1 or self._is_input(term_node):
# Inline intermediates only used here and simple input
# references.
eval_ = term_node.evals[0]
assert isinstance(eval_, _Prod)
contents = self._index_prod(eval_, ref.indices)
assert len(contents) == 1
term = contents[0]
factors, term_coeff = term.get_amp_factors(
self._interms, exts_dict
)
# Switch back to evaluation node for using the facilities for
# product nodes.
tmp_node = _Prod(
term_node.base, exts, term.sums,
ref.coeff * term_coeff, factors
)
new_term, term_deps = self._form_prod_def_term(tmp_node)
terms.append(new_term)
deps.extend(term_deps)
else:
terms.append(Term(
(), term, ()
))
deps.append(ref.base)
continue
return _Grain(
base=node.base, exts=exts, terms=terms
), deps
def _inline_sum_terms(
self, sum_terms: typing.Sequence[Expr], res: typing.List[Expr]
):
"""Inline the summation terms from single-reference terms.
This function mutates the given result list rather than returning the
result to avoid repeated list creation in recursive calls.
"""
for sum_term in sum_terms:
ref = self._parse_interm_ref(sum_term)
if ref is None:
res.append(sum_term)
continue
assert ref.power == 1
node = self._interms[ref.base]
assert len(node.evals) > 0
eval_ = node.evals[0]
if_inline = isinstance(eval_, _Sum) and (
node.n_refs == 1 or len(eval_.sum_terms) == 1
)
if if_inline:
if len(node.exts) == 0:
substs = None
else:
substs = {
i[0]: j for i, j in zip(eval_.exts, ref.indices)
}
proced_sum_terms = [
(
i.xreplace(substs) if substs is not None else sum_term
) * ref.coeff for i in eval_.sum_terms
]
self._inline_sum_terms(proced_sum_terms, res)
continue
else:
res.append(sum_term)
continue
return
def _is_input(self, node: _EvalNode):
"""Test if a product node is just a trivial reference to an input."""
if isinstance(node, _Prod):
return len(node.sums) == 0 and len(node.factors) == 1 and (
self._parse_interm_ref(node.factors[0]) is None
)
else:
return False
def _finalize(
self, computs: typing.Iterable[_Grain]
) -> typing.List[TensorDef]:
"""Finalize the linearization result.
Things will be cast to drudge tensor definitions, with intermediates
holding names formed from the format given by user.
"""
next_idx = 0
substs = {} # For normal substitution of bases.
repls = [] # For removed shallow intermediates
def proc_amp(amp):
"""Process the amplitude by making the found substitutions."""
for i in reversed(repls):
amp = amp.replace(*i)
continue
return amp.xreplace(substs)
res = []
for comput in computs:
exts = tuple((s, self._input_ranges[r]) for s, r in comput.exts)
if_scalar = len(exts) == 0
base = comput.base if if_scalar else IndexedBase(comput.base)
terms = [
i.map(proc_amp, sums=tuple(
(s, self._input_ranges[r]) for s, r in i.sums
)) for i in comput.terms
]
# No internal intermediates should be leaked.
for i in terms:
assert not any(j in self._interms for j in i.free_vars)
if comput.base in self._interms:
if len(terms) == 1 and len(terms[0].sums) == 0:
# Remove shallow intermediates. The saving might be too
# modest to justify the additional memory consumption.
#
# TODO: Move it earlier to a better place.
repl_lhs = base if if_scalar else base[tuple(
_WILD_FACTORY[i] for i, _ in enumerate(exts)
)]
repl_rhs = proc_amp(terms[0].amp.xreplace(
{v[0]: _WILD_FACTORY[i] for i, v in enumerate(exts)}
))
repls.append((repl_lhs, repl_rhs))
continue # No new intermediate added.
final_base = (
Symbol if if_scalar else IndexedBase
)(self._interm_fmt.format(next_idx))
next_idx += 1
substs[base] = final_base
else:
final_base = base
res.append(TensorDef(
final_base, exts, self._drudge.create_tensor(terms)
).reset_dumms())
continue
return res
#
# Internal support utilities.
#
def _get_next_internal(self):
"""Get the symbol for the next internal intermediate.
"""
idx = self._next_internal_idx
self._next_internal_idx += 1
return Symbol('gristmillInternalIntermediate{}'.format(idx))
@staticmethod
def _write_in_orig_ranges(sums):
"""Write the summations in terms of undecorated bare ranges.
The labels in the ranges are assumed to be decorated.
"""
return tuple(
(i, j.replace_label(j.label[0])) for i, j in sums
)
def _canon_terms(self, new_sums: _SrPairs, terms: typing.Iterable[Term]):
"""Form a canonical label for a list of terms.
The new summation list is prepended to the summation list of all terms.
The coefficient ahead of the canonical form is returned before the
canonical form. And the permuted new summation list is also returned
after the canonical form. Note that this list contains the original
dummies given in the new summation list, while the terms has reset new
dummies.
Note that the ranges in the new summation list are assumed to be
decorated with labels earlier than _SUMMED. In the result, they are
still in decorated forms and are guaranteed to be permuted in the same
way for all given terms. The summations from the terms will be
internally decorated but written in bare ranges in the final result.
Note that this is definitely a poor man's version of canonicalization of
multi-term tensor definitions with external indices. A lot of cases
cannot be handled well. Hopefully it can be replaced with a systematic
treatment some day in the future.
"""
new_dumms = {i for i, _ in new_sums}
coeffs = []
candidates = collections.defaultdict(list)
for idx, term in enumerate(terms):
term, canon_sums = self._canon_term(new_sums, term)
factors, coeff = term.get_amp_factors(self._interms)
coeffs.append(coeff)
candidates[
term.map(lambda x: prod_(factors))
].append((canon_sums, idx))
continue
# Poor man's canonicalization of external indices.
#
# This algorithm is not guaranteed to work. Here we just choose an
# ordering of the external indices that is as safe as possible. But
# certainly it is not guaranteed to work for all cases.
#
# TODO: Fix it!
chosen = min(candidates.items(), key=lambda x: (
len(x[1]), -len(x[0].amp.atoms(Symbol) & new_dumms),
x[0].sort_key
))
canon_new_sums = set(i for i, _ in chosen[1])
if len(canon_new_sums) > 1:
warnings.warn(
'Internal deficiency: '
'summation intermediate may not be fully canonicalized'
)
# This could also fail when the chosen term has symmetry among the new
# summations not present in any other term. This can be hard to check.
canon_new_sum = canon_new_sums.pop()
preferred = prod_(
coeffs[i] for _, i in candidates[chosen[0]]
)
canon_coeff = _get_canon_coeff(coeffs, preferred)
res_terms = []
for term in terms:
canon_term, _ = self._canon_term(canon_new_sum, term, fix_new=True)
# TODO: Add support for complex conjugation.
res_terms.append(canon_term.map(lambda x: x / canon_coeff))
continue
return canon_coeff, tuple(
sorted(res_terms, key=lambda x: x.sort_key)
), canon_new_sum
def _canon_term(self, new_sums, term, fix_new=False):
"""Canonicalize a single term.
Internal method for _canon_terms, not supposed to be directly called.
"""
term = Term(tuple(itertools.chain(
(
(v[0], v[1].replace_label((v[1].label[0], _EXT, i)))
for i, v in enumerate(new_sums)
) if fix_new else new_sums,
(
(i, j.replace_label((j.label, _SUMMED)))
for i, j in term.sums
)
)), term.amp, ())
canoned = term.canon(symms=self._drudge.symms.value)
canon_sums = canoned.sums
canon_orig_sums = self._write_in_orig_ranges(canon_sums)
dumm_reset, _ = canoned.map(sums=canon_orig_sums).reset_dumms(
dumms=self._dumms, excl=self._excl
)
canon_new_sums = []
term_new_sums = []
term_sums = []
i_new = 0
for i, j in zip(dumm_reset.sums, canon_sums):
if j[1].label[1] == _SUMMED:
# Existing summations.
term_sums.append(i)
else:
if fix_new:
assert j[0] == new_sums[i_new][0]
range_ = new_sums[i_new][1]
else:
range_ = j[1]
canon_new_sums.append((j[0], range_))
term_new_sums.append((i[0], range_))
i_new += 1
continue
return dumm_reset.map(sums=tuple(itertools.chain(
term_new_sums, term_sums
))), tuple(canon_new_sums)
def _parse_interm_ref(self, expr: Expr) -> typing.Optional[_IntermRef]:
"""Parse an expression that is possibly an intermediate reference.
"""
coeff = _UNITY
base = None
indices = None
power = None
if isinstance(expr, Mul):
args = expr.args
else:
args = [expr]
for i in args:
if any(j in self._interms for j in i.atoms(Symbol)):
assert base is None
ref, power = i.as_base_exp()
if isinstance(ref, Indexed):
base = ref.base.args[0]
indices = ref.indices
elif isinstance(ref, Symbol):
base = ref
indices = ()
else:
assert False
assert base in self._interms
else:
coeff *= i
return None if base is None else _IntermRef(
coeff=coeff, base=base, indices=indices, power=power
)
def _get_content(self, interm_ref: Expr) -> typing.List[Term]:
"""Get the content of an intermediate reference.
This function might be removed after the new factorization algorithm is
implemented.
"""
ref = self._parse_interm_ref(interm_ref)
assert ref is not None
node = self._interms[ref.base]
if isinstance(node, _Sum):
content = self._index_sum(node, ref.indices)
elif isinstance(node, _Prod):
content = self._index_prod(node, ref.indices)
else:
assert False
return [
i.scale(ref.coeff) for i in self._raise_power(content, ref.power)
]
def _index_sum(self, node: _Sum, indices) -> typing.List[Term]:
"""Substitute the external indices in the sum node"""
substs, _ = node.get_substs(indices)
res = []
for i in node.sum_terms:
term = i.xreplace(substs)
ref = self._parse_interm_ref(term)
if ref is None:
res.append(term)
else:
term_def = self._get_content(term)
res.extend(term_def)
continue
return res
def _index_prod(self, node: _Prod, indices) -> typing.List[Term]:
"""Substitute the external indices in the evaluation node."""
substs, excl = node.get_substs(indices)
term = Term(
node.sums, node.coeff * prod_(node.factors), ()
).reset_dumms(
self._dumms, excl=self._excl | excl
)[0].map(lambda x: x.xreplace(substs))
return [term]
def _raise_power(
self, terms: typing.Sequence[Term], exp: int
) -> typing.List[Term]:
"""Raise the sum of the given terms to the given power."""
curr = [] # type: typing.List[Term]
for _ in range(exp):
if len(curr) == 0:
curr = list(terms)
else:
# TODO: Make the multiplication more efficient.
curr = [i.mul_term(
j, dumms=self._dumms,
excl=self._excl | i.free_vars | j.free_vars
) for i, j in itertools.product(curr, terms)]
return curr
#
# General optimization.
#
def _form_node(self, grain: _Grain):
"""Form an evaluation node from a tensor definition.
This is the entry point for optimization.
"""
# We assume it is fully simplified and expanded by grist preparation.
exts = grain.exts
terms = grain.terms
assert len(terms) > 0
return self._form_sum_from_terms(
grain.base, exts, terms
)
def _optimize(self, node):
"""Optimize the evaluation of the given node.
The evaluation methods will be filled with, possibly multiple, method of
evaluations.
"""
if len(node.evals) > 0:
return node
if isinstance(node, _Sum):
return self._optimize_sum(node)
elif isinstance(node, _Prod):
return self._optimize_prod(node)
else:
assert False
def _form_prod_interm(
self, exts, sums, factors
) -> typing.Tuple[Expr, _EvalNode]:
"""Form a product intermediate.
The factors are assumed to be all non-trivial factors needing
processing.
"""
decored_exts = tuple(
(i, j.replace_label((j.label, _EXT)))
for i, j in exts
)
n_exts = len(decored_exts)
term = Term(tuple(sums), prod_(factors).simplify(), ())
coeff, key, canon_exts = self._canon_terms(
decored_exts, [term]
)
assert len(key) == 1
if key in self._interms_canon:
base = self._interms_canon[key]
else:
base = self._get_next_internal()
self._interms_canon[key] = base
key_term = key[0]
key_exts = self._write_in_orig_ranges(key_term.sums[:n_exts])
key_sums = key_term.sums[n_exts:]
# The external symbols will automatically be considered in
# get_amp_factors since they are in the summation list right now.
key_factors, key_coeff = key_term.get_amp_factors(self._interms)
interm = _Prod(
base, key_exts, key_sums, key_coeff, key_factors
)
self._interms[base] = interm
return coeff * _index(
base, canon_exts, strip=True
), self._interms[base]
def _form_sum_interm(
self, exts: _SrPairs, terms: typing.Sequence[Term]
) -> typing.Tuple[Expr, _EvalNode]:
"""Form a sum intermediate.
"""
decored_exts = tuple(
(i, j.replace_label((j.label, _EXT)))
for i, j in exts
)
n_exts = len(decored_exts)
coeff, canon_terms, canon_exts = self._canon_terms(decored_exts, terms)
if canon_terms in self._interms_canon:
base = self._interms_canon[canon_terms]
else:
base = self._get_next_internal()
self._interms_canon[canon_terms] = base
node_exts = None
node_terms = []
for term in canon_terms:
term_exts = self._write_in_orig_ranges(term.sums[:n_exts])
if node_exts is None:
node_exts = term_exts
else:
assert node_exts == term_exts
node_terms.append(term.map(sums=term.sums[n_exts:]))
continue
node = self._form_sum_from_terms(base, node_exts, node_terms)
self._interms[base] = node
self._optimize(node)
return coeff * _index(
base, canon_exts, strip=True
), self._interms[base]
def _form_sum_from_terms(
self, base: Symbol, exts: _SrPairs, terms: typing.Iterable[Term]
):
"""Form a summation node for given the terms.
No processing is done in this method. It just forms the node.
"""
sum_terms = []
plain_scalars = []
ext_symbs = {i for i, _ in exts}
for term in terms:
sums = term.sums
factors, coeff = term.get_amp_factors(self._interms, ext_symbs)
if len(factors) == 0:
plain_scalars.append(coeff)
else:
interm_ref, _ = self._form_prod_interm(exts, sums, factors)
sum_terms.append(interm_ref * coeff)
continue
if len(plain_scalars) > 0:
sum_terms.append(sum_(plain_scalars))
return _Sum(base, exts, sum_terms)
#
# Sum optimization.
#
def _optimize_sum(self, sum_node: _Sum):
"""Optimize the summation node."""
# We first optimize the common terms.
exts = sum_node.exts
scalars, terms, _ = self._organize_sum_terms(sum_node.sum_terms)
if self._strategy & Strategy.SUM > 0:
new_terms, old_terms = self._factorize_sum(terms, exts)
else:
new_terms = []
old_terms = terms
if self._strategy & Strategy.COMMON > 0:
old_terms = self._optimize_common_symmtrization(old_terms, exts)
res_terms = scalars + old_terms + new_terms
sum_node.evals = [_Sum(
sum_node.base, sum_node.exts, res_terms
)]
return
def _organize_sum_terms(self, terms: typing.Iterable[Expr]) -> typing.Tuple[
typing.List[Expr], typing.List[Expr], _OrgTerms
]:
"""Organize terms in the summation node.
"""
# Intermediate base -> (indices -> coefficient)
#
# This first gather terms with the same reference to deeper nodes.
org_terms = collections.defaultdict(
lambda: collections.defaultdict(lambda: _ZERO)
)
plain_scalars = []
for term in terms:
ref = self._parse_interm_ref(term)
if ref is None:
plain_scalars.append(term)
continue
assert ref.power == 1
org_terms[ref.base][ref.indices] += ref.coeff
continue
res_terms = []
for k, v in org_terms.items():
assert len(v) > 0
for indices, coeff in v.items():
coeff = coeff.simplify()
if coeff != 0:
res_terms.append(
_index(k, indices) * coeff
)
continue
return plain_scalars, res_terms, org_terms
def _optimize_common_symmtrization(self, terms, exts):
"""Optimize common symmetrization in the intermediate references.
"""
res_terms = []
exts_dict = dict(exts)
scalars, _, org_terms = self._organize_sum_terms(terms)
assert len(scalars) == 0
# Indices, coeffs tuple -> base, coeff
pull_info = collections.defaultdict(list)
for k, v in org_terms.items():
if len(v) == 0:
assert False
elif len(v) == 1:
indices, coeff = v.popitem()
res_terms.append(
_index(k, indices) * coeff
)
else:
# Here we use name for sorting directly, since here we cannot
# have general expressions hence no need to use the expensive
# sort_key.
raw = list(v.items()) # Indices/coefficient pairs.
raw.sort(key=lambda x: [i.name for i in x[0]])
leading_coeff = raw[0][1]
pull_info[tuple(
(i, j / leading_coeff) for i, j in raw
)].append((k, leading_coeff))
# Now we treat the terms from which new intermediates might be pulled
# out.
for k, v in pull_info.items():
pivot = k[0][0]
assert len(pivot) > 0
assert k[0][1] == 1
if len(v) == 1:
# No need to form a new intermediate.
base, coeff = v[0]
pivot_ref = _index(base, pivot) * coeff
else:
# We need to form an intermediate here.
interm_exts = tuple(
(i, exts_dict[i]) for i in pivot
)
interm_terms = [
term.scale(coeff)
for base, coeff in v
for term in self._get_content(_index(base, pivot))
]
pivot_ref, interm_node = self._form_sum_interm(
interm_exts, interm_terms
)
self._optimize(interm_node)
for indices, coeff in k:
substs = {
i: j for i, j in zip(pivot, indices)
}
res_terms.append(
pivot_ref.xreplace(substs) * coeff
)
continue
continue
return res_terms
def _factorize_sum(
self, terms: typing.Sequence[Expr], exts: _SrPairs
):
"""Factorize the summations greedily.
"""
if_keep = [True for _ in terms]
new_terms = []
collectibles, base_infos, term_base = self._find_collectibles(
terms, exts
)
while True:
ranges, biclique = self._choose_collectible(
collectibles, base_infos
)
if ranges is None:
break
new_terms.append(self._form_factored_term(ranges, biclique))
self._clean_up_collected(
biclique, collectibles, base_infos, term_base, if_keep
)
continue
# End Main loop.
old_terms = [i for i, j in zip(terms, if_keep) if j]
return new_terms, old_terms
def _find_collectibles(
self, terms: typing.Sequence[Expr], exts: _SrPairs
) -> typing.Tuple[_Collectibles, _BaseInfoDict, typing.List[Symbol]]:
"""Find all collectibles for the given terms..
"""
coll = collections.defaultdict(_CollectGraph) # type: _Collectibles
base_infos = _BaseInfoDict()
term_base = []
for term_idx, term in enumerate(terms):
ref = self._parse_interm_ref(term)
if ref is None:
term_base.append(None)
continue
base = ref.base
node = self._interms[base]
assert isinstance(node, _Prod)
self._optimize(node)
for eval_idx, eval_ in enumerate(node.evals):
assert isinstance(eval_, _Prod)
self._find_collectibles_eval(
term_idx, ref, eval_idx, eval_, exts, coll
)
continue
base_infos.add_base(base, node.total_cost)
term_base.append(base)
return coll, base_infos, term_base
def _find_collectibles_eval(
self, term_idx: int, ref: _IntermRef, eval_idx: int, eval_: _Prod,
exts: _SrPairs, res: _Collectibles
):
"""Get the collectibles for a particular evaluations of a product.
"""
if len(eval_.factors) < 2:
return
assert len(eval_.factors) == 2
eval_cost = eval_.total_cost
opt_cost = self._interms[ref.base].total_cost
eval_terms = self._index_prod(eval_, ref.indices)
assert len(eval_terms) == 1
eval_term = eval_terms[0]
ext_symbs = {i for i, _ in eval_.exts}
factors, coeff = eval_term.get_amp_factors(
self._interms, ext_symbs
)
coeff *= ref.coeff
assert factors[0] != factors[1]
sums = tuple(sorted(
eval_term.sums, key=lambda x: default_sort_key(x[0])
))
excl = self._excl | ext_symbs
symms = self._drudge.symms.value
# Information about the (two) factors,
#
# expr: The original expression for the factor.
# exts: Indices of the involved externals.
# canon_content: The canonicalized content for the factor.
factor_infos = [
types.SimpleNamespace(expr=i) for i in factors
]
for f_i in factor_infos:
content = self._get_content(f_i.expr)
assert len(content) == 1
content = content[0]
symbs = f_i.expr.atoms(Symbol)
f_i.exts = tuple(
i for i, v in enumerate(exts) if v[0] in symbs
) # Index only.
for i, _ in sums:
assert i in symbs
# In order to really make sure, the content will be re-canonicalized
# based on the current ambient.
canon_content = content.canon(symms=symms).reset_dumms(
self._dumms, excl=excl | content.free_vars
)[0]
_, canon_coeff = canon_content.get_amp_factors(
self._interms, ext_symbs
)
f_i.canon_content = canon_content.map(
lambda x: x / canon_coeff, skip_vecs=True
)
coeff *= canon_coeff
continue
factor_infos.sort(key=lambda x: x.exts)
l_exts, r_exts = [
tuple(exts[j] for j in i.exts)
for i in factor_infos
]
ranges = _Ranges(exts=(l_exts, r_exts), sums=sums)
# When the left and right externals differ, the two factors have
# determined colour, or we need to add one of them for each colour
# assignment.
lr_factor_idxes = [(0, 1)]
if l_exts == r_exts:
lr_factor_idxes.append((1, 0))
lr_factors = [
tuple(factor_infos[j].canon_content for j in i)
for i in lr_factor_idxes
]
for i in lr_factors:
res[ranges].add_edge(
left=i[0], right=i[1], term=term_idx, eval_=eval_idx,
base=ref.base, coeff=coeff, opt_cost=opt_cost,
eval_cost=eval_cost
)
continue
return
def _choose_collectible(
self, collectibles: _Collectibles, base_infos: _BaseInfoDict
):
"""Choose the most profitable collectible factor.
"""
rush_local = self._strategy & Strategy.RUSH_LOCAL > 0
rush_global = self._strategy & Strategy.RUSH_GLOBAL > 0
inaccurate = self._strategy & Strategy.INACCURATE > 0
opt_saving = None
opt_ranges = None
opt_biclique = None
for ranges, graph in collectibles.items():
curr_opt_saving, curr_opt_biclique = graph.get_opt_biclique(
ranges, base_infos,
greedy_cutoff=self._greedy_cutoff,
drop_cutoff=self._drop_cutoff,
rush_local=rush_local, rush_global=rush_global,
inaccurate=inaccurate
)
if curr_opt_saving is None:
continue
if opt_saving is None or curr_opt_saving > opt_saving:
opt_saving = curr_opt_saving
opt_ranges = ranges
opt_biclique = curr_opt_biclique
continue
return opt_ranges, opt_biclique
def _form_factored_term(
self, ranges: _Ranges, biclique: _Biclique
) -> Expr:
"""Form the factored term for the given factorization."""
# Form and optimize the two new summation nodes.
factors = [biclique.leading_coeff]
for exts_i, nodes_i in zip(ranges.exts, biclique.nodes):
scaled_terms = [
i.scale(j) for i, j in zip(*nodes_i)
]
exts = exts_i + ranges.sums
if len(scaled_terms) > 1:
expr, eval_node = self._form_sum_interm(exts, scaled_terms)
else:
scaled_term = scaled_terms[0]
expr, eval_node = self._form_prod_interm(
exts, scaled_term.sums, [scaled_term.amp]
)
factors.append(expr)
self._optimize(eval_node)
continue
# Form the contraction node for the two new summation nodes.
exts = tuple(sorted(
set(itertools.chain.from_iterable(ranges.exts)),
key=lambda x: default_sort_key(x[0])
))
expr, eval_node = self._form_prod_interm(
exts, ranges.sums, factors
)
# Make phony optimization of the intermediate.
eval_node.total_cost = 1
eval_node.evals = [eval_node]
return expr
@staticmethod
def _clean_up_collected(
biclique: _Biclique, collectibles: _Collectibles,
base_infos: _BaseInfoDict, term_base: typing.List[Symbol],
if_keep: typing.List[bool]
):
"""Clean up the collectibles and the terms after factorization."""
for i in biclique.terms:
assert if_keep[i]
if_keep[i] = False
continue
updated_bases = base_infos.remove_terms(biclique.terms, term_base)
to_remove = []
for ranges, graph in collectibles.items():
if_empty = graph.remove_terms(
biclique.terms, term_base, updated_bases, base_infos
)
if if_empty:
to_remove.append(ranges)
continue
for i in to_remove:
del collectibles[i]
continue
#
# Product optimization.
#
def _optimize_prod(self, prod_node):
"""Optimize the product evaluation node.
"""
assert len(prod_node.evals) == 0
n_factors = len(prod_node.factors)
if n_factors < 2:
assert n_factors == 1
prod_node.evals.append(prod_node)
prod_node.total_cost = _get_prod_final_cost(
get_total_size(prod_node.exts),
get_total_size(prod_node.sums)
)
return
strategy = self._strategy & Strategy.PROD_MASK
evals = prod_node.evals
optimal_cost = None
for final_cost, broken_sums, parts_gen in self._gen_factor_parts(
prod_node
):
def need_break():
"""If we need to break the current loop."""
if strategy == Strategy.GREEDY:
return True
elif strategy == Strategy.BEST or strategy == Strategy.SEARCHED:
return final_cost > optimal_cost
elif strategy == Strategy.ALL:
return False
else:
assert False
if (optimal_cost is not None) and need_break():
break
# Else
for parts in parts_gen:
# Recurse, two parts.
assert len(parts) == 2
for i in parts:
self._optimize(i.node)
continue
total_cost = (
final_cost
+ parts[0].node.total_cost
+ parts[1].node.total_cost
)
if_new_optimal = (
optimal_cost is None or optimal_cost > total_cost
)
if if_new_optimal:
optimal_cost = total_cost
if strategy == Strategy.BEST:
evals.clear()
# New optimal is always added.
def need_add_eval():
"""If the current evaluation need to be added."""
if strategy == Strategy.BEST:
return total_cost == optimal_cost
else:
return True
if if_new_optimal or need_add_eval():
new_eval = self._form_prod_eval(
prod_node, broken_sums, parts
)
new_eval.total_cost = total_cost
evals.append(new_eval)
continue
assert len(evals) > 0
prod_node.total_cost = optimal_cost
return
def _gen_factor_parts(self, prod_node: _Prod):
"""Generate all the partitions of factors in a product node."""
# Compute things invariant to different summations for performance.
exts = prod_node.exts
exts_total_size = get_total_size(exts)
factor_atoms = [
i.atoms(Symbol) for i in prod_node.factors
]
sum_involve = [
{j for j, v in enumerate(factor_atoms) if i in v}
for i, _ in prod_node.sums
]
dumm2index = tuple(
{v[0]: j for j, v in enumerate(i)}
for i in [prod_node.exts, prod_node.sums]
)
# Indices of external and internal dummies involved by each factors.
factor_infos = [
tuple(
set(i[j] for j in atoms if j in i)
for i in dumm2index
)
for atoms in factor_atoms
]
# Actual generation.
for broken_size, kept in self._gen_kept_sums(prod_node.sums):
broken_sums = [i for i, j in zip(prod_node.sums, kept) if not j]
final_cost = _get_prod_final_cost(
exts_total_size, broken_size
)
yield final_cost, broken_sums, self._gen_parts_w_kept_sums(
prod_node, kept, sum_involve, factor_infos
)
continue
@staticmethod
def _gen_kept_sums(sums):
"""Generate kept summations in increasing size of broken summations.
The results will be given as boolean array giving if the corresponding
entry is to be kept.
"""
sizes = [i.size for _, i in sums]
n_sums = len(sizes)
def get_size(kept):
"""Wrap the kept summation with its size."""
size = prod_(
i for i, j in zip(sizes, kept) if not j
)
return Tuple4Cmp((size, kept))
init = [True] * n_sums # Everything is kept.
queue = [get_size(init)]
while len(queue) > 0:
curr = heapq.heappop(queue)
yield curr
curr_kept = curr[1]
for i in range(n_sums):
if curr_kept[i]:
new_kept = list(curr_kept)
new_kept[i] = False
heapq.heappush(queue, get_size(new_kept))
continue
else:
break
continue
def _gen_parts_w_kept_sums(
self, prod_node: _Prod, kept, sum_involve, factor_infos
):
"""Generate all partitions with given summations kept.
First we the factors are divided into chunks indivisible according to
the kept summations. Then their bipartitions which really break the
broken sums are generated.
"""
dsf = DSF(i for i, _ in enumerate(factor_infos))
for i, j in zip(kept, sum_involve):
if i:
dsf.union(j)
continue
chunks = dsf.sets
if len(chunks) < 2:
return
for part in self._gen_parts_from_chunks(kept, chunks, sum_involve):
assert len(part) == 2
yield tuple(
self._form_part(prod_node, i, sum_involve, factor_infos)
for i in part
)
return
@staticmethod
def _gen_parts_from_chunks(kept, chunks, sum_involve):
"""Generate factor partitions from chunks.
Here special care is taken to respect the broken summations in the
result.
"""
n_chunks = len(chunks)
for chunks_part in multiset_partitions(n_chunks, m=2):
factors_part = tuple(set(
factor_i for chunk_i in chunk_part_i
for factor_i in chunks[chunk_i]
) for chunk_part_i in chunks_part)
for i, v in enumerate(kept):
if v:
continue
# Now we have broken sum, it need to be involved by both parts.
involve = sum_involve[i]
if any(part.isdisjoint(involve) for part in factors_part):
break
else:
yield factors_part
def _form_part(self, prod_node, factor_idxes, sum_involve, factor_infos):
"""Form a partition for the given factors."""
involved_exts, involved_sums = [
set.union(*[factor_infos[i][label] for i in factor_idxes])
for label in [0, 1]
]
factors = [prod_node.factors[i] for i in factor_idxes]
exts = [
v
for i, v in enumerate(prod_node.exts)
if i in involved_exts
]
sums = []
for i, v in enumerate(prod_node.sums):
if sum_involve[i].isdisjoint(factor_idxes):
continue
elif sum_involve[i] <= factor_idxes:
sums.append(v)
else:
exts.append(v)
continue
ref, node = self._form_prod_interm(exts, sums, factors)
return _Part(ref=ref, node=node)
def _form_prod_eval(
self, prod_node: _Prod, broken_sums, parts: typing.Tuple[_Part, ...]
):
"""Form an evaluation for a product node."""
assert len(parts) == 2
coeff = _UNITY
factors = []
for i in parts:
curr_ref = self._parse_interm_ref(i.ref)
coeff *= curr_ref.coeff
factors.append(curr_ref.ref)
continue
assert len(factors) == 2
if factors[0] == factors[1]:
factors = [factors[0] ** 2]
return _Prod(
prod_node.base, prod_node.exts, broken_sums,
coeff * prod_node.coeff, factors
)
#
# Utility constants.
#
_ZERO = Integer(0)
_UNITY = Integer(1)
_NEG_UNITY = Integer(-1)
_EXT = 0
_SUMMED_EXT = 1
_SUMMED = 2
_SUBSTED_EVAL_BASE = Symbol('gristmillSubstitutedEvalBase')
#
# Utility static functions.
#
class _SymbFactory(dict):
"""A small symbol factory."""
def __missing__(self, key):
return Symbol('gristmillInternalSymbol{}'.format(key))
_SYMB_FACTORY = _SymbFactory()
class _WildFactory(dict):
"""A small wild symbol factory."""
def __missing__(self, key):
return Wild('gristmillInternalWild{}'.format(key))
_WILD_FACTORY = _WildFactory()
def _get_canon_coeff(coeffs, preferred):
"""Get the canonical coefficient from a list of coefficients."""
expr = sum(
v * _SYMB_FACTORY[i] for i, v in enumerate(coeffs)
).together()
frac = _UNITY # The fractional part.
if isinstance(expr, Mul):
for i in expr.args:
if isinstance(i, Pow) and i.args[1] < 0:
frac *= i
continue
expr /= frac
coeff, _ = primitive(expr, *[
_SYMB_FACTORY[i] for i, _ in enumerate(coeffs)
])
# Initial coefficient without phase.
init_coeff = coeff * frac
# The primitive computation does not take phase into account.
negs = []
poses = []
for i in coeffs:
i /= init_coeff
if i.has(_NEG_UNITY) or i.is_negative:
negs.append(-i)
else:
poses.append(i)
continue
neg_sig, pos_sig = [
(len(i), tuple(sorted(default_sort_key(j) for j in i)))
for i in [negs, poses]
]
if neg_sig > pos_sig:
phase = _NEG_UNITY
elif pos_sig > neg_sig:
phase = _UNITY
else:
preferred_phase = (
_NEG_UNITY if preferred.has(_NEG_UNITY) or preferred.is_negative
else _UNITY
)
phase = preferred_phase
return (coeff * phase * frac).simplify()
def _index(base, indices, strip=False) -> Expr:
"""Index the given base with indices.
When strip is set to true, the indices are assumed to be symbol/range pairs
list.
"""
if strip:
indices = tuple(i for i, _ in indices)
else:
indices = tuple(indices)
return base if len(indices) == 0 else IndexedBase(base)[indices]
#
# Optimization result verification
# --------------------------------
#
[docs]def verify_eval_seq(
eval_seq: typing.Sequence[TensorDef], res: typing.Sequence[TensorDef],
simplify=False
) -> bool:
"""Verify the correctness of an evaluation sequence for the results.
The last entries of the evaluation sequence should be in one-to-one
correspondence with the original form in the ``res`` argument. This
function returns ``True`` when the evaluation sequence is symbolically
equivalent to the given raw form. When a difference is found,
``ValueError`` will be raised with relevant information.
Note that this function can be very slow for large evaluations. But it is
advised to be used for all optimizations in mission-critical tasks.
Parameters
----------
eval_seq
The evaluation sequence to verify, can be the output from
:py:func:`optimize` directly.
res
The original result to test the evaluation sequence against. It can be
the input to :py:func:`optimize` directly.
simplify
If simplification is going to be performed after each step of the
back-substitution. It is advised for larger complex evaluations.
"""
n_res = len(res)
n_interms = len(eval_seq) - n_res
substed_eval_seq = []
defs_dict = {}
for idx, eval_ in enumerate(eval_seq):
base = eval_.base
free_vars = eval_.rhs.free_vars
curr_defs = [
defs_dict[i] for i in free_vars if i in defs_dict
]
exts = {i for i, _ in eval_.exts}
rhs = eval_.rhs.subst_all(curr_defs, simplify=simplify, excl=exts)
new_def = TensorDef(base, eval_.exts, rhs)
substed_eval_seq.append(new_def)
if idx < n_interms:
defs_dict[
base.label if isinstance(base, IndexedBase) else base
] = new_def
continue
for i, j in zip(substed_eval_seq[-n_res:], res):
ref = j.simplify()
if i.lhs != ref.lhs:
raise ValueError(
'Unequal left-hand sides', i.lhs, 'with', ref.lhs
)
diff = (i.rhs - ref.rhs).simplify()
if diff != 0:
raise ValueError(
'Unequal definition for ', j.lhs, j
)
continue
return True