From 23e0397b26eb45e475945e4dac68c96fc56d05ca Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 9 Jan 2025 20:36:51 +0000 Subject: [PATCH] Boundary Quadrature element --- FIAT/polynomial_set.py | 2 +- finat/element_factory.py | 7 ++-- finat/fiat_elements.py | 7 ++-- finat/point_set.py | 7 ++-- finat/quadrature_element.py | 65 ++++++++++++++++++++++++++++--------- 5 files changed, 61 insertions(+), 27 deletions(-) diff --git a/FIAT/polynomial_set.py b/FIAT/polynomial_set.py index 08e62b8d8..be97e43df 100644 --- a/FIAT/polynomial_set.py +++ b/FIAT/polynomial_set.py @@ -69,7 +69,7 @@ def tabulate_new(self, pts): def tabulate(self, pts, jet_order=0): """Returns the values of the polynomial set.""" base_vals = self.expansion_set._tabulate(self.embedded_degree, pts, order=jet_order) - result = {alpha: numpy.dot(self.coeffs, base_vals[alpha]) for alpha in base_vals} + result = {alpha: numpy.tensordot(self.coeffs, base_vals[alpha], (-1, 0)) for alpha in base_vals} return result def get_expansion_set(self): diff --git a/finat/element_factory.py b/finat/element_factory.py index 48db428d8..b95043b25 100644 --- a/finat/element_factory.py +++ b/finat/element_factory.py @@ -149,13 +149,14 @@ def convert(element, **kwargs): @convert.register(finat.ufl.FiniteElement) def convert_finiteelement(element, **kwargs): cell = as_fiat_cell(element.cell) - if element.family() == "Quadrature": + if element.family() in {"Quadrature", "Boundary Quadrature"}: degree = element.degree() - scheme = element.quadrature_scheme() + scheme = element.quadrature_scheme() or "default" if degree is None or scheme is None: raise ValueError("Quadrature scheme and degree must be specified!") - return finat.make_quadrature_element(cell, degree, scheme), set() + codim = 1 if element.family() == "Boundary Quadrature" else 0 + return finat.make_quadrature_element(cell, degree, scheme, codim), set() lmbda = supported_elements[element.family()] if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}: lmbda = None diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 0203a3a7f..6a1a2fc18 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -100,7 +100,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): ''' space_dimension = self._element.space_dimension() value_size = np.prod(self._element.value_shape(), dtype=int) - fiat_result = self._element.tabulate(order, ps.points, entity) + fiat_result = self._element.tabulate(order, ps.points.reshape(-1, ps.points.shape[-1]), entity) result = {} # In almost all cases, we have # self.space_dimension() == self._element.space_dimension() @@ -116,9 +116,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): continue derivative = sum(alpha) - table_roll = fiat_table.reshape( - space_dimension, value_size, len(ps.points) - ).transpose(1, 2, 0) + table = fiat_table.reshape(space_dimension, value_size, *ps.points.shape[:-1]) + table_roll = np.moveaxis(table, 0, -1) exprs = [] for table in table_roll: diff --git a/finat/point_set.py b/finat/point_set.py index 1497308c7..82424c04b 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -24,8 +24,7 @@ def points(self): @property def dimension(self): """Point dimension.""" - _, dim = self.points.shape - return dim + return self.points.shape[-1] @abstractproperty def indices(self): @@ -130,7 +129,7 @@ def __init__(self, points): :arg points: A vector of N points of shape (N, D) where D is the dimension of each point.""" points = numpy.asarray(points) - assert len(points.shape) == 2 + assert len(points.shape) > 1 self.points = points @cached_property @@ -139,7 +138,7 @@ def points(self): @cached_property def indices(self): - return (gem.Index(extent=len(self.points)),) + return tuple(gem.Index(extent=e) for e in self.points.shape[:-1]) @cached_property def expression(self): diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index 3f17ec399..c3e6d29ed 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -1,4 +1,4 @@ -from finat.point_set import UnknownPointSet +from finat.point_set import UnknownPointSet, PointSet from functools import reduce import numpy @@ -13,7 +13,7 @@ from finat.quadrature import make_quadrature, AbstractQuadratureRule -def make_quadrature_element(fiat_ref_cell, degree, scheme="default"): +def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0): """Construct a :class:`QuadratureElement` from a given a reference element, degree and scheme. @@ -23,9 +23,16 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default"): integrate exactly. :param scheme: The quadrature scheme to use - e.g. "default", "canonical" or "KMV". + :param codim: The codimension of the quadrature scheme. :returns: The appropriate :class:`QuadratureElement` """ - rule = make_quadrature(fiat_ref_cell, degree, scheme=scheme) + if codim: + sd = fiat_ref_cell.get_spatial_dimension() + rule_ref_cell = fiat_ref_cell.construct_subcomplex(sd - codim) + else: + rule_ref_cell = fiat_ref_cell + + rule = make_quadrature(rule_ref_cell, degree, scheme=scheme) return QuadratureElement(fiat_ref_cell, rule) @@ -42,8 +49,6 @@ def __init__(self, fiat_ref_cell, rule): self.cell = fiat_ref_cell if not isinstance(rule, AbstractQuadratureRule): raise TypeError("rule is not an AbstractQuadratureRule") - if fiat_ref_cell.get_spatial_dimension() != rule.point_set.dimension: - raise ValueError("Cell dimension does not match rule's point set dimension") self._rule = rule @cached_property @@ -64,10 +69,16 @@ def formdegree(self): @cached_property def _entity_dofs(self): - # Inspired by ffc/quadratureelement.py + top = self.cell.get_topology() entity_dofs = {dim: {entity: [] for entity in entities} - for dim, entities in self.cell.get_topology().items()} - entity_dofs[self.cell.get_dimension()] = {0: list(range(self.space_dimension()))} + for dim, entities in top.items()} + ps = self._rule.point_set + dim = ps.dimension + num_pts = len(ps.points) + cur = 0 + for entity in sorted(top[dim]): + entity_dofs[dim][entity] = list(range(cur, cur + num_pts)) + cur += num_pts return entity_dofs def entity_dofs(self): @@ -76,9 +87,22 @@ def entity_dofs(self): def space_dimension(self): return numpy.prod(self.index_shape, dtype=int) + @cached_property + def _point_set(self): + ps = self._rule.point_set + sd = self.cell.get_spatial_dimension() + dim = ps.dimension + if dim != sd: + # Tile the quadrature rule on each subentity + entity_ids = self.entity_dofs() + pts = [self.cell.get_entity_transform(dim, entity)(ps.points) + for entity in entity_ids[dim]] + ps = PointSet(numpy.stack(pts, axis=0)) + return ps + @property def index_shape(self): - ps = self._rule.point_set + ps = self._point_set return tuple(index.extent for index in ps.indices) @property @@ -87,7 +111,7 @@ def value_shape(self): @cached_property def fiat_equivalent(self): - ps = self._rule.point_set + ps = self._point_set if isinstance(ps, UnknownPointSet): raise ValueError("A quadrature element with rule with runtime points has no fiat equivalent!") weights = getattr(self._rule, 'weights', None) @@ -107,8 +131,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set object. :param entity: the cell entity on which to tabulate. ''' - if entity is not None and entity != (self.cell.get_dimension(), 0): - raise ValueError('QuadratureElement does not "tabulate" on subentities.') + rule_dim = self._rule.point_set.dimension + if entity is None: + entity = (rule_dim, 0) + entity_dim, entity_id = entity + if entity_dim != rule_dim: + raise ValueError(f"Cannot tabulate QuadratureElement of dimension {rule_dim}" + f" on subentities of dimension {entity_dim}.") if order: raise ValueError("Derivatives are not defined on a QuadratureElement.") @@ -121,15 +150,21 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): product = reduce(gem.Product, [gem.Delta(q, r) for q, r in zip(ps.indices, multiindex)]) - dim = self.cell.get_spatial_dimension() - return {(0,) * dim: gem.ComponentTensor(product, multiindex)} + sd = self.cell.get_spatial_dimension() + if sd != ps.dimension: + data = numpy.zeros(self.index_shape[:-1], dtype=object) + data[...] = gem.Zero(product.shape) + data[entity_id] = product + product = gem.Indexed(gem.ListTensor(data), multiindex[-1]) + + return {(0,) * sd: gem.ComponentTensor(product, multiindex)} def point_evaluation(self, order, refcoords, entity=None): raise NotImplementedError("QuadratureElement cannot do point evaluation!") @property def dual_basis(self): - ps = self._rule.point_set + ps = self._point_set multiindex = self.get_indices() # Evaluation matrix is just an outer product of identity # matrices, evaluation points are just the quadrature points.