"""Tensor term definition and utility."""
import abc
import collections
import functools
import itertools
import operator
import typing
import warnings
from collections.abc import Iterable, Mapping, Callable, Sequence
from sympy import (
sympify, Symbol, KroneckerDelta, Eq, solve, S, Integer, Add, Mul, Indexed,
IndexedBase, Expr, Basic, Pow, Wild, conjugate
)
from sympy.core.sympify import CantSympify
from .canon import canon_factors
from .utils import (
ensure_symb, ensure_expr, sympy_key, is_higher, NonsympifiableFunc
)
#
# Utility constants
# -----------------
#
_UNITY = Integer(1)
_NEG_UNITY = Integer(-1)
_NAUGHT = Integer(0)
#
# Fundamental classes
# --------------------
#
[docs]class Range:
"""A symbolic range that can be summed over.
This class is for symbolic ranges that is going to be summed over in
tensors. Each range should have a label, and optionally lower and upper
bounds, which should be both given or absent. The label can be any hashable
and ordered Python type. The bounds will not be directly used for symbolic
computation, but rather designed for printers and conversion to SymPy
summation. Note that ranges are assumed to be atomic and disjoint. Even in
the presence of lower and upper bounds, unequal ranges are assumed to be
disjoint.
.. warning::
Bounds with the same label but different bounds will be considered
unequal. Although no error be given, using different bounds with
identical label is strongly advised against.
.. warning::
Unequal ranges are always assumed to be disjoint.
"""
__slots__ = [
'_label',
'_lower',
'_upper'
]
[docs] def __init__(self, label, lower=None, upper=None):
"""Initialize the symbolic range.
"""
self._label = label
self._lower = (
ensure_expr(lower, 'lower bound') if lower is not None else lower
)
if self._lower is None:
if upper is not None:
raise ValueError('lower range has not been given.')
else:
self._upper = None
else:
if upper is None:
raise ValueError('upper range has not been given.')
else:
self._upper = ensure_expr(upper, 'upper bound')
@property
def label(self):
"""The label of the range."""
return self._label
@property
def lower(self):
"""The lower bound of the range."""
return self._lower
@property
def upper(self):
"""The upper bound of the range."""
return self._upper
@property
def size(self):
"""The size of the range.
This property given None for unbounded ranges. For bounded ranges, it
is the difference between the lower and upper bound. Note that this
contradicts the deeply entrenched mathematical convention of including
other ends for a range. But it does gives a lot of convenience and
elegance.
"""
return self._upper - self._lower if self.bounded else None
@property
def bounded(self):
"""If the range is explicitly bounded."""
return self._lower is not None
@property
def args(self):
"""The arguments for range creation.
When the bounds are present, we have a triple, or we have a singleton
tuple of only the label.
"""
if self.bounded:
return self._label, self._lower, self._upper
else:
return self._label,
[docs] def __hash__(self):
"""Hash the symbolic range.
"""
return hash(self.args)
[docs] def __eq__(self, other):
"""Compare equality of two ranges.
"""
return isinstance(other, type(self)) and (
self.args == other.args
)
[docs] def __repr__(self):
"""Form the representative string.
"""
return ''.join([
'Range(', ', '.join(repr(i) for i in self.args), ')'
])
[docs] def __str__(self):
"""Form readable string representation.
"""
return str(self._label)
@property
def sort_key(self):
"""The sort key for the range."""
key = [self._label]
if self.bounded:
key.extend(sympy_key(i) for i in [self._lower, self._upper])
return key
[docs] def replace_label(self, new_label):
"""Replace the label of a given range.
The bounds will be the same as the original range.
"""
return Range(new_label, self._lower, self._upper)
[docs] def __lt__(self, other):
"""Compare two ranges.
This method is meant to skip explicit calling of the sort key when it is
not convenient.
"""
if not isinstance(other, Range):
raise TypeError('Invalid range to compare', other)
return self.sort_key < other.sort_key
class ATerms(abc.ABC):
"""Abstract base class for terms.
This abstract class is meant for things that can be interpreted as a local
collection of some tensor terms, mostly used for user input of tensor terms.
"""
@abc.abstractproperty
def terms(self) -> typing.List['Term']:
"""Get an list for the terms.
"""
pass
#
# Mathematical operations.
#
_op_priority = 19.0 # Just less than the full tensor.
def __mul__(self, other):
"""Multiply something on the right."""
if is_higher(other, self._op_priority):
return NotImplemented
return self._mul(self.terms, parse_terms(other))
def __rmul__(self, other):
"""Multiply something on the left."""
if is_higher(other, self._op_priority):
return NotImplemented
return self._mul(parse_terms(other), self.terms)
@staticmethod
def _mul(left_terms, right_terms):
"""Multiply the left terms with the right terms.
Note that the terms should not have any conflict in dummies. Actually,
by the common scheme in user input by drudge, the terms should normally
have no summations at all. So this function has different semantics
than the term multiplication function from the Terms class.
"""
prod_terms = []
for i, j in itertools.product(left_terms, right_terms):
# A shallow checking on sums, normally we have no sums by design.
sums = _cat_sums(i.sums, j.sums)
amp = i.amp * j.amp
vecs = i.vecs + j.vecs
prod_terms.append(Term(sums, amp, vecs))
continue
return Terms(prod_terms)
def __add__(self, other):
"""Add something on the right."""
if is_higher(other, self._op_priority):
return NotImplemented
return self._add(self.terms, parse_terms(other))
def __radd__(self, other):
"""Add something on the left."""
if is_higher(other, self._op_priority):
return NotImplemented
return self._add(parse_terms(other), self.terms)
def __sub__(self, other):
"""Subtract something on the right."""
if is_higher(other, self._op_priority):
return NotImplemented
other_terms = self._neg_terms(parse_terms(other))
return self._add(self.terms, other_terms)
def __rsub__(self, other):
"""Be subtracted from something on the left."""
if is_higher(other, self._op_priority):
return NotImplemented
self_terms = self._neg_terms(parse_terms(self))
return self._add(parse_terms(other), self_terms)
def __neg__(self):
"""Negate the terms."""
return Terms(self._neg_terms(parse_terms(self)))
@staticmethod
def _add(left_terms, right_terms):
"""Add the terms together.
"""
return Terms(itertools.chain(left_terms, right_terms))
@staticmethod
def _neg_terms(terms: typing.Iterable['Term']):
"""Negate the given terms.
The resulted terms are lazily evaluated.
"""
return (
Term(i.sums, i.amp * _NEG_UNITY, i.vecs)
for i in terms
)
class Terms(ATerms):
"""A local collection of terms.
This class is a concrete collection of terms. Any mathematical operation on
the abstract terms objects will be elevated to instances of this class.
"""
__slots__ = ['_terms']
def __init__(self, terms: typing.Iterable['Term']):
"""Initialize the terms object.
The possibly lazy iterable of terms will be instantiated here. And zero
terms will be filtered out.
"""
self._terms = list(i for i in terms if i.amp != 0)
@property
def terms(self):
"""Get the terms in the collection."""
return self._terms
def parse_terms(obj) -> typing.List['Term']:
"""Parse the object into a list of terms."""
if isinstance(obj, ATerms):
return obj.terms
else:
expr = ensure_expr(obj)
return [Term((), expr, ())]
[docs]class Vec(ATerms, CantSympify):
"""Vectors.
Vectors are the basic non-commutative quantities. Its objects consist of an
label for its base and some indices. The label is allowed to be any
hashable and ordered Python object, although small objects, like string, are
advised. The indices are always sympified into SymPy expressions.
Its objects can be created directly by giving the label and indices, or
existing vector objects can be subscribed to get new ones. The semantics is
similar to Haskell functions.
Note that users cannot directly assign to the attributes of this class.
This class can be used by itself, it can also be subclassed for special
use cases.
Despite very different internal data structure, the this class is attempted
to emulate the behaviour of the SymPy ``IndexedBase`` class
"""
__slots__ = ['_label', '_indices']
[docs] def __init__(self, label, indices=()):
"""Initialize a vector.
Atomic indices are added as the only index. Iterable values will
have all of its entries added.
"""
self._label = label
if not isinstance(indices, Iterable):
indices = (indices,)
self._indices = tuple(ensure_expr(i, 'vector index') for i in indices)
@property
def label(self):
"""The label for the base of the vector.
"""
return self._label
@property
def base(self):
"""The base of the vector.
This base can be subscribed to get other vectors.
"""
return Vec(self._label, [])
@property
def indices(self):
"""The indices to the vector.
"""
return self._indices
[docs] def __getitem__(self, item):
"""Append the given indices to the vector.
When multiple new indices are to be given, they have to be given as a
tuple.
"""
if not isinstance(item, tuple):
item = (item,)
# Pay attention to subclassing.
return type(self)(self._label, itertools.chain(self._indices, item))
[docs] def __repr__(self):
"""Form repr string form the vector."""
return ''.join([
type(self).__name__, '(', repr(self._label), ', (',
', '.join(repr(i) for i in self._indices),
'))'
])
[docs] def __str__(self):
"""Form a more readable string representation."""
label = str(self._label)
if len(self._indices) > 0:
indices = ''.join([
'[', ', '.join(str(i) for i in self._indices), ']'
])
else:
indices = ''
return label + indices
[docs] def __hash__(self):
"""Compute the hash value of a vector."""
return hash((self._label, self._indices))
[docs] def __eq__(self, other):
"""Compares the equality of two vectors."""
return (
(isinstance(self, type(other)) or isinstance(other, type(self))) and
self._label == other.label and self._indices == other.indices
)
@property
def sort_key(self):
"""The sort key for the vector.
This is a generic sort key for vectors. Note that this is only useful
for sorting the simplified terms and should not be used in the
normal-ordering operations.
"""
key = [self._label]
key.extend(sympy_key(i) for i in self._indices)
return key
#
# Misc facilities
#
[docs] def map(self, func):
"""Map the given function to indices."""
return Vec(self._label, (func(i) for i in self._indices))
@property
def terms(self):
"""Get the terms from the vector.
This is for the user input.
"""
return [Term((), _UNITY, (self,))]
[docs]class Term(ATerms):
"""Terms in tensor expression.
This is the core class for storing symbolic tensor expressions. The
actual symbolic tensor type is just a shallow wrapper over a list of
terms. It is basically comprised of three fields, a list of summations,
a SymPy expression giving the amplitude, and a list of non-commutative
vectors.
"""
__slots__ = [
'_sums',
'_amp',
'_vecs',
'_free_vars',
'_dumms'
]
[docs] def __init__(
self, sums: typing.Tuple[typing.Tuple[Symbol, Range], ...],
amp: Expr, vecs: typing.Tuple[Vec, ...],
free_vars: typing.FrozenSet[Symbol] = None,
dumms: typing.Mapping[Symbol, Range] = None,
):
"""Initialize the tensor term.
Users seldom have the need to create terms directly by this function.
So this constructor is mostly a developer function, no sanity checking
is performed on the input for performance. Most importantly, this
constructor does **not** copy either the summations or the vectors and
directly expect them to be tuples (for hashability). And the amplitude
is **not** simpyfied.
Also, it is important that the free variables and dummies dictionary be
given only when they really satisfy what we got for them.
"""
# For performance reason, no checking is done.
#
# Uncomment for debugging.
# valid = (
# isinstance(sums, tuple) and isinstance(amp, Expr)
# and isinstance(vecs, tuple)
# )
# if not valid:
# raise TypeError('Invalid argument to term creation')
self._sums = sums
self._amp = amp
self._vecs = vecs
self._free_vars = free_vars
self._dumms = dumms
@property
def sums(self):
"""The summations of the term."""
return self._sums
@property
def amp(self) -> Expr:
"""The amplitude expression."""
return self._amp
@property
def vecs(self):
"""The vectors in the term."""
return self._vecs
@property
def is_scalar(self):
"""If the term is a scalar."""
return len(self._vecs) == 0
@property
def args(self):
"""The triple of summations, amplitude, and vectors."""
return self._sums, self._amp, self._vecs
[docs] def __hash__(self):
"""Compute the hash of the term."""
return hash(self.args)
[docs] def __eq__(self, other):
"""Evaluate the equality with another term."""
return isinstance(other, type(self)) and self.args == other.args
[docs] def __repr__(self):
"""Form the representative string of a term."""
return 'Term(sums=[{}], amp={}, vecs=[{}])'.format(
', '.join(repr(i) for i in self._sums),
repr(self._amp),
', '.join(repr(i) for i in self._vecs)
)
[docs] def __str__(self):
"""Form the readable string representation of a term."""
if len(self._sums) > 0:
header = 'sum_{{{}}} '.format(
', '.join(str(i[0]) for i in self._sums))
else:
header = ''
factors = [str(self._amp)]
factors.extend(str(i) for i in self._vecs)
return header + ' * '.join(factors)
@property
def sort_key(self):
"""The sort key for a term.
This key attempts to sort the terms by complexity, with simpler terms
coming earlier. This capability of sorting the terms will make the
equality comparison of multiple terms easier.
This sort key also ensures that terms that can be merged are always put
into adjacent positions.
"""
vec_keys = [i.sort_key for i in self._vecs]
sum_keys = [(i[1].sort_key, sympy_key(i[0])) for i in self._sums]
return (
len(vec_keys), vec_keys,
len(sum_keys), sum_keys,
sympy_key(self._amp)
)
@property
def terms(self):
"""The singleton list of the current term.
This property is for the rare cases where direct construction of tensor
inputs from SymPy expressions and vectors are not sufficient.
"""
return [self]
[docs] def scale(self, factor):
"""Scale the term by a factor.
"""
return Term(self._sums, self._amp * factor, self._vecs)
[docs] def mul_term(self, other, dumms=None, excl=None):
"""Multiply with another tensor term.
Note that by this function, the free symbols in the two operands are not
automatically excluded.
"""
lhs, rhs = self.reconcile_dumms(other, dumms, excl)
return Term(
lhs.sums + rhs.sums, lhs.amp * rhs.amp, lhs.vecs + rhs.vecs
)
[docs] def comm_term(self, other, dumms=None, excl=None):
"""Commute with another tensor term.
In ths same way as the multiplication operation, here the free symbols
in the operands are not automatically excluded.
"""
lhs, rhs = self.reconcile_dumms(other, dumms, excl)
sums = lhs.sums + rhs.sums
amp0 = lhs.amp * rhs.amp
return [
Term(sums, amp0, lhs.vecs + rhs.vecs),
Term(sums, -amp0, rhs.vecs + lhs.vecs)
]
[docs] def reconcile_dumms(self, other, dumms, excl):
"""Reconcile the dummies in two terms."""
lhs, dummbegs = self.reset_dumms(dumms, excl=excl)
rhs, _ = other.reset_dumms(dumms, dummbegs=dummbegs, excl=excl)
return lhs, rhs
#
# SymPy related
#
@property
def exprs(self):
"""Loop over the sympy expression in the term.
Note that the summation dummies are not looped over.
"""
yield self._amp
for vec in self._vecs:
yield from vec.indices
@property
def free_vars(self):
"""The free symbols used in the term.
"""
if self._free_vars is None:
dumms = self.dumms
self._free_vars = set(
i for expr in self.exprs for i in expr.atoms(Symbol)
if i not in dumms
)
return self._free_vars
@property
def dumms(self):
"""Get the mapping from dummies to their range.
"""
if self._dumms is None:
self._dumms = dict(self._sums)
return self._dumms
@property
def amp_factors(self):
"""The factors in the amplitude expression.
This is a convenience wrapper over :py:meth:`get_amp_factors` for the
case of no special additional symbols.
"""
return self.get_amp_factors(set())
[docs] def get_amp_factors(self, special_symbs):
"""Get the factors in the amplitude and the coefficient.
The indexed factors and factors involving dummies or the symbols in the
given special symbols set will be returned as a list, with the rest
returned as a single SymPy expression.
Error will be raised if the amplitude is not a monomial.
"""
amp = self._amp
if isinstance(amp, Add):
raise ValueError('Invalid amplitude: ', amp, 'expecting monomial')
if isinstance(amp, Mul):
all_factors = amp.args
else:
all_factors = (amp,)
dumms = self.dumms
factors = []
coeff = _UNITY
for factor in all_factors:
need_treatment = any(
(i in dumms or i in special_symbs)
for i in factor.atoms(Symbol)
) or isinstance(factor, Indexed)
if need_treatment:
factors.append(factor)
else:
coeff *= factor
continue
return factors, coeff
[docs] def map(self, func, sums=None, amp=None, vecs=None, skip_vecs=False):
"""Map the given function to the SymPy expressions in the term.
The given function will **not** be mapped to the dummies in the
summations. When operations on summations are needed, a **tuple**
for the new summations can be given.
By passing the identity function, this function can also be used to
replace the summation list, the amplitude expression, or the vector
part.
"""
return Term(
self._sums if sums is None else sums,
func(self._amp if amp is None else amp),
tuple(
i.map(func) if not skip_vecs else i
for i in (self._vecs if vecs is None else vecs)
)
)
[docs] def subst(self, substs, sums=None, amp=None, vecs=None, purge_sums=False):
"""Perform symbol substitution on the SymPy expressions.
After the replacement of the fields given, the given substitutions are
going to be performed using SymPy ``xreplace`` method simultaneously.
If purge sums is set, the summations whose dummy is substituted is going
to be removed.
"""
if sums is None:
sums = self._sums
if purge_sums:
sums = tuple(i for i in sums if i[0] not in substs)
return self.map(
lambda x: x.xreplace(substs), sums=sums, amp=amp, vecs=vecs
)
[docs] def reset_dumms(self, dumms, dummbegs=None, excl=None):
"""Reset the dummies in the term.
The term with dummies reset will be returned alongside with the new
dummy begins dictionary. Note that the dummy begins dictionary will be
mutated if one is given.
ValueError will be raised when no more dummies are available.
"""
new_sums, substs, dummbegs = self.reset_sums(
self._sums, dumms, dummbegs, excl
)
return self.subst(substs, new_sums), dummbegs
[docs] @staticmethod
def reset_sums(sums, dumms, dummbegs=None, excl=None):
"""Reset the given summations.
The new summation list, substitution dictionary, and the new dummy begin
dictionary will be returned.
"""
if dummbegs is None:
dummbegs = {}
new_sums = []
substs = {}
for dumm_i, range_i in sums:
# For linter.
new_dumm = None
new_beg = None
beg = dummbegs[range_i] if range_i in dummbegs else 0
for i in itertools.count(beg):
try:
tentative = dumms[range_i][i]
except KeyError:
raise ValueError('Dummies for range', range_i,
'is not given')
except IndexError:
raise ValueError('Dummies for range', range_i, 'is used up')
if excl is None or tentative not in excl:
new_dumm = tentative
new_beg = i + 1
break
else:
continue
new_sums.append((new_dumm, range_i))
substs[dumm_i] = new_dumm
dummbegs[range_i] = new_beg
continue
return tuple(new_sums), substs, dummbegs
#
# Amplitude simplification
#
[docs] def simplify_deltas(self, resolvers):
"""Simplify deltas in the amplitude of the expression."""
new_amp, substs = simplify_deltas_in_expr(
self.dumms, self._amp, resolvers
)
# Note that here the substitutions needs to be performed in order.
return self.subst(substs, purge_sums=True, amp=new_amp)
[docs] def simplify_sums(self):
"""Simplify the summations in the term."""
involved = {
i for expr in self.exprs for i in expr.atoms(Symbol)
}
new_sums = []
factor = _UNITY
dirty = False
for symb, range_ in self._sums:
if symb not in involved and range_.bounded:
dirty = True
factor *= range_.size
else:
new_sums.append((symb, range_))
continue
if dirty:
return Term(tuple(new_sums), factor * self._amp, self._vecs)
else:
return self
[docs] def expand(self):
"""Expand the term into many terms."""
expanded_amp = self.amp.expand()
if expanded_amp == 0:
return []
elif isinstance(expanded_amp, Add):
amp_terms = expanded_amp.args
else:
amp_terms = (expanded_amp,)
return [self.map(lambda x: x, amp=i) for i in amp_terms]
#
# Canonicalization.
#
[docs] def canon(self, symms=None, vec_colour=None):
"""Canonicalize the term.
The given vector colour should be a callable accepting the index
within vector list (under the keyword ``idx``) and the vector itself
(under keyword ``vec``). By default, vectors has colour the same as
its index within the list of vectors.
Note that whether or not colours for the vectors are given, the vectors
are never permuted in the result.
"""
# Factors to canonicalize.
factors = []
# Additional information for factor reconstruction.
#
# It has integral placeholders for vectors and scalar factors without
# any indexed quantity, the expression with (the only) indexed replaced
# by the placeholder for factors with indexed.
factors_info = []
vec_factor = 1
unindexed_factor = 2
#
# Get the factors in the amplitude.
#
# Cache globals for performance.
wrapper_base = _WRAPPER_BASE
indexed_placeholder = _INDEXED_PLACEHOLDER
# Extractors for the indexed, defined here to avoid repeated list and
# function creation for each factor.
indexed = []
def replace_indexed(base, *indices):
"""Replace the indexed quantity inside the factor."""
indexed.append(base[indices])
return indexed_placeholder
amp_factors, coeff = self.amp_factors
for i in amp_factors:
amp_no_indexed = i.replace(
Indexed, NonsympifiableFunc(replace_indexed)
)
n_indexed = len(indexed)
if n_indexed > 1:
raise ValueError(
'Invalid amplitude factor containing multiple indexed', i
)
elif n_indexed == 1:
factors.append((
indexed[0], (
_COMMUTATIVE,
indexed[0].base.label.name,
sympy_key(amp_no_indexed)
)
))
factors_info.append(amp_no_indexed)
indexed.clear() # Clean the container for the next factor.
else: # No indexed.
# When the factor never has an indexed base, we treat it as
# indexing a uni-valence internal indexed base.
factors.append((
wrapper_base[i], (_COMMUTATIVE,)
))
factors_info.append(unindexed_factor)
continue
#
# Get the factors in the vectors.
#
for i, v in enumerate(self._vecs):
colour = i if vec_colour is None else vec_colour(
idx=i, vec=v, term=self
)
factors.append((
v, (_NON_COMMUTATIVE, colour)
))
factors_info.append(vec_factor)
continue
#
# Invoke the core simplification.
#
res_sums, canoned_factors, canon_coeff = canon_factors(
self._sums, factors, symms if symms is not None else {}
)
#
# Reconstruct the result
#
res_amp = coeff * canon_coeff
res_vecs = []
for i, j in zip(canoned_factors, factors_info):
if j == vec_factor:
# When we have a vector.
res_vecs.append(i)
elif j == unindexed_factor:
res_amp *= i.indices[0]
else:
res_amp *= j.xreplace({indexed_placeholder: i})
continue
return Term(tuple(res_sums), res_amp, tuple(res_vecs))
[docs] def canon4normal(self, symms):
"""Canonicalize the term for normal-ordering.
This is the preparation task for normal ordering. The term will be
canonicalized with all the vectors considered the same. And the dummies
will be reset internally according to the summation list.
"""
# Make the internal dummies factory to canonicalize the vectors.
dumms = collections.defaultdict(list)
for i, v in self.sums:
dumms[v].append(i)
for i in dumms.values():
# This is important for ordering vectors according to SymPy key in
# normal ordering.
i.sort(key=sympy_key)
canon_term = (
self.canon(symms=symms, vec_colour=lambda idx, vec, term: 0)
.reset_dumms(dumms)[0]
)
return canon_term
[docs] def has_base(self, base):
"""Test if the given base is present in the current term."""
if isinstance(base, (IndexedBase, Symbol)):
return self._amp.has(base)
elif isinstance(base, Vec):
label = base.label
return any(
i.label == label for i in self._vecs
)
else:
raise TypeError('Invalid base to test presence', base)
_WRAPPER_BASE = IndexedBase(
'internalWrapper', shape=('internalShape',)
)
_INDEXED_PLACEHOLDER = Symbol('internalIndexedPlaceholder')
# For colour of factors in a term.
_COMMUTATIVE = 1
_NON_COMMUTATIVE = 0
#
# Substitution by tensor definition
# ---------------------------------
#
def subst_vec_in_term(term: Term, lhs: Vec, rhs_terms: typing.List[Term],
dumms, dummbegs, excl):
"""Substitute a matching vector in the given term.
"""
sums = term.sums
vecs = term.vecs
amp = term.amp
for i, v in enumerate(vecs):
substs = _match_indices(v, lhs)
if substs is None:
continue
else:
substed_vec_idx = i
break
else:
return None # Based on nest bind protocol.
subst_states = _prepare_subst_states(
rhs_terms, substs, dumms, dummbegs, excl
)
res = []
for i, j in subst_states:
new_vecs = list(vecs)
new_vecs[substed_vec_idx:substed_vec_idx + 1] = i.vecs
res.append((
Term(sums + i.sums, amp * i.amp, tuple(new_vecs)), j
))
continue
return res
def subst_factor_in_term(term: Term, lhs, rhs_terms: typing.List[Term],
dumms, dummbegs, excl, full_simplify=True):
"""Substitute a scalar factor in the term.
While vectors are always flattened lists of vectors. The amplitude part can
be a lot more complex. Here we strive to replace only one instance of the
LHS by two placeholders, the substitution is possible only if the result
expands to two terms, each containing only one of the placeholders.
"""
amp = term.amp
placeholder1 = Symbol('internalSubstPlaceholder1')
placeholder2 = Symbol('internalSubstPlaceholder2')
found = [False]
substs = {}
if isinstance(lhs, Symbol):
label = lhs
def query_func(expr):
"""Filter for the given symbol."""
return not found[0] and expr == lhs
def replace_func(_):
"""Replace the symbol."""
found[0] = True
return placeholder1 + placeholder2
elif isinstance(lhs, Indexed):
label = lhs.base.label
# Here, in order to avoid the matching being called twice, we separate
# the actual checking into both the query and the replace call-back.
def query_func(expr):
"""Query for a reference to a given indexed base."""
return not found[0] and isinstance(expr, Indexed)
def replace_func(expr):
"""Replace the reference to the indexed base."""
match_res = _match_indices(expr, lhs)
if match_res is None:
return expr
found[0] = True
assert len(substs) == 0
substs.update(match_res)
return placeholder1 + placeholder2
else:
raise TypeError(
'Invalid LHS for substitution', lhs,
'expecting symbol or indexed quantity'
)
# Some special treatment is needed for powers.
pow_placeholder = Symbol('internalSubstPowPlaceholder')
pow_val = [None]
def decouple_pow(base, e):
"""Decouple a power."""
if pow_val[0] is None and base.has(label):
pow_val[0] = base
return base * Pow(pow_placeholder, e - 1)
else:
return Pow(base, e)
amp = amp.replace(Pow, NonsympifiableFunc(decouple_pow))
amp = amp.replace(query_func, NonsympifiableFunc(replace_func))
if not found[0]:
return None
if pow_val[0] is not None:
amp = amp.xreplace({pow_placeholder: pow_val[0]})
amp = (
amp.simplify() if full_simplify else amp
).expand()
# It is called nonlinear error, but some nonlinear forms, like conjugation,
# can be handled.
nonlinear_err = ValueError(
'Invalid amplitude', term.amp, 'not expandable in', lhs
)
if not isinstance(amp, Add) or len(amp.args) != 2:
raise nonlinear_err
amp_term1, amp_term2 = amp.args
diff = amp_term1.atoms(Symbol) ^ amp_term2.atoms(Symbol)
if diff != {placeholder1, placeholder2}:
raise nonlinear_err
if amp_term1.has(placeholder1):
amp = amp_term1
else:
amp = amp_term2
subst_states = _prepare_subst_states(
rhs_terms, substs, dumms, dummbegs, excl
)
sums = term.sums
vecs = term.vecs
res = []
for i, j in subst_states:
res.append((
Term(sums + i.sums, amp.xreplace({placeholder1: i.amp}), vecs), j
))
continue
return res
def _match_indices(target, expr):
"""Match the target against the give expression for the indices.
Both arguments must be scalar or vector indexed quantities. The second
argument should contain Wilds.
"""
if target.base != expr.base or len(target.indices) != len(expr.indices):
return None
substs = {}
for i, j in zip(target.indices, expr.indices):
res = i.match(j)
if res is None:
return None
else:
substs.update(res)
continue
return substs
def _prepare_subst_states(rhs_terms, substs, dumms, dummbegs, excl):
"""Prepare the substitution states.
Here we only have partially-finished substitution state for the next loop,
where for each substituting term on the RHS, the given wild symbols in
it will be substituted, then its dummies are going to be resolved. Pairs of
the prepared RHS terms and the corresponding dummbegs will be returned. It
is the responsibility of the caller to assemble the terms into the actual
substitution state, by information in the term to be substituted.
"""
subst_states = []
for i, v in enumerate(rhs_terms):
# Reuse existing dummy begins only for the first term.
if i == 0:
curr_dummbegs = dummbegs
else:
curr_dummbegs = dict(dummbegs)
curr_term, curr_dummbegs = v.reset_dumms(dumms, curr_dummbegs, excl)
subst_states.append((
curr_term.subst(substs), curr_dummbegs
))
continue
return subst_states
def rewrite_term(
term: Term, vecs: typing.Sequence[Vec], new_amp: Expr
) -> typing.Tuple[typing.Optional[Term], Term]:
"""Rewrite the given term.
When a rewriting happens, the result will be the pair of the rewritten term
and the term for the definition of the new amplitude, or the result will be
None and the original term.
"""
if len(term.vecs) != len(vecs):
return None, term
substs = {}
for i, j in zip(term.vecs, vecs):
curr_substs = _match_indices(i, j)
if curr_substs is None:
break
for wild, expr in curr_substs.items():
if wild in substs:
if substs[wild] != expr:
break
else:
substs[wild] = expr
else:
# When a match is found.
res_amp = new_amp.xreplace(substs)
res_symbs = res_amp.atoms(Symbol)
res_sums = tuple(i for i in term.sums if i[0] in res_symbs)
def_sums = tuple(i for i in term.sums if i[0] not in res_symbs)
return Term(res_sums, res_amp, term.vecs), Term(
def_sums, term.amp, ()
)
return None, term
#
# User interface support
# ----------------------
#
def sum_term(sum_args, summand, predicate=None) -> typing.List[Term]:
"""Sum the given expression.
This method is meant for easy creation of tensor terms. The arguments
should start with summations and ends with the expression that is summed.
This core function is designed to be wrapped in functions working with full
symbolic tensors.
"""
# Too many SymPy stuff are callable.
if isinstance(summand, Callable) and not isinstance(summand, Basic):
inp_terms = None
inp_func = summand
else:
inp_terms = parse_terms(summand)
inp_func = None
if len(sum_args) == 0:
return list(inp_terms)
sums, substs = _parse_sums(sum_args)
res = []
for sum_i in itertools.product(*sums):
for subst_i in itertools.product(*substs):
subst_dict = dict(subst_i)
# We alway assemble the call sequence here, since this part should
# never be performance critical.
call_seq = dict(sum_i)
call_seq.update(subst_dict)
if not (predicate is None or predicate(call_seq)):
continue
if inp_terms is not None:
curr_inp_terms = inp_terms
else:
curr_inp_terms = parse_terms(inp_func(call_seq))
curr_terms = [i.subst(
subst_dict, sums=_cat_sums(i.sums, sum_i)
) for i in curr_inp_terms]
res.extend(curr_terms)
continue
continue
return res
def _parse_sums(args):
"""Parse the summation arguments passed to the sum interface.
The result will be the decomposed form of the summations and substitutions
from the arguments. For either of them, each entry in the result is a list
of pairs of the dummy with the actual range or symbolic expression.
"""
sums = []
substs = []
for arg in args:
if not isinstance(arg, Sequence):
raise TypeError('Invalid summation', arg, 'expecting a sequence')
if len(arg) < 2:
raise ValueError('Invalid summation', arg,
'expecting dummy and range')
dumm = ensure_symb(arg[0], 'dummy')
flattened = []
for i in arg[1:]:
if isinstance(i, Iterable):
flattened.extend(i)
else:
flattened.append(i)
continue
contents = []
expecting_range = None
for i in flattened:
if isinstance(i, Range):
if expecting_range is None:
expecting_range = True
elif not expecting_range:
raise ValueError('Invalid summation on', i,
'expecting expression')
contents.append((dumm, i))
else:
if expecting_range is None:
expecting_range = False
elif expecting_range:
raise ValueError('Invalid summation on', i,
'expecting a range')
expr = ensure_expr(i)
contents.append((dumm, expr))
if expecting_range:
sums.append(contents)
else:
substs.append(contents)
return sums, substs
def _cat_sums(sums1, sums2):
"""Concatenate two summation lists.
This function forms the tuple and ensures that there is no conflicting
dummies in the two summations. This function is mostly for sanitizing user
inputs.
"""
sums = tuple(itertools.chain(sums1, sums2))
# Construction of the counter is separate from the addition of
# content due to a PyCharm bug.
dumm_counts = collections.Counter()
dumm_counts.update(i[0] for i in sums)
if any(i > 1 for i in dumm_counts.values()):
raise ValueError(
'Invalid summations to be concatenated', (sums1, sums2),
'expecting no conflict in dummies'
)
return sums
def einst_term(term: Term, resolvers):
"""Add summations according to the Einstein convention to a term.
In order for problems easy to be detected for users, here we just add the
most certain Einstein summations, while give warnings when there is anything
looking like a summation but is not added because of something suspicious.
"""
# Strategy, find all indices to indexed bases, and replace them with
# placeholder symbols so that we can detect other free symbols in the
# amplitude as well.
next_idx = [0]
indices = []
def replace_cb(_, *curr_indices):
"""Replace indexed quantities."""
indices.extend(curr_indices)
placeholder = Symbol('internalEinstPlaceholder{}'.format(next_idx[0]))
next_idx[0] += 1
return placeholder
res_amp = term.amp.replace(Indexed, NonsympifiableFunc(replace_cb))
for i in term.vecs:
indices.extend(i.indices)
# Usage tally of the symbols, in bare form and in complex expressions.
use_tally = collections.defaultdict(lambda: [0, 0])
for index in indices:
if isinstance(index, Symbol):
use_tally[index][0] += 1
else:
for i in index.atoms(Symbol):
use_tally[i][1] += 1
continue
continue
existing_dumms = term.dumms
new_sums = []
for symb, use in use_tally.items():
if symb in existing_dumms:
continue
if use[0] != 2 and use[0] + use[1] != 2:
# No chance to be an Einstein summation.
continue
if use[1] != 0:
warnings.warn(
'Symbol {} is not summed due to its usage in complex indices'
.format(symb)
)
continue
if res_amp.has(symb):
warnings.warn(
'Symbol {} is not summed due to its usage in the amplitude'
.format(symb)
)
continue
range_ = try_resolve_range(symb, {}, resolvers)
if range_ is None:
warnings.warn(
'Symbol {} is not summed for the incapability to resolve range'
.format(symb)
)
continue
# Now we have an Einstein summation.
new_sums.append((symb, range_))
continue
# Make summation from Einstein convention deterministic.
new_sums.sort(key=lambda x: (x[1].sort_key, x[0].name))
return Term(_cat_sums(term.sums, new_sums), term.amp, term.vecs)
def parse_term(term):
"""Parse a term.
Other things that can be interpreted as a term are also accepted.
"""
if isinstance(term, Term):
return term
elif isinstance(term, Vec):
return Term((), _UNITY, (term,))
else:
return Term((), sympify(term), ())
#
# Delta simplification utilities.
# -------------------------------
#
# The core idea of delta simplification is that a delta can be replaced by a
# new, possibly simpler, expression, with a possible substitution on a dummy.
# The functions here aim to find and compose them.
#
def simplify_deltas_in_expr(sums_dict, amp, resolvers):
"""Simplify the deltas in the given expression.
A new amplitude will be returned with all the deltas simplified, along with
a dictionary giving the substitutions from the deltas.
"""
substs = {}
if amp == 0:
return amp, substs
new_amp = amp.replace(KroneckerDelta, NonsympifiableFunc(functools.partial(
_proc_delta_in_amp, sums_dict, resolvers, substs
)))
return new_amp, substs
def compose_simplified_delta(amp, new_substs, substs, sums_dict, resolvers):
"""Compose delta simplification result with existing substitutions.
This function can be interpreted as follows. First we have a delta that has
been resolved to be equivalent to an amplitude expression and some
substitutions. Then by this function, we get what it is equivalent to when
we already have an existing bunch of earlier substitutions.
The new substitutions should be given as an iterable of old/new pairs. Then
the new amplitude and substitution from delta simplification can be composed
with existing substitution dictionary. New amplitude will be returned as
the first return value. The given substitution dictionary will be mutated
and returned as the second return value. When the new substitution is
incompatible with existing ones, the first return value will be a plain
zero.
The amplitude is a local thing in the expression tree, while the
substitutions is always global among the entire term. This function
aggregate and expands it.
"""
for subst in new_substs:
if subst is None:
continue
old = subst[0]
new = subst[1].xreplace(substs)
if old in substs:
comp_amp, new_substs = proc_delta(
substs[old], new, sums_dict, resolvers
)
amp = amp * comp_amp
if new_substs is not None:
# The new substitution cannot involve substituted symbols.
substs[new_substs[0]] = new_substs[1]
# amp could now be zero.
else:
# Easier case, a new symbol is tried to be added.
replace_old = {old: new}
for i in substs.keys():
substs[i] = substs[i].xreplace(replace_old)
substs[old] = new
continue
return amp, substs
def proc_delta(arg1, arg2, sums_dict, resolvers):
"""Processs a delta.
An amplitude and a substitution pair is going to be returned. The given
delta will be equivalent to the returned amplitude factor with the
substitution performed. None will be returned for the substitution when no
substitution is needed.
"""
dumms = [
i for i in set.union(arg1.atoms(Symbol), arg2.atoms(Symbol))
if i in sums_dict
]
if len(dumms) == 0:
return KroneckerDelta(arg1, arg2)
eqn = Eq(arg1, arg2)
# We try to solve for each of the dummies. Most likely this will only be
# executed for one loop.
for dumm in dumms:
range_ = sums_dict[dumm]
sol = solve(eqn, dumm)
if sol is S.true:
# Now we can be sure that we got an identity.
return _UNITY, None
elif len(sol) > 0:
for i in sol:
# Try to get the range of the substituting expression.
range_of_i = try_resolve_range(i, sums_dict, resolvers)
if range_of_i is None:
continue
if range_of_i == range_:
return _UNITY, (dumm, i)
else:
# We assume atomic and disjoint ranges!
return _NAUGHT, None
# We cannot resolve the range of any of the solutions. Try next
# dummy.
continue
else:
# No solution.
return _NAUGHT, None
# When we got here, all the solutions we found have undetermined range, we
# have to return the unprocessed form.
return KroneckerDelta(arg1, arg2), None
def _proc_delta_in_amp(sums_dict, resolvers, substs, *args):
"""Process a delta in the amplitude expression.
The partial application of this function is going to be used as the
call-back to SymPy replace function. This function only returns SymPy
expressions to satisfy SymPy replace interface. All actions on the
substitution are handled by an input/output argument.
"""
# We first perform the substitutions found thus far.
args = [i.xreplace(substs) for i in args]
# Process the new delta.
amp, subst = proc_delta(*args, sums_dict=sums_dict, resolvers=resolvers)
new_amp, _ = compose_simplified_delta(
amp, [subst], substs, sums_dict=sums_dict, resolvers=resolvers
)
return new_amp
#
# Gradient computation
# --------------------
#
def diff_term(term: Term, variable, real, wirtinger_conj):
"""Differentiate a term.
"""
symb = _GRAD_REAL_SYMB if real else _GRAD_SYMB
if isinstance(variable, Symbol):
lhs = variable
rhs = lhs + symb
elif isinstance(variable, Indexed):
indices = variable.indices
wilds = tuple(
Wild(_GRAD_WILD_FMT.format(i)) for i, _ in enumerate(indices)
)
lhs = variable.base[wilds]
rhs = lhs + functools.reduce(
operator.mul,
(KroneckerDelta(i, j) for i, j in zip(wilds, indices)), symb
)
else:
raise ValueError('Invalid differentiation variable', variable)
if real:
orig_amp = term.amp.replace(conjugate(lhs), lhs)
else:
orig_amp = term.amp
replaced_amp = (orig_amp.replace(lhs, rhs)).simplify()
if real:
eval_substs = {symb: 0}
else:
replaced_amp = replaced_amp.replace(
conjugate(symb), _GRAD_CONJ_SYMB
)
eval_substs = {_GRAD_CONJ_SYMB: 0, symb: 0}
if wirtinger_conj:
diff_var = _GRAD_CONJ_SYMB
else:
diff_var = symb
# Core evaluation.
res_amp = replaced_amp.diff(diff_var).xreplace(eval_substs)
res_amp = res_amp.simplify()
return term.map(lambda x: x, amp=res_amp)
# Internal symbols for gradients.
_GRAD_SYMB_FMT = 'internalGradient{tag}Placeholder'
_GRAD_SYMB = Symbol(_GRAD_SYMB_FMT.format(tag=''))
_GRAD_CONJ_SYMB = Symbol(_GRAD_SYMB_FMT.format(tag='Conj'))
_GRAD_REAL_SYMB = Symbol(
_GRAD_SYMB_FMT.format(tag='Real'), real=True
)
_GRAD_WILD_FMT = 'InternalWildSymbol{}'
#
# Misc public functions
# ---------------------
#
def try_resolve_range(i, sums_dict, resolvers):
"""Attempt to resolve the range of an expression.
None will be returned if it cannot be resolved.
"""
for resolver in itertools.chain([sums_dict], resolvers):
if isinstance(resolver, Mapping):
if i in resolver:
return resolver[i]
else:
continue
elif isinstance(resolver, Callable):
range_ = resolver(i)
if range_ is None:
continue
else:
if isinstance(range_, Range):
return range_
else:
raise TypeError('Invalid range: ', range_,
'from resolver', resolver,
'expecting range or None')
else:
raise TypeError('Invalid resolver: ', resolver,
'expecting callable or mapping')
# Never resolved nor error found.
return None