Skip to content

Commit cfc412e

Browse files
authored
Merge pull request #116 from sandialabs/plumbing/variable_properties
Plumbing/variable properties
2 parents 4fdc0b6 + 12f928d commit cfc412e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+951
-383
lines changed

optimism/FunctionSpace.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,14 @@ def interpolate_to_points(functionSpace, nodalField):
193193
return jax.vmap(interpolate_to_element_points, (None, 0, 0))(nodalField, functionSpace.shapes, functionSpace.mesh.conns)
194194

195195

196-
def integrate_over_block(functionSpace, U, stateVars, dt, func, block,
196+
def vmapPropValue(propArray):
197+
numAxes = len(propArray.shape)
198+
if numAxes > 1:
199+
return 0
200+
else:
201+
return None
202+
203+
def integrate_over_block(functionSpace, U, stateVars, props, dt, func, block,
197204
*params, modify_element_gradient=default_modify_element_gradient):
198205
"""Integrates a density function over a block of the mesh.
199206
@@ -219,11 +226,11 @@ def integrate_over_block(functionSpace, U, stateVars, dt, func, block,
219226
block of elements.
220227
"""
221228

222-
vals = evaluate_on_block(functionSpace, U, stateVars, dt, func, block, *params, modify_element_gradient=modify_element_gradient)
229+
vals = evaluate_on_block(functionSpace, U, stateVars, props, dt, func, block, *params, modify_element_gradient=modify_element_gradient)
223230
return np.dot(vals.ravel(), functionSpace.vols[block].ravel())
224231

225232

226-
def evaluate_on_block(functionSpace, U, stateVars, dt, func, block,
233+
def evaluate_on_block(functionSpace, U, stateVars, props, dt, func, block,
227234
*params, modify_element_gradient=default_modify_element_gradient):
228235
"""Evaluates a density function at every quadrature point in a block of the mesh.
229236
@@ -249,23 +256,60 @@ def evaluate_on_block(functionSpace, U, stateVars, dt, func, block,
249256
density functional ``func`` at every quadrature point in the block.
250257
"""
251258
fs = functionSpace
252-
compute_elem_values = jax.vmap(evaluate_on_element, (None, None, 0, None, 0, 0, 0, 0, None, None, *tuple(0 for p in params)))
253259

254-
blockValues = compute_elem_values(U, fs.mesh.coords, stateVars[block], dt, fs.shapes[block],
260+
compute_elem_values = jax.vmap(evaluate_on_element, (None, None, 0, vmapPropValue(props), None, 0, 0, 0, 0, None, None, *tuple(0 for p in params)))
261+
262+
blockValues = compute_elem_values(U, fs.mesh.coords, stateVars[block], props, dt, fs.shapes[block],
263+
fs.shapeGrads[block], fs.vols[block],
264+
fs.mesh.conns[block], func, modify_element_gradient, *params)
265+
return blockValues
266+
267+
268+
def evaluate_on_block_heterogeneous_props(
269+
functionSpace, U, stateVars, props, dt, func, block,
270+
*params, modify_element_gradient=default_modify_element_gradient
271+
):
272+
"""Evaluates a density function at every quadrature point in a block of the mesh.
273+
274+
Args:
275+
functionSpace: Function space object to do the evaluation with.
276+
U: The vector of dofs for the primal field in the functional.
277+
stateVars: Internal state variable array.
278+
dt: Current time increment
279+
func: Lagrangian density function to evaluate, Must have the signature
280+
``func(u, dudx, q, x, *params) -> scalar``, where ``u`` is the primal field, ``q`` is the
281+
value of the internal variables, ``x`` is the current point coordinates, and ``*params`` is
282+
a variadic set of additional parameters, which correspond to the ``*params`` argument.
283+
block: Group of elements to evaluate over. This is an array of element indices. For
284+
performance, the elements within the block should be numbered consecutively.
285+
*params: Optional parameter fields to pass into Lagrangian density function. These are
286+
represented as a single value per element.
287+
modify_element_gradient: Optional function that modifies the gradient at the element level.
288+
This can be to set the particular 2D mode, and additionally to enforce volume averaging
289+
on the gradient operator. This is a keyword-only argument.
290+
291+
Returns:
292+
An array of shape (numElements, numQuadPtsPerElement) that contains the scalar values of the
293+
density functional ``func`` at every quadrature point in the block.
294+
"""
295+
fs = functionSpace
296+
compute_elem_values = jax.vmap(evaluate_on_element, (None, None, 0, 0, None, 0, 0, 0, 0, None, None, *tuple(0 for p in params)))
297+
298+
blockValues = compute_elem_values(U, fs.mesh.coords, stateVars[block], props[block], dt, fs.shapes[block],
255299
fs.shapeGrads[block], fs.vols[block],
256300
fs.mesh.conns[block], func, modify_element_gradient, *params)
257301
return blockValues
258302

259303

260-
def integrate_element_from_local_field(elemNodalField, elemNodalCoords, elemStates, dt, elemShapes, elemShapeGrads, elemVols, func, modify_element_gradient=default_modify_element_gradient):
304+
def integrate_element_from_local_field(elemNodalField, elemNodalCoords, elemStates, elemProps, dt, elemShapes, elemShapeGrads, elemVols, func, modify_element_gradient=default_modify_element_gradient):
261305
"""Integrate over element with element nodal field as input.
262306
This allows element residuals and element stiffness matrices to computed.
263307
"""
264308
elemVals = jax.vmap(interpolate_to_point, (None,0))(elemNodalField, elemShapes)
265309
elemGrads = jax.vmap(compute_quadrature_point_field_gradient, (None,0))(elemNodalField, elemShapeGrads)
266310
elemGrads = modify_element_gradient(elemGrads, elemShapes, elemVols, elemNodalField, elemNodalCoords)
267311
elemPoints = jax.vmap(interpolate_to_point, (None,0))(elemNodalCoords, elemShapes)
268-
fVals = jax.vmap(func, (0, 0, 0, 0, None))(elemVals, elemGrads, elemStates, elemPoints, dt)
312+
fVals = jax.vmap(func, (0, 0, 0, None, 0, None))(elemVals, elemGrads, elemStates, elemProps, elemPoints, dt)
269313
return np.dot(fVals, elemVols)
270314

271315

@@ -299,12 +343,13 @@ def integrate_element(U, coords, elemStates, elemShapes, elemShapeGrads, elemVol
299343
return np.dot(fVals, elemVols)
300344

301345

302-
def evaluate_on_element(U, coords, elemStates, dt, elemShapes, elemShapeGrads, elemVols, elemConn, kernelFunc, modify_element_gradient, *params):
346+
def evaluate_on_element(U, coords, elemStates, props, dt, elemShapes, elemShapeGrads, elemVols, elemConn, kernelFunc, modify_element_gradient, *params):
303347
elemVals = interpolate_to_element_points(U, elemShapes, elemConn)
304348
elemGrads = compute_element_field_gradient(U, coords, elemShapes, elemShapeGrads, elemVols, elemConn, modify_element_gradient)
305349
elemXs = interpolate_to_element_points(coords, elemShapes, elemConn)
306-
vmapArgs = 0, 0, 0, 0, None, *tuple(None for p in params)
307-
fVals = jax.vmap(kernelFunc, vmapArgs)(elemVals, elemGrads, elemStates, elemXs, dt, *params)
350+
# vmapArgs = 0, 0, 0, vmapPropValue(props), 0, None, *tuple(None for p in params)
351+
vmapArgs = 0, 0, 0, None, 0, None, *tuple(None for p in params)
352+
fVals = jax.vmap(kernelFunc, vmapArgs)(elemVals, elemGrads, elemStates, props, elemXs, dt, *params)
308353
return fVals
309354

310355

0 commit comments

Comments
 (0)