@@ -193,7 +193,14 @@ def interpolate_to_points(functionSpace, nodalField):
193
193
return jax .vmap (interpolate_to_element_points , (None , 0 , 0 ))(nodalField , functionSpace .shapes , functionSpace .mesh .conns )
194
194
195
195
196
- def integrate_over_block (functionSpace , U , stateVars , dt , func , block ,
196
+ def vmapPropValue (propArray ):
197
+ numAxes = len (propArray .shape )
198
+ if numAxes > 1 :
199
+ return 0
200
+ else :
201
+ return None
202
+
203
+ def integrate_over_block (functionSpace , U , stateVars , props , dt , func , block ,
197
204
* params , modify_element_gradient = default_modify_element_gradient ):
198
205
"""Integrates a density function over a block of the mesh.
199
206
@@ -219,11 +226,11 @@ def integrate_over_block(functionSpace, U, stateVars, dt, func, block,
219
226
block of elements.
220
227
"""
221
228
222
- vals = evaluate_on_block (functionSpace , U , stateVars , dt , func , block , * params , modify_element_gradient = modify_element_gradient )
229
+ vals = evaluate_on_block (functionSpace , U , stateVars , props , dt , func , block , * params , modify_element_gradient = modify_element_gradient )
223
230
return np .dot (vals .ravel (), functionSpace .vols [block ].ravel ())
224
231
225
232
226
- def evaluate_on_block (functionSpace , U , stateVars , dt , func , block ,
233
+ def evaluate_on_block (functionSpace , U , stateVars , props , dt , func , block ,
227
234
* params , modify_element_gradient = default_modify_element_gradient ):
228
235
"""Evaluates a density function at every quadrature point in a block of the mesh.
229
236
@@ -249,23 +256,60 @@ def evaluate_on_block(functionSpace, U, stateVars, dt, func, block,
249
256
density functional ``func`` at every quadrature point in the block.
250
257
"""
251
258
fs = functionSpace
252
- compute_elem_values = jax .vmap (evaluate_on_element , (None , None , 0 , None , 0 , 0 , 0 , 0 , None , None , * tuple (0 for p in params )))
253
259
254
- blockValues = compute_elem_values (U , fs .mesh .coords , stateVars [block ], dt , fs .shapes [block ],
260
+ compute_elem_values = jax .vmap (evaluate_on_element , (None , None , 0 , vmapPropValue (props ), None , 0 , 0 , 0 , 0 , None , None , * tuple (0 for p in params )))
261
+
262
+ blockValues = compute_elem_values (U , fs .mesh .coords , stateVars [block ], props , dt , fs .shapes [block ],
263
+ fs .shapeGrads [block ], fs .vols [block ],
264
+ fs .mesh .conns [block ], func , modify_element_gradient , * params )
265
+ return blockValues
266
+
267
+
268
+ def evaluate_on_block_heterogeneous_props (
269
+ functionSpace , U , stateVars , props , dt , func , block ,
270
+ * params , modify_element_gradient = default_modify_element_gradient
271
+ ):
272
+ """Evaluates a density function at every quadrature point in a block of the mesh.
273
+
274
+ Args:
275
+ functionSpace: Function space object to do the evaluation with.
276
+ U: The vector of dofs for the primal field in the functional.
277
+ stateVars: Internal state variable array.
278
+ dt: Current time increment
279
+ func: Lagrangian density function to evaluate, Must have the signature
280
+ ``func(u, dudx, q, x, *params) -> scalar``, where ``u`` is the primal field, ``q`` is the
281
+ value of the internal variables, ``x`` is the current point coordinates, and ``*params`` is
282
+ a variadic set of additional parameters, which correspond to the ``*params`` argument.
283
+ block: Group of elements to evaluate over. This is an array of element indices. For
284
+ performance, the elements within the block should be numbered consecutively.
285
+ *params: Optional parameter fields to pass into Lagrangian density function. These are
286
+ represented as a single value per element.
287
+ modify_element_gradient: Optional function that modifies the gradient at the element level.
288
+ This can be to set the particular 2D mode, and additionally to enforce volume averaging
289
+ on the gradient operator. This is a keyword-only argument.
290
+
291
+ Returns:
292
+ An array of shape (numElements, numQuadPtsPerElement) that contains the scalar values of the
293
+ density functional ``func`` at every quadrature point in the block.
294
+ """
295
+ fs = functionSpace
296
+ compute_elem_values = jax .vmap (evaluate_on_element , (None , None , 0 , 0 , None , 0 , 0 , 0 , 0 , None , None , * tuple (0 for p in params )))
297
+
298
+ blockValues = compute_elem_values (U , fs .mesh .coords , stateVars [block ], props [block ], dt , fs .shapes [block ],
255
299
fs .shapeGrads [block ], fs .vols [block ],
256
300
fs .mesh .conns [block ], func , modify_element_gradient , * params )
257
301
return blockValues
258
302
259
303
260
- def integrate_element_from_local_field (elemNodalField , elemNodalCoords , elemStates , dt , elemShapes , elemShapeGrads , elemVols , func , modify_element_gradient = default_modify_element_gradient ):
304
+ def integrate_element_from_local_field (elemNodalField , elemNodalCoords , elemStates , elemProps , dt , elemShapes , elemShapeGrads , elemVols , func , modify_element_gradient = default_modify_element_gradient ):
261
305
"""Integrate over element with element nodal field as input.
262
306
This allows element residuals and element stiffness matrices to computed.
263
307
"""
264
308
elemVals = jax .vmap (interpolate_to_point , (None ,0 ))(elemNodalField , elemShapes )
265
309
elemGrads = jax .vmap (compute_quadrature_point_field_gradient , (None ,0 ))(elemNodalField , elemShapeGrads )
266
310
elemGrads = modify_element_gradient (elemGrads , elemShapes , elemVols , elemNodalField , elemNodalCoords )
267
311
elemPoints = jax .vmap (interpolate_to_point , (None ,0 ))(elemNodalCoords , elemShapes )
268
- fVals = jax .vmap (func , (0 , 0 , 0 , 0 , None ))(elemVals , elemGrads , elemStates , elemPoints , dt )
312
+ fVals = jax .vmap (func , (0 , 0 , 0 , None , 0 , None ))(elemVals , elemGrads , elemStates , elemProps , elemPoints , dt )
269
313
return np .dot (fVals , elemVols )
270
314
271
315
@@ -299,12 +343,13 @@ def integrate_element(U, coords, elemStates, elemShapes, elemShapeGrads, elemVol
299
343
return np .dot (fVals , elemVols )
300
344
301
345
302
- def evaluate_on_element (U , coords , elemStates , dt , elemShapes , elemShapeGrads , elemVols , elemConn , kernelFunc , modify_element_gradient , * params ):
346
+ def evaluate_on_element (U , coords , elemStates , props , dt , elemShapes , elemShapeGrads , elemVols , elemConn , kernelFunc , modify_element_gradient , * params ):
303
347
elemVals = interpolate_to_element_points (U , elemShapes , elemConn )
304
348
elemGrads = compute_element_field_gradient (U , coords , elemShapes , elemShapeGrads , elemVols , elemConn , modify_element_gradient )
305
349
elemXs = interpolate_to_element_points (coords , elemShapes , elemConn )
306
- vmapArgs = 0 , 0 , 0 , 0 , None , * tuple (None for p in params )
307
- fVals = jax .vmap (kernelFunc , vmapArgs )(elemVals , elemGrads , elemStates , elemXs , dt , * params )
350
+ # vmapArgs = 0, 0, 0, vmapPropValue(props), 0, None, *tuple(None for p in params)
351
+ vmapArgs = 0 , 0 , 0 , None , 0 , None , * tuple (None for p in params )
352
+ fVals = jax .vmap (kernelFunc , vmapArgs )(elemVals , elemGrads , elemStates , props , elemXs , dt , * params )
308
353
return fVals
309
354
310
355
0 commit comments