Skip to content

Commit 638c225

Browse files
committed
added dimension and derivative type inference
1 parent eb64882 commit 638c225

File tree

5 files changed

+100
-155
lines changed

5 files changed

+100
-155
lines changed

nrpylatex/core/generator.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, parser):
1616

1717
def generate(self, LHS, RHS, impsum=True):
1818
# perform implied summation on indexed expression
19-
LHS_RHS, dimension = self.expand_summation(LHS, RHS, impsum)
19+
LHS_RHS, dimension, suffix = self.expand_summation(LHS, RHS, impsum)
2020
if self._property['debug']:
2121
lineno = '[%d]' % self._property['debug']
2222
print('%s Python' % (len(lineno) * ' '))
@@ -37,7 +37,7 @@ def generate(self, LHS, RHS, impsum=True):
3737
except IndexError:
3838
raise GeneratorError('index out of range; change loop/summation range')
3939

40-
return global_env, dimension
40+
return global_env, dimension, suffix
4141

4242
def expand_summation(self, LHS, RHS, impsum=True):
4343
tree, indexing = ExprTree(LHS), []
@@ -75,10 +75,11 @@ def expand_summation(self, LHS, RHS, impsum=True):
7575
subexpr = subtree.expr
7676
if subexpr.func == Function('Tensor'):
7777
symbol = str(subexpr.args[0])
78-
dimension = self._namespace[symbol].dimension
7978
for index in subexpr.args[1:]:
8079
if str(index) in self._property['index']:
8180
dimension = self._property['index'][str(index)]
81+
else:
82+
dimension = self._namespace[symbol].dimension
8283
if str(index) in index_range and dimension != index_range[str(index)]:
8384
raise GeneratorError('inconsistent loop/summation range for index \'%s\'' % index)
8485
index_range[str(index)] = dimension
@@ -94,10 +95,11 @@ def expand_summation(self, LHS, RHS, impsum=True):
9495
argument = subexpr.args[0]
9596
derivative = 'diff(' + srepr(argument)
9697
symbol = str(argument.args[0])
97-
dimension = self._namespace[symbol].dimension
9898
for index, order in subexpr.args[1:]:
9999
if str(index) in self._property['index']:
100100
dimension = self._property['index'][str(index)]
101+
else:
102+
dimension = self._namespace[symbol].dimension
101103
if str(index) in index_range and dimension != index_range[str(index)]:
102104
raise GeneratorError('inconsistent loop/summation range for index \'%s\'' % index)
103105
index_range[str(index)] = dimension
@@ -160,11 +162,16 @@ def expand_summation(self, LHS, RHS, impsum=True):
160162
dimension_LHS = index_range[index]
161163

162164
# shift tensor indexing forward whenever dimension > upper bound
165+
# and infer derivative suffix of LHS tensor from RHS tensors
166+
suffix_LHS = None
163167
for subtree in tree.preorder():
164168
subexpr = subtree.expr
165169
if subexpr.func == Function('Tensor'):
166170
symbol = str(subexpr.args[0])
167171
dimension = self._namespace[symbol].dimension
172+
suffix = self._namespace[symbol].suffix
173+
if suffix is not None:
174+
suffix_LHS = self._property['suffix']
168175
tensor = IndexedSymbol(subexpr, dimension)
169176
indexing = IndexedSymbol.indexing(subexpr)
170177
for index in subexpr.args[1:]:
@@ -177,7 +184,7 @@ def expand_summation(self, LHS, RHS, impsum=True):
177184
indexing[i] = ('%s + %s' % (idx, shift), pos)
178185
equation[-1] = equation[-1].replace(tensor.array_format(subexpr), tensor.array_format(indexing))
179186

180-
return ' = '.join(equation), dimension_LHS
187+
return ' = '.join(equation), dimension_LHS, suffix_LHS
181188

182189
@staticmethod
183190
def separate_indexing(indexing, symbol_LHS, impsum=True):
@@ -213,39 +220,21 @@ def generate_metric(symbol, dimension, suffix):
213220
r'\epsilon_{' + ' '.join('j_' + str(i) for i in range(1, 1 + dimension)) + '} '
214221
det_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}^{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(1, 1 + dimension))
215222
inv_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}^{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(2, 1 + dimension))
216-
if suffix:
217-
latex_config += r"% declare {symbol}det {inv_symbol} --dim {dimension} --suffix {suffix}" \
218-
.format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'), dimension=dimension)
219-
else:
220-
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
223+
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
221224
latex_config += r"""
222225
\mathrm{{{symbol}det}} = \frac{{1}}{{({dimension})({factorial})}} {det_latex} \\
223226
\mathrm{{{symbol}}}_{{i_1 j_1}} = \frac{{1}}{{{factorial}}} \mathrm{{{symbol}det}}^{{{{-1}}}} ({inv_latex}) \\""" \
224-
.format(symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'), dimension=dimension,
225-
factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
226-
# latex_config += '\n' + r"% assign {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
227-
# if suffix:
228-
# latex_config += '\n' + r"% assign {symbol}det {inv_symbol} --suffix {suffix}" \
229-
# .format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'))
227+
.format(symbol=symbol[:-2], dimension=dimension, factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
230228
else:
231229
prefix = r'\epsilon^{' + ' '.join('i_' + str(i) for i in range(1, 1 + dimension)) + '} ' + \
232230
r'\epsilon^{' + ' '.join('j_' + str(i) for i in range(1, 1 + dimension)) + '} '
233231
det_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}_{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(1, 1 + dimension))
234232
inv_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}_{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(2, 1 + dimension))
235-
if suffix:
236-
latex_config += r"% declare {symbol}det {inv_symbol} --dim {dimension} --suffix {suffix}" \
237-
.format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'), dimension=dimension)
238-
else:
239-
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
233+
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
240234
latex_config += r"""
241235
\mathrm{{{symbol}det}} = \frac{{1}}{{({dimension})({factorial})}} {det_latex} \\
242236
\mathrm{{{symbol}}}^{{i_1 j_1}} = \frac{{1}}{{{factorial}}} \mathrm{{{symbol}det}}^{{{{-1}}}} ({inv_latex}) \\""" \
243-
.format(symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'), dimension=dimension,
244-
factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
245-
# latex_config += '\n' + r"% assign {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
246-
# if suffix:
247-
# latex_config += '\n' + r"% assign {symbol}det {inv_symbol} --suffix {suffix}" \
248-
# .format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'))
237+
.format(symbol=symbol[:-2], dimension=dimension, factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
249238
return latex_config
250239

251240
@staticmethod
@@ -288,10 +277,7 @@ def generate_covdrv(function, covdrv_index, symbol=None, diacritic=None, dimensi
288277
RHS += '^{%s}_{%s %s} (%s)' % (index, bound_index, covdrv_index, latex)
289278
else:
290279
RHS += '^{%s}_{%s %s} (%s)' % (bound_index, index, covdrv_index, latex)
291-
config = ('% declare ' + symbol + ' --dim %d --suffix dD\n' % dimension) if symbol else ''
292-
return config + LHS + ' = ' + RHS
293-
# config = (' % assign ' + symbol + ' --suffix dD\n') if symbol else ''
294-
# return LHS + ' = ' + RHS + config
280+
return LHS + ' = ' + RHS
295281

296282
@staticmethod
297283
def generate_liedrv(function, vector, weight=None):

0 commit comments

Comments
 (0)