Skip to content

Commit

Permalink
Merge pull request #39 from sandialabs/physics-cleanup
Browse files Browse the repository at this point in the history
Physics cleanup
  • Loading branch information
cmhamel authored Jan 21, 2025
2 parents 309d01f + 3b20333 commit 4747f64
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 50 deletions.
103 changes: 103 additions & 0 deletions examples/forward_problems/burgers/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from pancax import *
# import pancax
# import pancax.domains_new
# import pancax.domains_new.variational_domain
# import pancax.physics_new
# import pancax.problems.forward_problem

##################
# for reproducibility
##################
key = random.key(10)

##################
# file management
##################
mesh_file = find_mesh_file('mesh_quad4.g')
logger = Logger('pinn.log', log_every=250)
pp = PostProcessor(mesh_file, 'exodus')

##################
# domain setup
##################
times = jnp.linspace(0.0, 10.0, 21)
domain = CollocationDomain(mesh_file, times)
print(domain)
##################
# physics setup
##################
physics = BurgersEquation()

##################
# bcs
##################
def bc_func(x, t, z):
x, y = x[0], x[1]
# return (0.5 + x) * (0.5 - x) * z
u = z
u = u.at[0].set((0.5 + x) * (0.5 - x) * u[0])
u = u.at[1].set((0.5 + x) * (0.5 - x) * u[1])

# u = u.at[1].set()
return u

def ic_func(x):
x, y = x[0], x[1]
return jnp.array([jnp.exp(-x**2 / 2.), 0.])

physics = physics.update_dirichlet_bc_func(bc_func)

ics = [
ic_func
]
essential_bcs = [
EssentialBC('nodeset_2', 0, lambda x, t: 0.0),
EssentialBC('nodeset_4', 0, lambda x, t: 0.0),
EssentialBC('nodeset_2', 1, lambda x, t: 0.0),
EssentialBC('nodeset_4', 1, lambda x, t: 0.0),
]
natural_bcs = [
]

##################
# problem setup
##################
forward_problem = ForwardProblem(domain, physics, ics, essential_bcs, natural_bcs)

##################
# ML setup
##################
n_dims = domain.coords.shape[1]
field = MLP(n_dims + 1, physics.n_dofs, 50, 5, jax.nn.tanh, key)
# props = FixedProperties([])
params = FieldPhysicsPair(field, forward_problem.physics)

print(forward_problem.physics.x_mins)
print(forward_problem.physics.x_maxs)

loss_function = EnergyLoss()
# loss_function_2 = StrongFormResidualLoss()
loss_function_2 = CombineLossFunctions(
StrongFormResidualLoss(1.0),
ICLossFunction(1.0),
# DirichletBCLoss(1.0)
)
opt = Adam(loss_function_2, learning_rate=1e-3, has_aux=True)
opt_st = opt.init(params)
# # print(opt_st)

for epoch in range(25000):
params, opt_st, loss = opt.step(params, forward_problem, opt_st)

if epoch % 100 == 0:
print(epoch)
print(loss)

##################
# post-processing
##################
pp.init(forward_problem, 'output.e',
node_variables=['field_values']
)
pp.write_outputs(params, forward_problem)
pp.close()
Binary file not shown.
51 changes: 28 additions & 23 deletions examples/forward_problems/mechanics/example_incompressible_2d.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from pancax import *

##################
# for debugging nans... this will slow things down though.
##################
# jax.config.update("jax_debug_nans", True)

##################
# for reproducibility
##################
Expand All @@ -15,7 +10,13 @@
##################
mesh_file = find_mesh_file('mesh_quad4.g')
logger = Logger('pinn.log', log_every=250)
pp = PostProcessor(mesh_file, 'vtk')
pp = PostProcessor(mesh_file, 'exodus')

##################
# domain setup
##################
times = jnp.linspace(0.0, 1.0, 11)
domain = VariationalDomain(mesh_file, times, q_order=2)

##################
# physics setup
Expand All @@ -24,9 +25,15 @@
essential_bc_func = UniaxialTensionLinearRamp(
final_displacement=1.0, length=1.0, direction='y', n_dimensions=2
)
model = NeoHookean()
model = NeoHookean(
bulk_modulus=1000.0,
shear_modulus=1.,
)
formulation = PlaneStrain()
physics_kernel = SolidMechanics(mesh_file, essential_bc_func, model, formulation)
physics = SolidMechanics(model, formulation)
physics = physics.update_dirichlet_bc_func(essential_bc_func)
ics = [
]
essential_bcs = [
EssentialBC('nset_1', 0),
EssentialBC('nset_1', 1),
Expand All @@ -35,37 +42,35 @@
]
natural_bcs = [
]
domain = VariationalDomain(physics_kernel, essential_bcs, natural_bcs, mesh_file, times, q_order=2)

##################
# problem setup
##################
problem = ForwardProblem(domain, physics, ics, essential_bcs, natural_bcs)

##################
# ML setup
##################
# loss_function = CombineLossFunctions(
# EnergyAndResidualLoss(residual_weight=250.0),
# QuadratureIncompressibilityConstraint(weight=100.0)
# )
loss_function = EnergyLoss()
field_network = MLP(physics_kernel.n_dofs + 1, physics_kernel.n_dofs, 50, 5, jax.nn.tanh, key)
# props = FixedProperties([1000.0, 0.3846])
shear = 1.0
props = FixedProperties([1000.0 * shear, shear])
params = FieldPropertyPair(field_network, props)
field_network = MLP(physics.n_dofs + 1, physics.n_dofs, 50, 5, jax.nn.tanh, key)
params = FieldPhysicsPair(field_network, problem.physics)

##################
# train network
##################
opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False)
opt_st = opt.init(params)
for epoch in range(1000):
params, opt_st, loss = opt.step(params, domain, opt_st)
for epoch in range(100000):
params, opt_st, loss = opt.step(params, problem, opt_st)
logger.log_loss(loss, epoch)

##################
# post-processing
##################
pp.init(domain, 'output.vtm',
pp.init(problem, 'output.e',
node_variables=[
'displacement',
'field_values'
# 'displacement',
# 'internal_force'
],
element_variables=[
Expand All @@ -76,5 +81,5 @@
# 'element_pk1_stress'
]
)
pp.write_outputs(params, domain)
pp.write_outputs(params, problem)
pp.close()
4 changes: 2 additions & 2 deletions examples/forward_problems/poisson/collocation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def bc_func(x, t, z):
##################
n_dims = domain.coords.shape[1]
field = MLP(n_dims + 1, physics.n_dofs, 50, 3, jax.nn.tanh, key)
params = FieldPropertyPair(field, problem.physics)
params = FieldPhysicsPair(field, problem.physics)

loss_function = StrongFormResidualLoss()
opt = Adam(loss_function, learning_rate=1e-3, has_aux=True)
opt_st = opt.init(params)

for epoch in range(500):
for epoch in range(5000):
params, opt_st, loss = opt.step(params, problem, opt_st)

if epoch % 100 == 0:
Expand Down
5 changes: 3 additions & 2 deletions pancax/loss_functions/strong_form_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __call__(self, params, domain):
return self.weight * residual, dict(residual=residual)

def load_step(self, params, domain, t):
func = domain.physics.strong_form_residual
# func = domain.physics.strong_form_residual
func = params.physics.strong_form_residual
# TODO this will fail on delta PINNs currently
residuals = vmap(func, in_axes=(None, 0, None))(params, domain.coords, t)
residuals = vmap(func, in_axes=(None, 0, None))(params.fields, domain.coords, t)
return jnp.square(residuals).mean()
2 changes: 1 addition & 1 deletion pancax/loss_functions/weak_form_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load_step(self, params, domain, t):
# pi = potential_energy(domain, us, props)
field, physics = params
us = physics.vmap_field_values(field, domain.coords, t)
pi = physics.potential_energy(params, domain.domain, t, us)
pi = physics.potential_energy(physics, domain.domain, t, us)
return pi


Expand Down
2 changes: 1 addition & 1 deletion pancax/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .elm import ELM, ELM2
from .field_property_pair import FieldPropertyPair
from .field_physics_pair import FieldPhysicsPair
from .mlp import Linear
from .mlp import MLP
from .mlp import MLPBasis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,34 @@ def serialise(self, base_name, epoch):
eqx.tree_serialise_leaves(file_name, self)


class FieldPropertyPair(BasePancaxModel):
class FieldPhysicsPair(BasePancaxModel):
"""
Data structure for storing a set of field network
parameters and a set of material properties
parameters and a physics object
:param fields: field network parameters object
:param properties: property parameters object
:param physics: physics object
"""
fields: eqx.Module
properties: eqx.Module
physics: eqx.Module

def __iter__(self):
"""
Iterator for user friendliness
"""
return iter((self.fields, self.properties))
return iter((self.fields, self.physics))

def freeze_fields_filter(self):
filter_spec = jtu.tree_map(lambda _: False, self)
filter_spec = eqx.tree_at(
# lambda tree: tree.properties.prop_params,
lambda tree: tree.properties,
lambda tree: tree.physics,
filter_spec,
replace=True
)
return filter_spec

def freeze_props_filter(self):
def freeze_physics_filter(self):
filter_spec = jtu.tree_map(lambda _: False, self)
for n in range(len(self.fields.layers)):
filter_spec = eqx.tree_at(
Expand Down
4 changes: 4 additions & 0 deletions pancax/physics_kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def update_dirichlet_bc_func(self, bc_func: Callable):
def update_normalization(self, domain):
x_mins = jnp.min(domain.coords, axis=0)
x_maxs = jnp.max(domain.coords, axis=0)

# x_mins = jnp.append(x_mins, jnp.min(domain.times))
# x_maxs = jnp.append(x_maxs, jnp.max(domain.times))

new_pytree = eqx.tree_at(lambda x: x.x_mins, self, x_mins)
new_pytree = eqx.tree_at(lambda x: x.x_maxs, new_pytree, x_maxs)
return new_pytree
Expand Down
20 changes: 20 additions & 0 deletions pancax/physics_kernels/beer_lambert_law.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .base import BaseStrongFormPhysics
from ..constitutive_models import Property
from jaxtyping import Array, Float
import equinox as eqx


class BeerLambertLaw(BaseStrongFormPhysics):
field_value_names: tuple[int, ...]
d: Float[Array, "3"] = eqx.field(static=True)
sigma: Property

def __init__(self, d: Float[Array, "3"], sigma: Property):
super().__init__(('I',))
self.d = d
self.sigma = sigma

def strong_form_residual(self, params, x, t, *args):
I = self.field_values(params, x, t, *args)
grad_I = self.field_gradients(params, x, t, *args)
return self.d @ grad_I + self.sigma * I
8 changes: 3 additions & 5 deletions pancax/physics_kernels/burgers_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ def __init__(self):
# self.v = v

def strong_form_residual(self, params, x, t, *args):
field, _ = params
u = self.field_values(field, x, t, *args)
grad_u = self.field_gradients(field, x, t, *args)
# delta_u = self.field_laplacians(field, x, t, *args)
dudt = self.field_time_derivatives(field, x, t, *args)
u = self.field_values(params, x, t, *args)
grad_u = self.field_gradients(params, x, t, *args)
dudt = self.field_time_derivatives(params, x, t, *args)
# return dudt + 0.01 * jnp.dot(grad_u, grad_u.T)
return dudt + jnp.dot(u, grad_u.T)
6 changes: 2 additions & 4 deletions pancax/physics_kernels/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ def energy(self, params, x, t, u, grad_u, *args):
return jnp.sum(pi)

def strong_form_neumann_bc(self, params, x, t, n, *args):
field, _ = params
grad_u = self.field_gradients(field, x, t, *args)
grad_u = self.field_gradients(params, x, t, *args)
return -jnp.dot(grad_u, n)

def strong_form_residual(self, params, x, t, *args):
field, _ = params
delta_u = self.field_laplacians(field, x, t, *args)
delta_u = self.field_laplacians(params, x, t, *args)
f = self.f(x)
return -delta_u - f
1 change: 1 addition & 0 deletions test/data/test_global_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_global_data():
)


@pytest.mark.skip(reason='Failing on missions with bad tk')
def test_global_data_with_plotting():
data_file = os.path.join(Path(__file__).parent, 'data_global.csv')
mesh_file = os.path.join(Path(__file__).parent, 'mesh.g')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pancax import FieldPropertyPair, Properties, MLP
from pancax import FieldPhysicsPair, Properties, MLP
from pathlib import Path
import equinox as eqx
import jax
Expand All @@ -12,7 +12,7 @@ def test_field_property_pair():
prop_maxs=[2., 3.],
key=jax.random.key(0)
)
model = FieldPropertyPair(network, props)
model = FieldPhysicsPair(network, props)
x = jax.numpy.ones(3)

network, props = model
Expand All @@ -30,7 +30,7 @@ def test_model_serialisation():
prop_maxs=[2., 3.],
key=jax.random.key(0)
)
model = FieldPropertyPair(network, props)
model = FieldPhysicsPair(network, props)
x = jax.numpy.ones(3)

network, props = model
Expand Down
4 changes: 2 additions & 2 deletions test/test_post_processors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import random
from pancax import EssentialBC, VariationalDomain, NeoHookean, ThreeDimensional, SolidMechanics
from pancax import FieldPropertyPair, MLP
from pancax import FieldPhysicsPair, MLP
from pancax import PostProcessor, ForwardProblem
from pathlib import Path
import jax
Expand Down Expand Up @@ -37,7 +37,7 @@ def problem():
def params(problem):
key = random.key(10)
field_network = MLP(4, 3, 20, 3, jax.nn.tanh, key)
return FieldPropertyPair(field_network, problem.physics)
return FieldPhysicsPair(field_network, problem.physics)


def test_post_processor(params, problem):
Expand Down

0 comments on commit 4747f64

Please sign in to comment.